JaxPM/tests/test_gradients.py
Wassim KABALAN 6693e5c725
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s
Fix sharding error (#37)
* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2025-06-28 23:07:31 +02:00

88 lines
3.1 KiB
Python

import jax
import pytest
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE
from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting, adjoint):
mesh_shape, _ = simulation_config
cosmo._workspace = {}
if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
@jax.jit
@jax.grad
def forward_model(initial_conditions, cosmo):
# Initial displacement
if absolute_painting:
particles = uniform_particles(mesh_shape)
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p])
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
solver = Dopri5()
controller = PIDController(rtol=1e-7,
atol=1e-7,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=lpt_scale_factor,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
adjoint=adjoint,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
else:
final_field = cic_paint_dx(solutions.ys[-1, 0])
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
bad_initial_conditions = initial_conditions + jax.random.normal(
jax.random.PRNGKey(0), initial_conditions.shape) * 0.5
best_ic = forward_model(initial_conditions, cosmo)
bad_ic = forward_model(bad_initial_conditions, cosmo)
assert jnp.max(best_ic) < 1e-5
assert jnp.max(bad_ic) > 1e-5