format tests

This commit is contained in:
Wassim Kabalan 2024-12-08 22:46:20 +01:00
parent af29c4005d
commit 97f39bd051
3 changed files with 94 additions and 76 deletions

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
markers =
distributed: mark a test as distributed
single_device: mark a test as single_device

View file

@ -1,21 +1,21 @@
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.distributed import uniform_particles
from helpers import MSE , MSRE
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Dopri5
from jaxpm.utils import power_spectrum from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
@pytest.mark.parametrize("order", [1 , 2])
def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor, @pytest.mark.single_device
fpm_lpt1_field,fpm_lpt2_field, cosmo , order): @pytest.mark.parametrize("order", [1, 2])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -31,17 +31,18 @@ def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
_ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field,fpm_lpt2_field, cosmo , order): fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
# Initial displacement # Initial displacement
@ -51,26 +52,31 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
_ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, def test_nbody_absolute(simulation_config, initial_conditions,
nbody_from_lpt1,nbody_from_lpt2, cosmo , order): lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions,particles , a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
make_diffrax_ode(cosmo, mesh_shape))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -84,33 +90,36 @@ def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor,
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
t0=lpt_scale_factor, t0=lpt_scale_factor,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1 , 0]) final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_nbody_relative(simulation_config, initial_conditions,
nbody_from_lpt1,nbody_from_lpt2, cosmo , order): lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions , a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
@ -127,23 +136,20 @@ def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor,
y0 = jnp.stack([dx, p]) y0 = jnp.stack([dx, p])
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
t0=lpt_scale_factor, t0=lpt_scale_factor,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
final_field = cic_paint_dx(solutions.ys[-1 , 0]) final_field = cic_paint_dx(solutions.ys[-1, 0])
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) _, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -1,20 +1,26 @@
from conftest import initialize_distributed from conftest import initialize_distributed
initialize_distributed() # ignore : E402
import jax # noqa : E402 initialize_distributed() # ignore : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from diffrax import ODETerm, Dopri5, PIDController, SaveAt, diffeqsolve # noqa : E402
from jax.sharding import PartitionSpec as P, NamedSharding # noqa : E402
from helpers import MSE # noqa : E402
from jax import lax # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402
from diffrax import (Dopri5, ODETerm, PIDController, SaveAt, # noqa : E402
diffeqsolve)
from helpers import MSE # noqa : E402
from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
@ -27,10 +33,10 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, initial_conditions,
particles, particles,
a=0.1, a=0.1,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jnp.stack([particles + dx, p])
else: else:
@ -41,10 +47,10 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
atol=1e-8, atol=1e-8,
pcoeff=0.4, pcoeff=0.4,
icoeff=1, icoeff=1,
dcoeff=0) dcoeff=0)
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
@ -59,7 +65,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if absolute_painting: if absolute_painting:
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0]) solutions.ys[-1, 0])
else: else:
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0]) single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
@ -70,7 +76,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
@ -128,17 +135,18 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if absolute_painting: if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0], solutions.ys[-1, 0],
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
else: else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field , tiled=True) multi_device_final_field = process_allgather(multi_device_final_field,
tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field) mse = MSE(single_device_final_field, multi_device_final_field)
print(f"MSE is {mse}") print(f"MSE is {mse}")
assert mse < _TOLERANCE assert mse < _TOLERANCE