From e1daa8cba4ece2c9449d2fdbe5b195a58fe19caf Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 28 Feb 2025 14:03:33 +0100 Subject: [PATCH] final touches --- tests/test_against_fpm.py | 2 ++ tests/test_distributed_pm.py | 2 -- tests/test_gradients.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 5ef5211..5ebcbc2 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -95,6 +95,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -140,6 +141,7 @@ def test_nbody_relative(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index 6e94f56..69c37ed 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -258,7 +258,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, def test_fwd_rev_gradients(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) - # SINGLE DEVICE RUN cosmo._workspace = {} mesh = jax.make_mesh(pdims, ('x', 'y')) @@ -328,7 +327,6 @@ def test_fwd_rev_gradients(cosmo, pdims): def test_vmap(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) - # SINGLE DEVICE RUN cosmo._workspace = {} mesh = jax.make_mesh(pdims, ('x', 'y')) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index bb48920..1f611aa 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, particles, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: @@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, a=lpt_scale_factor, order=order) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -66,6 +66,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, t1=1.0, dt0=None, y0=y0, + args=cosmo, adjoint=adjoint, stepsize_controller=controller, saveat=saveat)