2024-12-20 11:44:02 +01:00
|
|
|
import pytest
|
|
|
|
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
|
|
|
from helpers import MSE, MSRE
|
|
|
|
from jax import numpy as jnp
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
from jaxdecomp import ShardedArray
|
2024-12-20 11:44:02 +01:00
|
|
|
from jaxpm.distributed import uniform_particles
|
|
|
|
from jaxpm.painting import cic_paint, cic_paint_dx
|
|
|
|
from jaxpm.pm import lpt, make_diffrax_ode
|
|
|
|
from jaxpm.utils import power_spectrum
|
2025-01-20 22:40:28 +01:00
|
|
|
import jax
|
2024-12-20 11:44:02 +01:00
|
|
|
_TOLERANCE = 1e-4
|
|
|
|
_PM_TOLERANCE = 1e-3
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.single_device
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
2025-01-20 22:40:28 +01:00
|
|
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
2024-12-20 11:44:02 +01:00
|
|
|
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
2025-01-20 22:40:28 +01:00
|
|
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
cosmo._workspace = {}
|
|
|
|
particles = uniform_particles(mesh_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
particles = ShardedArray(particles)
|
|
|
|
initial_conditions = ShardedArray(initial_conditions)
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
|
|
|
dx, _, _ = lpt(cosmo,
|
|
|
|
initial_conditions,
|
|
|
|
particles,
|
|
|
|
a=lpt_scale_factor,
|
|
|
|
order=order)
|
|
|
|
|
|
|
|
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
|
|
|
|
|
|
|
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
|
2025-01-20 22:40:28 +01:00
|
|
|
lpt_field_arr, = jax.tree.leaves(lpt_field)
|
|
|
|
_, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
|
2024-12-20 11:44:02 +01:00
|
|
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
|
2024-12-20 11:44:02 +01:00
|
|
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
assert type(dx) == ShardedArray
|
|
|
|
assert type(lpt_field) == ShardedArray
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
@pytest.mark.single_device
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
2025-01-20 22:40:28 +01:00
|
|
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
2024-12-20 11:44:02 +01:00
|
|
|
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
2025-01-20 22:40:28 +01:00
|
|
|
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
cosmo._workspace = {}
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
initial_conditions = ShardedArray(initial_conditions)
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
|
|
|
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
|
|
|
|
|
|
|
lpt_field = cic_paint_dx(dx)
|
|
|
|
|
|
|
|
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
|
2025-01-20 22:40:28 +01:00
|
|
|
lpt_field_arr, = jax.tree.leaves(lpt_field)
|
|
|
|
_, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
|
2024-12-20 11:44:02 +01:00
|
|
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
|
2024-12-20 11:44:02 +01:00
|
|
|
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
assert type(dx) == ShardedArray
|
|
|
|
assert type(lpt_field) == ShardedArray
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
@pytest.mark.single_device
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
2025-01-20 22:40:28 +01:00
|
|
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
2024-12-20 11:44:02 +01:00
|
|
|
def test_nbody_absolute(simulation_config, initial_conditions,
|
|
|
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
2025-01-20 22:40:28 +01:00
|
|
|
cosmo, order , shardedArrayAPI):
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
cosmo._workspace = {}
|
|
|
|
particles = uniform_particles(mesh_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
particles = ShardedArray(particles)
|
|
|
|
initial_conditions = ShardedArray(initial_conditions)
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
|
|
|
dx, p, _ = lpt(cosmo,
|
|
|
|
initial_conditions,
|
|
|
|
particles,
|
|
|
|
a=lpt_scale_factor,
|
|
|
|
order=order)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solver = Dopri5()
|
|
|
|
controller = PIDController(rtol=1e-8,
|
|
|
|
atol=1e-8,
|
|
|
|
pcoeff=0.4,
|
|
|
|
icoeff=1,
|
|
|
|
dcoeff=0)
|
|
|
|
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
|
|
solver,
|
|
|
|
t0=lpt_scale_factor,
|
|
|
|
t1=1.0,
|
|
|
|
dt0=None,
|
|
|
|
y0=y0,
|
2025-01-20 22:40:28 +01:00
|
|
|
args=cosmo,
|
2024-12-20 11:44:02 +01:00
|
|
|
stepsize_controller=controller,
|
|
|
|
saveat=saveat)
|
|
|
|
|
|
|
|
final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0])
|
|
|
|
|
|
|
|
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
final_field_arr, = jax.tree.leaves(final_field)
|
|
|
|
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
|
2024-12-20 11:44:02 +01:00
|
|
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
|
2024-12-20 11:44:02 +01:00
|
|
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
assert type(dx) == ShardedArray
|
|
|
|
assert type( solutions.ys[-1, 0]) == ShardedArray
|
|
|
|
assert type(final_field) == ShardedArray
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
@pytest.mark.single_device
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
2025-01-20 22:40:28 +01:00
|
|
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
2024-12-20 11:44:02 +01:00
|
|
|
def test_nbody_relative(simulation_config, initial_conditions,
|
|
|
|
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
2025-01-20 22:40:28 +01:00
|
|
|
cosmo, order , shardedArrayAPI):
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
cosmo._workspace = {}
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
initial_conditions = ShardedArray(initial_conditions)
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
|
|
|
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
|
|
|
|
|
|
|
ode_fn = ODETerm(
|
2025-01-20 22:40:28 +01:00
|
|
|
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solver = Dopri5()
|
|
|
|
controller = PIDController(rtol=1e-9,
|
|
|
|
atol=1e-9,
|
|
|
|
pcoeff=0.4,
|
|
|
|
icoeff=1,
|
|
|
|
dcoeff=0)
|
|
|
|
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
|
|
solver,
|
|
|
|
t0=lpt_scale_factor,
|
|
|
|
t1=1.0,
|
|
|
|
dt0=None,
|
|
|
|
y0=y0,
|
2025-01-20 22:40:28 +01:00
|
|
|
args=cosmo,
|
2024-12-20 11:44:02 +01:00
|
|
|
stepsize_controller=controller,
|
|
|
|
saveat=saveat)
|
|
|
|
|
|
|
|
final_field = cic_paint_dx(solutions.ys[-1, 0])
|
|
|
|
|
|
|
|
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
final_field_arr, = jax.tree.leaves(final_field)
|
|
|
|
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
|
2024-12-20 11:44:02 +01:00
|
|
|
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
|
2024-12-20 11:44:02 +01:00
|
|
|
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
if shardedArrayAPI:
|
|
|
|
assert type(dx) == ShardedArray
|
|
|
|
assert type( solutions.ys[-1, 0]) == ShardedArray
|
|
|
|
assert type(final_field) == ShardedArray
|