mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
add tests
This commit is contained in:
parent
36ef18e3d0
commit
f70583b5fd
4 changed files with 483 additions and 0 deletions
180
tests/conftest.py
Normal file
180
tests/conftest.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Parameterized fixture for mesh_shape
|
||||
import os
|
||||
import pytest
|
||||
os.environ["EQX_ON_ERROR"] = "nan"
|
||||
setup_done = False
|
||||
on_cluster = False
|
||||
|
||||
|
||||
def is_on_cluster():
|
||||
global on_cluster
|
||||
return on_cluster
|
||||
|
||||
def initialize_distributed():
|
||||
global setup_done
|
||||
global on_cluster
|
||||
if not setup_done:
|
||||
if "SLURM_JOB_ID" in os.environ:
|
||||
on_cluster = True
|
||||
print("Running on cluster")
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
setup_done = True
|
||||
on_cluster = True
|
||||
else:
|
||||
print("Running locally")
|
||||
setup_done = True
|
||||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
import jax
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_teardown_session():
|
||||
# Code to run at the start of the session
|
||||
print("Starting session...")
|
||||
initialize_distributed()
|
||||
# Setup code here
|
||||
# e.g., connecting to a database, initializing some resources, etc.
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
((64, 64, 64) , (512., 512., 512.)), # BOX
|
||||
((64, 64, 128) , (256. , 256. , 512.)), # RECTANGULAR
|
||||
])
|
||||
def simulation_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=[0.1 , 0.5 , 0.8])
|
||||
def lpt_scale_factor(request):
|
||||
return request.param
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cosmo():
|
||||
from jax_cosmo import Cosmology
|
||||
from functools import partial
|
||||
Planck18 = partial(
|
||||
Cosmology,
|
||||
# Omega_m = 0.3111
|
||||
Omega_c=0.2607,
|
||||
Omega_b=0.0490,
|
||||
Omega_k=0.0,
|
||||
h=0.6766,
|
||||
n_s=0.9665,
|
||||
sigma8=0.8102,
|
||||
w0=-1.0,
|
||||
wa=0.0,
|
||||
)
|
||||
|
||||
return Planck18()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def particle_mesh(simulation_config):
|
||||
from pmesh.pm import ParticleMesh
|
||||
mesh_shape, box_shape = simulation_config
|
||||
return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4')
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_initial_conditions(cosmo, particle_mesh):
|
||||
from jax import numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
|
||||
# Generate initial particle positions
|
||||
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
||||
np.float32)
|
||||
# Interpolate with linear_matter spectrum to get initial density field
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
|
||||
def pk_fn(x):
|
||||
return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||
|
||||
whitec = particle_mesh.generate_whitenoise(42,
|
||||
type='complex',
|
||||
unitary=False)
|
||||
lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5
|
||||
* v * (1 / v.BoxSize).prod()**0.5)
|
||||
init_mesh = lineark.c2r().value # XXX
|
||||
|
||||
return lineark, grid, init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def initial_conditions(fpm_initial_conditions):
|
||||
_, _, init_mesh = fpm_initial_conditions
|
||||
return init_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def solver(cosmo, particle_mesh):
|
||||
from fastpm.core import Solver, Cosmology as FastPMCosmology
|
||||
ref_cosmo = FastPMCosmology(cosmo)
|
||||
return Solver(particle_mesh, ref_cosmo, B=1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt1_field(fpm_lpt1, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt1.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor):
|
||||
|
||||
lineark, grid, _ = fpm_initial_conditions
|
||||
statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2)
|
||||
return statelpt
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def fpm_lpt2_field(fpm_lpt2, particle_mesh):
|
||||
return particle_mesh.paint(fpm_lpt2.X).value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||
from fastpm.core import leapfrog
|
||||
import numpy as np
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||
from fastpm.core import leapfrog
|
||||
import numpy as np
|
||||
|
||||
if lpt_scale_factor == 0.8:
|
||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||
|
||||
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
|
||||
|
||||
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
|
||||
fpm_mesh = particle_mesh.paint(finalstate.X).value
|
||||
|
||||
return fpm_mesh
|
10
tests/helpers.py
Normal file
10
tests/helpers.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import jax.numpy as jnp
|
||||
|
||||
def MSE(x , y):
|
||||
return jnp.mean((x - y)**2)
|
||||
|
||||
def MSE_3D(x , y):
|
||||
return ((x - y)**2).mean(axis=0)
|
||||
|
||||
def MSRE(x , y):
|
||||
return jnp.mean(((x - y)/ y)**2)
|
149
tests/test_against_fpm.py
Normal file
149
tests/test_against_fpm.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
import pytest
|
||||
|
||||
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from helpers import MSE , MSRE
|
||||
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Dopri5
|
||||
from jaxpm.utils import power_spectrum
|
||||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
|
||||
@pytest.mark.parametrize("order", [1 , 2])
|
||||
def test_lpt_absoulute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field,fpm_lpt2_field, cosmo , order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
|
||||
_ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape)
|
||||
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field,fpm_lpt2_field, cosmo , order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
# Initial displacement
|
||||
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
lpt_field = cic_paint_dx(dx)
|
||||
|
||||
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
||||
|
||||
_ , jpm_ps = power_spectrum(lpt_field ,box_shape=box_shape)
|
||||
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1,nbody_from_lpt2, cosmo , order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
particles = uniform_particles(mesh_shape)
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions,particles , a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1 , 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape)
|
||||
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
nbody_from_lpt1,nbody_from_lpt2, cosmo , order):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions , a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
atol=1e-9,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=lpt_scale_factor,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
final_field = cic_paint_dx(solutions.ys[-1 , 0])
|
||||
|
||||
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
||||
|
||||
_ , jpm_ps = power_spectrum(final_field ,box_shape=box_shape)
|
||||
_ , fpm_ps = power_spectrum(fpm_ref_field ,box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
|
144
tests/test_distributed_pm.py
Normal file
144
tests/test_distributed_pm.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
from conftest import initialize_distributed
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
from diffrax import ODETerm, Dopri5, PIDController, SaveAt, diffeqsolve # noqa : E402
|
||||
from jax.sharding import PartitionSpec as P, NamedSharding # noqa : E402
|
||||
from helpers import MSE # noqa : E402
|
||||
from jax import lax # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0])
|
||||
else:
|
||||
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
||||
|
||||
print("Done with single device run")
|
||||
# MULTI DEVICE RUN
|
||||
|
||||
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
y0 = jnp.stack([particles + dx, p])
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jnp.stack([dx, p])
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
t0=0.1,
|
||||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
multi_device_final_field = process_allgather(multi_device_final_field , tiled=True)
|
||||
|
||||
mse = MSE(single_device_final_field, multi_device_final_field)
|
||||
print(f"MSE is {mse}")
|
||||
|
||||
assert mse < _TOLERANCE
|
Loading…
Add table
Reference in a new issue