update tests and add test for FWD REV gradient

This commit is contained in:
Wassim Kabalan 2025-01-20 22:40:28 +01:00
parent 151fa09247
commit 20fe25c324
5 changed files with 434 additions and 41 deletions

View file

@ -173,3 +173,19 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
fpm_mesh = particle_mesh.paint(finalstate.X).value fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh return fpm_mesh
def compare_sharding(sharding1, sharding2):
def get_axis_size(sharding, idx):
axis_name = sharding.spec[idx]
if axis_name is None:
return 1
else:
return sharding.mesh.shape[sharding.spec[idx]]
def get_pdims_from_sharding(sharding):
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))])
pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2)
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
return pdims1 == pdims2

View file

@ -2,7 +2,7 @@ import jax.numpy as jnp
def MSE(x, y): def MSE(x, y):
return jnp.mean((x - y)**2) return ((x - y)**2).mean()
def MSE_3D(x, y): def MSE_3D(x, y):
@ -10,4 +10,4 @@ def MSE_3D(x, y):
def MSRE(x, y): def MSRE(x, y):
return jnp.mean(((x - y) / y)**2) return (((x - y) / y)**2).mean()

View file

@ -3,24 +3,30 @@ from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from jaxdecomp import ShardedArray
from jaxpm.distributed import uniform_particles from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum from jaxpm.utils import power_spectrum
import jax
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order): fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement # Initial displacement
dx, _, _ = lpt(cosmo, dx, _, _ = lpt(cosmo,
initial_conditions, initial_conditions,
@ -31,44 +37,61 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) lpt_field_arr, = jax.tree.leaves(lpt_field)
_, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order): fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
if shardedArrayAPI:
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement # Initial displacement
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
lpt_field = cic_paint_dx(dx) lpt_field = cic_paint_dx(dx)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field_arr, = jax.tree.leaves(lpt_field)
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) _, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_absolute(simulation_config, initial_conditions, def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order): cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, initial_conditions,
@ -76,7 +99,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor, a=lpt_scale_factor,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -87,7 +110,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
y0 = jnp.stack([particles + dx, p]) y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -95,6 +118,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -102,27 +126,37 @@ def test_nbody_absolute(simulation_config, initial_conditions,
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) final_field_arr, = jax.tree.leaves(final_field)
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_relative(simulation_config, initial_conditions, def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order): cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
if shardedArrayAPI:
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,
@ -133,7 +167,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
y0 = jnp.stack([dx, p]) y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -141,6 +175,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -148,8 +183,14 @@ def test_nbody_relative(simulation_config, initial_conditions,
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape) final_field_arr, = jax.tree.leaves(final_field)
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray

View file

@ -1,4 +1,4 @@
from conftest import initialize_distributed from conftest import initialize_distributed , compare_sharding
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
@ -12,38 +12,48 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402 from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxdecomp import ShardedArray # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting): absolute_painting,shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
cosmo._workspace = {} cosmo._workspace = {}
if shardedArrayAPI:
ic = ShardedArray(initial_conditions)
else:
ic = initial_conditions
if absolute_painting: if absolute_painting:
particles = uniform_particles(mesh_shape) particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, ic,
particles, particles,
a=0.1, a=0.1,
order=order) order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jnp.stack([particles + dx, p]) y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else: else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p]) y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -59,6 +69,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t0=0.1, t0=0.1,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
args=cosmo,
y0=y0, y0=y0,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -76,17 +87,22 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions, ic = lax.with_sharding_constraint(initial_conditions,
sharding) sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {ic.sharding}")
if shardedArrayAPI:
ic = ShardedArray(ic , sharding)
cosmo._workspace = {} cosmo._workspace = {}
if absolute_painting: if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding) particles = uniform_particles(mesh_shape, sharding=sharding)
if shardedArrayAPI:
particles = ShardedArray(particles, sharding)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, ic,
particles, particles,
a=0.1, a=0.1,
order=order, order=order,
@ -94,26 +110,26 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(
mesh_shape, mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jnp.stack([particles + dx, p]) y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else: else:
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, ic,
a=0.1, a=0.1,
order=order, order=order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(cosmo, make_diffrax_ode(
mesh_shape, mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jnp.stack([dx, p]) y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -130,6 +146,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0, t1=1.0,
dt0=None, dt0=None,
y0=y0, y0=y0,
args=cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
@ -143,10 +160,182 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field, multi_device_final_field_g = process_allgather(multi_device_final_field,
tiled=True) tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field) single_device_final_field_arr, = jax.tree.leaves(single_device_final_field)
multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g)
mse = MSE(single_device_final_field_arr, multi_device_final_field_arr)
print(f"MSE is {mse}") print(f"MSE is {mse}")
if shardedArrayAPI:
assert type(multi_device_final_field) == ShardedArray
assert compare_sharding(multi_device_final_field.sharding , sharding)
assert compare_sharding(multi_device_final_field.initial_sharding , sharding)
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2,
absolute_painting):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions , cosmo):
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
particles = ShardedArray(particles, sharding)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions , cosmo):
final_field = forward_model(initial_conditions , cosmo)
final_field, = jax.tree.leaves(final_field)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions , cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions , cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo)
assert compare_sharding(good_grads.sharding , initial_conditions.sharding)
assert compare_sharding(off_grads.sharding , initial_conditions.sharding)
@pytest.mark.distributed
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_fwd_rev_gradients(cosmo,absolute_painting):
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0)
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
cosmo._workspace = {}
@partial(jax.jit , static_argnums=(3,4 , 5))
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None):
paint_absolute_pos = particles is not None
if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[...,0]
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
assert compare_sharding(forces.sharding , initial_conditions.sharding)
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding)

