2025-01-20 22:41:19 +01:00
|
|
|
from conftest import compare_sharding, initialize_distributed
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
initialize_distributed() # ignore : E402
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
from functools import partial # noqa : E402
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
import jax # noqa : E402
|
|
|
|
import jax.numpy as jnp # noqa : E402
|
2025-01-20 22:41:19 +01:00
|
|
|
import jax_cosmo as jc # noqa : E402
|
2024-12-20 11:44:02 +01:00
|
|
|
import pytest # noqa : E402
|
|
|
|
from diffrax import SaveAt # noqa : E402
|
|
|
|
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
|
|
|
from helpers import MSE # noqa : E402
|
|
|
|
from jax import lax # noqa : E402
|
|
|
|
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
|
|
|
from jax.sharding import NamedSharding
|
|
|
|
from jax.sharding import PartitionSpec as P # noqa : E402
|
2025-01-20 22:41:19 +01:00
|
|
|
from jaxdecomp import ShardedArray # noqa : E402
|
|
|
|
|
|
|
|
from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
|
2024-12-20 11:44:02 +01:00
|
|
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
2025-01-20 22:41:19 +01:00
|
|
|
from jaxpm.pm import pm_forces # noqa : E402
|
2024-12-20 11:44:02 +01:00
|
|
|
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
2025-01-20 22:41:19 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
_TOLERANCE = 3.0 # 🙃🙃
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.distributed
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
|
|
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
2025-01-20 22:40:28 +01:00
|
|
|
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
2024-12-20 11:44:02 +01:00
|
|
|
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
2025-01-20 22:41:19 +01:00
|
|
|
absolute_painting, shardedArrayAPI):
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
# SINGLE DEVICE RUN
|
|
|
|
cosmo._workspace = {}
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
ic = ShardedArray(initial_conditions)
|
|
|
|
else:
|
|
|
|
ic = initial_conditions
|
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
if absolute_painting:
|
|
|
|
particles = uniform_particles(mesh_shape)
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
particles = ShardedArray(particles)
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
2025-01-20 22:41:19 +01:00
|
|
|
dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
|
2025-01-20 22:40:28 +01:00
|
|
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
2025-01-20 22:41:19 +01:00
|
|
|
y0 = jax.tree.map(
|
|
|
|
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
|
|
|
dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
else:
|
2025-01-20 22:40:28 +01:00
|
|
|
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
|
2025-01-20 22:41:19 +01:00
|
|
|
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
|
|
|
paint_absolute_pos=False))
|
|
|
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solver = Dopri5()
|
|
|
|
controller = PIDController(rtol=1e-8,
|
|
|
|
atol=1e-8,
|
|
|
|
pcoeff=0.4,
|
|
|
|
icoeff=1,
|
|
|
|
dcoeff=0)
|
|
|
|
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
|
|
solver,
|
|
|
|
t0=0.1,
|
|
|
|
t1=1.0,
|
|
|
|
dt0=None,
|
2025-01-20 22:40:28 +01:00
|
|
|
args=cosmo,
|
2024-12-20 11:44:02 +01:00
|
|
|
y0=y0,
|
|
|
|
stepsize_controller=controller,
|
|
|
|
saveat=saveat)
|
|
|
|
|
|
|
|
if absolute_painting:
|
|
|
|
single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
|
|
|
solutions.ys[-1, 0])
|
|
|
|
else:
|
|
|
|
single_device_final_field = cic_paint_dx(solutions.ys[-1, 0])
|
|
|
|
|
|
|
|
print("Done with single device run")
|
|
|
|
# MULTI DEVICE RUN
|
|
|
|
|
|
|
|
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
|
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
halo_size = mesh_shape[0] // 2
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
ic = lax.with_sharding_constraint(initial_conditions, sharding)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
print(f"sharded initial conditions {ic.sharding}")
|
|
|
|
|
|
|
|
if shardedArrayAPI:
|
2025-01-20 22:41:19 +01:00
|
|
|
ic = ShardedArray(ic, sharding)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
cosmo._workspace = {}
|
|
|
|
if absolute_painting:
|
|
|
|
particles = uniform_particles(mesh_shape, sharding=sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
particles = ShardedArray(particles, sharding)
|
2024-12-20 11:44:02 +01:00
|
|
|
# Initial displacement
|
|
|
|
dx, p, _ = lpt(cosmo,
|
2025-01-20 22:40:28 +01:00
|
|
|
ic,
|
2024-12-20 11:44:02 +01:00
|
|
|
particles,
|
|
|
|
a=0.1,
|
|
|
|
order=order,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
|
|
|
|
ode_fn = ODETerm(
|
2025-01-20 22:41:19 +01:00
|
|
|
make_diffrax_ode(mesh_shape,
|
2024-12-20 11:44:02 +01:00
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding))
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
y0 = jax.tree.map(
|
|
|
|
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
|
|
|
dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
else:
|
|
|
|
dx, p, _ = lpt(cosmo,
|
2025-01-20 22:40:28 +01:00
|
|
|
ic,
|
2024-12-20 11:44:02 +01:00
|
|
|
a=0.1,
|
|
|
|
order=order,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
ode_fn = ODETerm(
|
2025-01-20 22:41:19 +01:00
|
|
|
make_diffrax_ode(mesh_shape,
|
2024-12-20 11:44:02 +01:00
|
|
|
paint_absolute_pos=False,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding))
|
2025-01-20 22:41:19 +01:00
|
|
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
|
|
|
solver = Dopri5()
|
|
|
|
controller = PIDController(rtol=1e-8,
|
|
|
|
atol=1e-8,
|
|
|
|
pcoeff=0.4,
|
|
|
|
icoeff=1,
|
|
|
|
dcoeff=0)
|
|
|
|
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
|
|
solver,
|
|
|
|
t0=0.1,
|
|
|
|
t1=1.0,
|
|
|
|
dt0=None,
|
|
|
|
y0=y0,
|
2025-01-20 22:40:28 +01:00
|
|
|
args=cosmo,
|
2024-12-20 11:44:02 +01:00
|
|
|
stepsize_controller=controller,
|
|
|
|
saveat=saveat)
|
|
|
|
|
|
|
|
if absolute_painting:
|
|
|
|
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
|
|
|
solutions.ys[-1, 0],
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
else:
|
|
|
|
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
multi_device_final_field_g = process_allgather(multi_device_final_field,
|
2025-01-20 22:41:19 +01:00
|
|
|
tiled=True)
|
2024-12-20 11:44:02 +01:00
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
single_device_final_field_arr, = jax.tree.leaves(single_device_final_field)
|
|
|
|
multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g)
|
|
|
|
mse = MSE(single_device_final_field_arr, multi_device_final_field_arr)
|
2024-12-20 11:44:02 +01:00
|
|
|
print(f"MSE is {mse}")
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
if shardedArrayAPI:
|
|
|
|
assert type(multi_device_final_field) == ShardedArray
|
2025-01-20 22:41:19 +01:00
|
|
|
assert compare_sharding(multi_device_final_field.sharding, sharding)
|
|
|
|
assert compare_sharding(multi_device_final_field.initial_sharding,
|
|
|
|
sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
assert mse < _TOLERANCE
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.distributed
|
|
|
|
@pytest.mark.parametrize("order", [1, 2])
|
|
|
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
2025-01-20 22:41:19 +01:00
|
|
|
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
|
|
|
order, nbody_from_lpt1, nbody_from_lpt2,
|
|
|
|
absolute_painting):
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
mesh_shape, box_shape = simulation_config
|
|
|
|
# SINGLE DEVICE RUN
|
|
|
|
cosmo._workspace = {}
|
|
|
|
|
|
|
|
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
|
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
halo_size = mesh_shape[0] // 2
|
|
|
|
|
|
|
|
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
|
|
|
sharding)
|
|
|
|
|
|
|
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
initial_conditions = ShardedArray(initial_conditions, sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
cosmo._workspace = {}
|
|
|
|
|
|
|
|
@jax.jit
|
2025-01-20 22:41:19 +01:00
|
|
|
def forward_model(initial_conditions, cosmo):
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
if absolute_painting:
|
|
|
|
particles = uniform_particles(mesh_shape, sharding=sharding)
|
|
|
|
particles = ShardedArray(particles, sharding)
|
|
|
|
# Initial displacement
|
|
|
|
dx, p, _ = lpt(cosmo,
|
2025-01-20 22:41:19 +01:00
|
|
|
initial_conditions,
|
|
|
|
particles,
|
|
|
|
a=0.1,
|
|
|
|
order=order,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
ode_fn = ODETerm(
|
2025-01-20 22:41:19 +01:00
|
|
|
make_diffrax_ode(mesh_shape,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding))
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
y0 = jax.tree.map(
|
|
|
|
lambda particles, dx, p: jnp.stack([particles + dx, p]),
|
|
|
|
particles, dx, p)
|
2025-01-20 22:40:28 +01:00
|
|
|
else:
|
|
|
|
dx, p, _ = lpt(cosmo,
|
2025-01-20 22:41:19 +01:00
|
|
|
initial_conditions,
|
|
|
|
a=0.1,
|
|
|
|
order=order,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
ode_fn = ODETerm(
|
2025-01-20 22:41:19 +01:00
|
|
|
make_diffrax_ode(mesh_shape,
|
|
|
|
paint_absolute_pos=False,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding))
|
|
|
|
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
solver = Dopri5()
|
|
|
|
controller = PIDController(rtol=1e-8,
|
2025-01-20 22:41:19 +01:00
|
|
|
atol=1e-8,
|
|
|
|
pcoeff=0.4,
|
|
|
|
icoeff=1,
|
|
|
|
dcoeff=0)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
saveat = SaveAt(t1=True)
|
|
|
|
|
|
|
|
solutions = diffeqsolve(ode_fn,
|
|
|
|
solver,
|
|
|
|
t0=0.1,
|
|
|
|
t1=1.0,
|
|
|
|
dt0=None,
|
|
|
|
y0=y0,
|
|
|
|
args=cosmo,
|
|
|
|
stepsize_controller=controller,
|
|
|
|
saveat=saveat)
|
|
|
|
|
|
|
|
if absolute_painting:
|
|
|
|
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
2025-01-20 22:41:19 +01:00
|
|
|
solutions.ys[-1, 0],
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
else:
|
|
|
|
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
|
|
|
|
return multi_device_final_field
|
|
|
|
|
|
|
|
@jax.jit
|
2025-01-20 22:41:19 +01:00
|
|
|
def model(initial_conditions, cosmo):
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
final_field = forward_model(initial_conditions, cosmo)
|
2025-01-20 22:40:28 +01:00
|
|
|
final_field, = jax.tree.leaves(final_field)
|
2025-01-20 22:41:19 +01:00
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
return MSE(final_field,
|
|
|
|
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
obs_val = model(initial_conditions, cosmo)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
|
|
|
jax.random.key(42), initial_conditions.shape) * 5
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
|
|
|
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
|
|
|
|
|
|
|
assert compare_sharding(good_grads.sharding, initial_conditions.sharding)
|
|
|
|
assert compare_sharding(off_grads.sharding, initial_conditions.sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.distributed
|
|
|
|
@pytest.mark.parametrize("absolute_painting", [True, False])
|
2025-01-20 22:41:19 +01:00
|
|
|
def test_fwd_rev_gradients(cosmo, absolute_painting):
|
2025-01-20 22:40:28 +01:00
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
2025-01-20 22:40:28 +01:00
|
|
|
# SINGLE DEVICE RUN
|
|
|
|
cosmo._workspace = {}
|
|
|
|
|
|
|
|
mesh = jax.make_mesh((1, 8), ('x', 'y'))
|
|
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
halo_size = mesh_shape[0] // 2
|
|
|
|
|
|
|
|
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
|
|
|
|
|
|
|
|
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
|
|
|
sharding)
|
|
|
|
|
|
|
|
print(f"sharded initial conditions {initial_conditions.sharding}")
|
2025-01-20 22:41:19 +01:00
|
|
|
initial_conditions = ShardedArray(initial_conditions, sharding)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
cosmo._workspace = {}
|
|
|
|
|
2025-01-20 22:41:19 +01:00
|
|
|
@partial(jax.jit, static_argnums=(3, 4, 5))
|
|
|
|
def compute_forces(initial_conditions,
|
|
|
|
cosmo,
|
|
|
|
particles=None,
|
|
|
|
a=0.5,
|
|
|
|
halo_size=0,
|
|
|
|
sharding=None):
|
|
|
|
|
2025-01-20 22:40:28 +01:00
|
|
|
paint_absolute_pos = particles is not None
|
|
|
|
if particles is None:
|
2025-01-20 22:41:19 +01:00
|
|
|
particles = jax.tree.map(
|
|
|
|
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
|
|
|
|
initial_conditions)
|
2025-01-20 22:40:28 +01:00
|
|
|
|
|
|
|
a = jnp.atleast_1d(a)
|
|
|
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
|
|
|
delta_k = fft3d(initial_conditions)
|
|
|
|
initial_force = pm_forces(particles,
|
2025-01-20 22:41:19 +01:00
|
|
|
delta=delta_k,
|
|
|
|
paint_absolute_pos=paint_absolute_pos,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
|
|
|
|
return initial_force[..., 0]
|
|
|
|
|
|
|
|
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
|
|
|
|
sharding) if absolute_painting else None
|
|
|
|
forces = compute_forces(initial_conditions,
|
|
|
|
cosmo,
|
|
|
|
particles=particles,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
|
|
|
|
cosmo,
|
|
|
|
particles=particles,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
|
|
|
|
cosmo,
|
|
|
|
particles=particles,
|
|
|
|
halo_size=halo_size,
|
|
|
|
sharding=sharding)
|
|
|
|
|
|
|
|
assert compare_sharding(forces.sharding, initial_conditions.sharding)
|
|
|
|
assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
|
|
|
|
initial_conditions.sharding)
|
|
|
|
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)
|