diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9f020be --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + distributed: mark a test as distributed + single_device: mark a test as single_device diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 60d95ca..6d17939 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -1,21 +1,21 @@ import pytest - - +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from helpers import MSE, MSRE from jax import numpy as jnp +from jaxpm.distributed import uniform_particles from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.pm import lpt, make_diffrax_ode -from jaxpm.distributed import uniform_particles -from helpers import MSE , MSRE -from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Dopri5 from jaxpm.utils import power_spectrum _TOLERANCE = 1e-4 _PM_TOLERANCE = 1e-3 -@pytest.mark.parametrize("order", [1 , 2]) -def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field,fpm_lpt2_field, cosmo , order): + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, + fpm_lpt1_field, fpm_lpt2_field, cosmo, order): mesh_shape, box_shape = simulation_config cosmo._workspace = {} @@ -31,17 +31,18 @@ def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor, fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) - _ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) - _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE +@pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field,fpm_lpt2_field, cosmo , order): - + fpm_lpt1_field, fpm_lpt2_field, cosmo, order): + mesh_shape, box_shape = simulation_config cosmo._workspace = {} # Initial displacement @@ -51,26 +52,31 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field - _ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) - _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE +@pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) -def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, - nbody_from_lpt1,nbody_from_lpt2, cosmo , order): - +def test_nbody_absolute(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + mesh_shape, box_shape = simulation_config cosmo._workspace = {} particles = uniform_particles(mesh_shape) # Initial displacement - dx, p, _ = lpt(cosmo, initial_conditions,particles , a=lpt_scale_factor, order=order) + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -84,33 +90,36 @@ def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, y0 = jnp.stack([particles + dx, p]) solutions = diffeqsolve(ode_fn, - solver, - t0=lpt_scale_factor, - t1=1.0, - dt0=None, - y0=y0, - stepsize_controller=controller, - saveat=saveat) + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) - final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1 , 0]) + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 - _ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) - _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + +@pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) -def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, - nbody_from_lpt1,nbody_from_lpt2, cosmo , order): +def test_nbody_relative(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): mesh_shape, box_shape = simulation_config cosmo._workspace = {} # Initial displacement - dx, p, _ = lpt(cosmo, initial_conditions , a=lpt_scale_factor, order=order) + dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) ode_fn = ODETerm( make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) @@ -127,23 +136,20 @@ def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, y0 = jnp.stack([dx, p]) solutions = diffeqsolve(ode_fn, - solver, - t0=lpt_scale_factor, - t1=1.0, - dt0=None, - y0=y0, - stepsize_controller=controller, - saveat=saveat) + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) - final_field = cic_paint_dx(solutions.ys[-1 , 0]) + final_field = cic_paint_dx(solutions.ys[-1, 0]) fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 - _ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) - _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE - - - diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index 39b48b9..cec0bce 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -1,20 +1,26 @@ from conftest import initialize_distributed -initialize_distributed() # ignore : E402 -import jax # noqa : E402 -from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 -import jax.numpy as jnp # noqa : E402 -import pytest # noqa : E402 -from jax.experimental.multihost_utils import process_allgather # noqa : E402 -from jaxpm.distributed import uniform_particles # noqa : E402 -from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 -from diffrax import ODETerm, Dopri5, PIDController, SaveAt, diffeqsolve # noqa : E402 -from jax.sharding import PartitionSpec as P, NamedSharding # noqa : E402 -from helpers import MSE # noqa : E402 -from jax import lax # noqa : E402 +initialize_distributed() # ignore : E402 -_TOLERANCE = 3.0 # 🙃🙃 +import jax # noqa : E402 +import jax.numpy as jnp # noqa : E402 +import pytest # noqa : E402 +from diffrax import (Dopri5, ODETerm, PIDController, SaveAt, # noqa : E402 + diffeqsolve) +from helpers import MSE # noqa : E402 +from jax import lax # noqa : E402 +from jax.experimental.multihost_utils import process_allgather # noqa : E402 +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P # noqa : E402 +from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 +from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 + +_TOLERANCE = 3.0 # 🙃🙃 + + +@pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("absolute_painting", [True, False]) def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, @@ -27,10 +33,10 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, particles = uniform_particles(mesh_shape) # Initial displacement dx, p, _ = lpt(cosmo, - initial_conditions, - particles, - a=0.1, - order=order) + initial_conditions, + particles, + a=0.1, + order=order) ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: @@ -41,10 +47,10 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, solver = Dopri5() controller = PIDController(rtol=1e-8, - atol=1e-8, - pcoeff=0.4, - icoeff=1, - dcoeff=0) + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) saveat = SaveAt(t1=True) @@ -59,7 +65,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, if absolute_painting: single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), - solutions.ys[-1, 0]) + solutions.ys[-1, 0]) else: single_device_final_field = cic_paint_dx(solutions.ys[-1, 0]) @@ -70,7 +76,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 - initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) print(f"sharded initial conditions {initial_conditions.sharding}") @@ -128,17 +135,18 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, if absolute_painting: multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), - solutions.ys[-1, 0], - halo_size=halo_size, - sharding=sharding) + solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) else: multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], - halo_size=halo_size, - sharding=sharding) + halo_size=halo_size, + sharding=sharding) - multi_device_final_field = process_allgather(multi_device_final_field , tiled=True) + multi_device_final_field = process_allgather(multi_device_final_field, + tiled=True) mse = MSE(single_device_final_field, multi_device_final_field) print(f"MSE is {mse}") - assert mse < _TOLERANCE \ No newline at end of file + assert mse < _TOLERANCE