147
tests/test_sharded_array.py Normal file
View file

@ -0,0 +1,147 @@
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import os
os.environ["EQX_ON_ERROR"] = "nan"
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.debug import visualize_array_sharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
from jaxpm.distributed import uniform_particles
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4)
#devices = create_device_mesh(pdims)
#mesh = Mesh(devices, axis_names=('x', 'y'))
#sharding = NamedSharding(mesh, P('x', 'y'))
sharding = None
from typing import NamedTuple
from jaxdecomp import ShardedArray
mesh_shape = 64
box_size = 64.
halo_size = 2
snapshots = (0.5, 1.0)
class Params(NamedTuple):
omega_c: float
sigma8: float
initial_conditions : jnp.ndarray
mesh_shape = (mesh_shape,) * 3
box_size = (box_size,) * 3
omega_c = 0.25
sigma8 = 0.8
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
#initial_conditions = ShardedArray(initial_conditions, sharding)
params = Params(omega_c, sigma8, initial_conditions)
@partial(jax.jit , static_argnums=(1 , 2,3,4 ))
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Create initial conditions
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
particles = uniform_particles(mesh_shape , sharding)
ic_structure = jax.tree.structure(params.initial_conditions)
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles))
# Initial displacement
dx, p, f = lpt(cosmo,
params.initial_conditions,
particles,
a=0.1,
order=2,
halo_size=halo_size,
sharding=sharding)
# Evolve the simulation forward
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding))
solver = LeapfrogMidpoint()
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p)
print(f"y0 structure: {jax.tree.structure(y0)}")
stepsize_controller = ConstantStepSize()
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=y0,
args=cosmo,
saveat=SaveAt(ts=snapshots),
stepsize_controller=stepsize_controller)
ode_solutions = [sol[0] for sol in res.ys]
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1])
return particles + dx , ode_field
ode_field = cic_paint_dx(ode_solutions[-1])
return dx , ode_field
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots)
import matplotlib.pyplot as plt
lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
#lpt_field = cic_paint_dx(lpt_particles)
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(lpt_field.sum(axis=0) , cmap='magma')
plt.colorbar()
plt.title('LPT field')
plt.subplot(122)
plt.imshow(ode_field.sum(axis=0) , cmap='magma')
plt.colorbar()
plt.title('ODE field')
plt.show()
plt.close()
#particles = jax.random.uniform(jax.random.PRNGKey(0), (4 , 4 ,4 , 3), minval=0.1, maxval=0.9)
#field = jax.random.uniform(jax.random.PRNGKey(0), (4, 4, 4))
#
#partiles = ShardedArray(particles, sharding)
#field = ShardedArray(field, sharding)
#
#
#cic_read_dx(field , particles )