format tests

This commit is contained in:
Wassim Kabalan 2024-12-08 22:46:20 +01:00
parent af29c4005d
commit 97f39bd051
3 changed files with 94 additions and 76 deletions

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
markers =
distributed: mark a test as distributed
single_device: mark a test as single_device

View file

@ -1,20 +1,20 @@
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.distributed import uniform_particles
from helpers import MSE , MSRE
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Dopri5
from jaxpm.utils import power_spectrum from jaxpm.utils import power_spectrum
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order): fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
@ -38,6 +38,7 @@ def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor,
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order): fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
@ -58,19 +59,24 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, def test_nbody_absolute(simulation_config, initial_conditions,
nbody_from_lpt1,nbody_from_lpt2, cosmo , order): lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions,particles , a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
make_diffrax_ode(cosmo, mesh_shape))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -102,9 +108,12 @@ def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor,
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_nbody_relative(simulation_config, initial_conditions,
nbody_from_lpt1,nbody_from_lpt2, cosmo , order): lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -144,6 +153,3 @@ def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor,
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE

View file

@ -1,20 +1,26 @@
from conftest import initialize_distributed from conftest import initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
import jax # noqa : E402 import jax # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from diffrax import (Dopri5, ODETerm, PIDController, SaveAt, # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402 diffeqsolve)
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from diffrax import ODETerm, Dopri5, PIDController, SaveAt, diffeqsolve # noqa : E402
from jax.sharding import PartitionSpec as P, NamedSharding # noqa : E402
from helpers import MSE # noqa : E402 from helpers import MSE # noqa : E402
from jax import lax # noqa : E402 from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
@ -70,7 +76,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
@ -136,7 +143,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field , tiled=True) multi_device_final_field = process_allgather(multi_device_final_field,
tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field) mse = MSE(single_device_final_field, multi_device_final_field)
print(f"MSE is {mse}") print(f"MSE is {mse}")