From f70583b5fd87af5d14f4af58aa8b59532cfffa22 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 6 Dec 2024 18:56:57 +0100 Subject: [PATCH] add tests --- tests/conftest.py | 180 +++++++++++++++++++++++++++++++++++ tests/helpers.py | 10 ++ tests/test_against_fpm.py | 149 +++++++++++++++++++++++++++++ tests/test_distributed_pm.py | 144 ++++++++++++++++++++++++++++ 4 files changed, 483 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/helpers.py create mode 100644 tests/test_against_fpm.py create mode 100644 tests/test_distributed_pm.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..94c3fb2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,180 @@ +# Parameterized fixture for mesh_shape +import os +import pytest +os.environ["EQX_ON_ERROR"] = "nan" +setup_done = False +on_cluster = False + + +def is_on_cluster(): + global on_cluster + return on_cluster + +def initialize_distributed(): + global setup_done + global on_cluster + if not setup_done: + if "SLURM_JOB_ID" in os.environ: + on_cluster = True + print("Running on cluster") + import jax + jax.distributed.initialize() + setup_done = True + on_cluster = True + else: + print("Running locally") + setup_done = True + on_cluster = False + os.environ["JAX_PLATFORM_NAME"] = "cpu" + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_teardown_session(): + # Code to run at the start of the session + print("Starting session...") + initialize_distributed() + # Setup code here + # e.g., connecting to a database, initializing some resources, etc. + + + +@pytest.fixture( + scope="session", + params=[ + ((64, 64, 64) , (512., 512., 512.)), # BOX + ((64, 64, 128) , (256. , 256. , 512.)), # RECTANGULAR + ]) +def simulation_config(request): + return request.param + + +@pytest.fixture(scope="session", params=[0.1 , 0.5 , 0.8]) +def lpt_scale_factor(request): + return request.param + + + +@pytest.fixture(scope="session") +def cosmo(): + from jax_cosmo import Cosmology + from functools import partial + Planck18 = partial( + Cosmology, + # Omega_m = 0.3111 + Omega_c=0.2607, + Omega_b=0.0490, + Omega_k=0.0, + h=0.6766, + n_s=0.9665, + sigma8=0.8102, + w0=-1.0, + wa=0.0, + ) + + return Planck18() + + +@pytest.fixture(scope="session") +def particle_mesh(simulation_config): + from pmesh.pm import ParticleMesh + mesh_shape, box_shape = simulation_config + return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4') + + +@pytest.fixture(scope="session") +def fpm_initial_conditions(cosmo, particle_mesh): + from jax import numpy as jnp + import jax_cosmo as jc + import numpy as np + + # Generate initial particle positions + grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype( + np.float32) + # Interpolate with linear_matter spectrum to get initial density field + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power(cosmo, k) + + def pk_fn(x): + return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape) + + whitec = particle_mesh.generate_whitenoise(42, + type='complex', + unitary=False) + lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5 + * v * (1 / v.BoxSize).prod()**0.5) + init_mesh = lineark.c2r().value # XXX + + return lineark, grid, init_mesh + + +@pytest.fixture(scope="session") +def initial_conditions(fpm_initial_conditions): + _, _, init_mesh = fpm_initial_conditions + return init_mesh + + +@pytest.fixture(scope="session") +def solver(cosmo, particle_mesh): + from fastpm.core import Solver, Cosmology as FastPMCosmology + ref_cosmo = FastPMCosmology(cosmo) + return Solver(particle_mesh, ref_cosmo, B=1) + + +@pytest.fixture(scope="session") +def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor): + + lineark, grid, _ = fpm_initial_conditions + statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1) + return statelpt + + +@pytest.fixture(scope="session") +def fpm_lpt1_field(fpm_lpt1, particle_mesh): + return particle_mesh.paint(fpm_lpt1.X).value + + +@pytest.fixture(scope="session") +def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor): + + lineark, grid, _ = fpm_initial_conditions + statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2) + return statelpt + + +@pytest.fixture(scope="session") +def fpm_lpt2_field(fpm_lpt2, particle_mesh): + return particle_mesh.paint(fpm_lpt2.X).value + + +@pytest.fixture(scope="session") +def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor): + from fastpm.core import leapfrog + import numpy as np + + if lpt_scale_factor == 0.8: + pytest.skip("Do not run nbody simulation from scale factor 0.8") + + stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) + + finalstate = solver.nbody(fpm_lpt1, leapfrog(stages)) + fpm_mesh = particle_mesh.paint(finalstate.X).value + + return fpm_mesh + + +@pytest.fixture(scope="session") +def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor): + from fastpm.core import leapfrog + import numpy as np + + if lpt_scale_factor == 0.8: + pytest.skip("Do not run nbody simulation from scale factor 0.8") + + stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) + + finalstate = solver.nbody(fpm_lpt2, leapfrog(stages)) + fpm_mesh = particle_mesh.paint(finalstate.X).value + + return fpm_mesh diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..3369a8c --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,10 @@ +import jax.numpy as jnp + +def MSE(x , y): + return jnp.mean((x - y)**2) + +def MSE_3D(x , y): + return ((x - y)**2).mean(axis=0) + +def MSRE(x , y): + return jnp.mean(((x - y)/ y)**2) \ No newline at end of file diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py new file mode 100644 index 0000000..60d95ca --- /dev/null +++ b/tests/test_against_fpm.py @@ -0,0 +1,149 @@ +import pytest + + +from jax import numpy as jnp + +from jaxpm.painting import cic_paint, cic_paint_dx +from jaxpm.pm import lpt, make_diffrax_ode +from jaxpm.distributed import uniform_particles +from helpers import MSE , MSRE +from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Dopri5 +from jaxpm.utils import power_spectrum + +_TOLERANCE = 1e-4 +_PM_TOLERANCE = 1e-3 + +@pytest.mark.parametrize("order", [1 , 2]) +def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor, + fpm_lpt1_field,fpm_lpt2_field, cosmo , order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + particles = uniform_particles(mesh_shape) + + # Initial displacement + dx, _, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) + + fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field + + lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) + _ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) + _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + + assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + + +@pytest.mark.parametrize("order", [1, 2]) +def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, + fpm_lpt1_field,fpm_lpt2_field, cosmo , order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + # Initial displacement + dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) + + lpt_field = cic_paint_dx(dx) + + fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field + + _ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape) + _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + + assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + + +@pytest.mark.parametrize("order", [1, 2]) +def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, + nbody_from_lpt1,nbody_from_lpt2, cosmo , order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + particles = uniform_particles(mesh_shape) + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions,particles , a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape)) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([particles + dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1 , 0]) + + fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 + + _ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) + _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + + assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + +@pytest.mark.parametrize("order", [1, 2]) +def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, + nbody_from_lpt1,nbody_from_lpt2, cosmo , order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions , a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + + solver = Dopri5() + controller = PIDController(rtol=1e-9, + atol=1e-9, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint_dx(solutions.ys[-1 , 0]) + + fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 + + _ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape) + _ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape) + + assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + + + diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py new file mode 100644 index 0000000..39b48b9 --- /dev/null +++ b/tests/test_distributed_pm.py @@ -0,0 +1,144 @@ +from conftest import initialize_distributed +initialize_distributed() # ignore : E402 + +import jax # noqa : E402 +from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 +import jax.numpy as jnp # noqa : E402 +import pytest # noqa : E402 +from jax.experimental.multihost_utils import process_allgather # noqa : E402 +from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 +from diffrax import ODETerm, Dopri5, PIDController, SaveAt, diffeqsolve # noqa : E402 +from jax.sharding import PartitionSpec as P, NamedSharding # noqa : E402 +from helpers import MSE # noqa : E402 +from jax import lax # noqa : E402 + +_TOLERANCE = 3.0 # 🙃🙃 + +@pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("absolute_painting", [True, False]) +def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, + absolute_painting): + + mesh_shape, box_shape = simulation_config + # SINGLE DEVICE RUN + cosmo._workspace = {} + if absolute_painting: + particles = uniform_particles(mesh_shape) + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=0.1, + order=order) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + y0 = jnp.stack([particles + dx, p]) + else: + dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), + solutions.ys[-1, 0]) + else: + single_device_final_field = cic_paint_dx(solutions.ys[-1, 0]) + + print("Done with single device run") + # MULTI DEVICE RUN + + mesh = jax.make_mesh((1, 8), ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) + + print(f"sharded initial conditions {initial_conditions.sharding}") + + cosmo._workspace = {} + if absolute_painting: + particles = uniform_particles(mesh_shape, sharding=sharding) + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, + mesh_shape, + halo_size=halo_size, + sharding=sharding)) + + y0 = jnp.stack([particles + dx, p]) + else: + dx, p, _ = lpt(cosmo, + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, + mesh_shape, + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), + solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + else: + multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + + multi_device_final_field = process_allgather(multi_device_final_field , tiled=True) + + mse = MSE(single_device_final_field, multi_device_final_field) + print(f"MSE is {mse}") + + assert mse < _TOLERANCE \ No newline at end of file