From f8325b1c675950aebc29b5699dc17816f8aff7a9 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 28 Feb 2025 10:06:09 +0100 Subject: [PATCH] fix tests --- tests/test_distributed_pm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index fd683ab..eb44456 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -37,12 +37,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, particles, a=0.1, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, + paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -94,8 +94,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, halo_size=halo_size, sharding=sharding)) @@ -108,8 +107,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, halo_size=halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, - mesh_shape, + make_diffrax_ode(mesh_shape, paint_absolute_pos=False, halo_size=halo_size, sharding=sharding))