From b132a0e2aa182cd9ff1ac695b9ed2d01cefc3cd6 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 21 Dec 2024 23:14:45 +0100 Subject: [PATCH] update jaxdecomp version and test gradients --- pyproject.toml | 2 +- tests/test_gradients.py | 115 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 tests/test_gradients.py diff --git a/pyproject.toml b/pyproject.toml index a41096d..a204633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" requires-python = ">=3.9" license = { file = "LICENSE" } urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" } -dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"] +dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"] [tool.setuptools] packages = ["jaxpm"] diff --git a/tests/test_gradients.py b/tests/test_gradients.py new file mode 100644 index 0000000..1ac10b5 --- /dev/null +++ b/tests/test_gradients.py @@ -0,0 +1,115 @@ +import pytest +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from helpers import MSE +from jax import numpy as jnp + +from jaxpm.distributed import uniform_particles +from jaxpm.painting import cic_paint, cic_paint_dx +from jaxpm.pm import lpt, make_diffrax_ode +import jax + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_grad_relative(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, _ = simulation_config + cosmo._workspace = {} + + @jax.jit + @jax.grad + def forward_model(initial_conditions, cosmo): + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint_dx(solutions.ys[-1, 0]) + + return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + + bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions , cosmo) + bad_ic = forward_model(bad_initial_conditions, cosmo) + + assert jnp.max(best_ic) < 1e-5 + assert jnp.max(bad_ic) > 1e-5 + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_grad_absolute(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, _ = simulation_config + cosmo._workspace = {} + + @jax.jit + @jax.grad + def forward_model(initial_conditions, cosmo): + + # Initial displacement + particles = uniform_particles(mesh_shape) + dx, p, _ = lpt(cosmo, initial_conditions,particles, a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=True)) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([particles + dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) + + return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + + bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions , cosmo) + bad_ic = forward_model(bad_initial_conditions, cosmo) + + assert jnp.max(best_ic) < 1e-5 + assert jnp.max(bad_ic) > 1e-5 + +