update tests and add test for FWD REV gradient

This commit is contained in:
Wassim Kabalan 2025-01-20 22:40:28 +01:00
parent 151fa09247
commit 20fe25c324
5 changed files with 434 additions and 41 deletions

View file

@ -173,3 +173,19 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
fpm_mesh = particle_mesh.paint(finalstate.X).value
return fpm_mesh
def compare_sharding(sharding1, sharding2):
def get_axis_size(sharding, idx):
axis_name = sharding.spec[idx]
if axis_name is None:
return 1
else:
return sharding.mesh.shape[sharding.spec[idx]]
def get_pdims_from_sharding(sharding):
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))])
pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2)
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
return pdims1 == pdims2

View file

@ -2,7 +2,7 @@ import jax.numpy as jnp
def MSE(x, y):
return jnp.mean((x - y)**2)
return ((x - y)**2).mean()
def MSE_3D(x, y):
@ -10,4 +10,4 @@ def MSE_3D(x, y):
def MSRE(x, y):
return jnp.mean(((x - y) / y)**2)
return (((x - y) / y)**2).mean()

View file

@ -3,24 +3,30 @@ from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from jaxdecomp import ShardedArray
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum
import jax
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement
dx, _, _ = lpt(cosmo,
initial_conditions,
@ -31,44 +37,61 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx)
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
lpt_field_arr, = jax.tree.leaves(lpt_field)
_, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order):
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
if shardedArrayAPI:
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement
dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
lpt_field = cic_paint_dx(dx)
fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
lpt_field_arr, = jax.tree.leaves(lpt_field)
_, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
@ -76,7 +99,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
a=lpt_scale_factor,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -87,7 +110,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
saveat = SaveAt(t1=True)
y0 = jnp.stack([particles + dx, p])
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
solutions = diffeqsolve(ode_fn,
solver,
@ -95,6 +118,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -102,27 +126,37 @@ def test_nbody_absolute(simulation_config, initial_conditions,
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
final_field_arr, = jax.tree.leaves(final_field)
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order):
cosmo, order , shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
if shardedArrayAPI:
initial_conditions = ShardedArray(initial_conditions)
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -133,7 +167,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
saveat = SaveAt(t1=True)
y0 = jnp.stack([dx, p])
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
solutions = diffeqsolve(ode_fn,
solver,
@ -141,6 +175,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -148,8 +183,14 @@ def test_nbody_relative(simulation_config, initial_conditions,
fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
final_field_arr, = jax.tree.leaves(final_field)
_, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape)
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray

View file

@ -1,4 +1,4 @@
from conftest import initialize_distributed
from conftest import initialize_distributed , compare_sharding
initialize_distributed() # ignore : E402
@ -12,38 +12,48 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxdecomp import ShardedArray # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting):
absolute_painting,shardedArrayAPI):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
if shardedArrayAPI:
ic = ShardedArray(initial_conditions)
else:
ic = initial_conditions
if absolute_painting:
particles = uniform_particles(mesh_shape)
if shardedArrayAPI:
particles = ShardedArray(particles)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
ic,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
y0 = jnp.stack([particles + dx, p])
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else:
dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order)
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
y0 = jnp.stack([dx, p])
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -59,6 +69,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t0=0.1,
t1=1.0,
dt0=None,
args=cosmo,
y0=y0,
stepsize_controller=controller,
saveat=saveat)
@ -76,17 +87,22 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
ic = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
print(f"sharded initial conditions {ic.sharding}")
if shardedArrayAPI:
ic = ShardedArray(ic , sharding)
cosmo._workspace = {}
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
if shardedArrayAPI:
particles = ShardedArray(particles, sharding)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
ic,
particles,
a=0.1,
order=order,
@ -94,26 +110,26 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
make_diffrax_ode(
mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([particles + dx, p])
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
ic,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(cosmo,
make_diffrax_ode(
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jnp.stack([dx, p])
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -130,6 +146,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
@ -143,10 +160,182 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
multi_device_final_field = process_allgather(multi_device_final_field,
multi_device_final_field_g = process_allgather(multi_device_final_field,
tiled=True)
mse = MSE(single_device_final_field, multi_device_final_field)
single_device_final_field_arr, = jax.tree.leaves(single_device_final_field)
multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g)
mse = MSE(single_device_final_field_arr, multi_device_final_field_arr)
print(f"MSE is {mse}")
if shardedArrayAPI:
assert type(multi_device_final_field) == ShardedArray
assert compare_sharding(multi_device_final_field.sharding , sharding)
assert compare_sharding(multi_device_final_field.initial_sharding , sharding)
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2,
absolute_painting):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions , cosmo):
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
particles = ShardedArray(particles, sharding)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
order=order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=1,
dcoeff=0)
saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.0,
dt0=None,
y0=y0,
args=cosmo,
stepsize_controller=controller,
saveat=saveat)
if absolute_painting:
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
else:
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size,
sharding=sharding)
return multi_device_final_field
@jax.jit
def model(initial_conditions , cosmo):
final_field = forward_model(initial_conditions , cosmo)
final_field, = jax.tree.leaves(final_field)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions , cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions , cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo)
assert compare_sharding(good_grads.sharding , initial_conditions.sharding)
assert compare_sharding(off_grads.sharding , initial_conditions.sharding)
@pytest.mark.distributed
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_fwd_rev_gradients(cosmo,absolute_painting):
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0)
# SINGLE DEVICE RUN
cosmo._workspace = {}
mesh = jax.make_mesh((1, 8), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
cosmo._workspace = {}
@partial(jax.jit , static_argnums=(3,4 , 5))
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None):
paint_absolute_pos = particles is not None
if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
return initial_force[...,0]
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
assert compare_sharding(forces.sharding , initial_conditions.sharding)
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding)

147
tests/test_sharded_array.py Normal file
View file

@ -0,0 +1,147 @@
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import os
os.environ["EQX_ON_ERROR"] = "nan"
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.debug import visualize_array_sharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
from jaxpm.distributed import uniform_particles
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4)
#devices = create_device_mesh(pdims)
#mesh = Mesh(devices, axis_names=('x', 'y'))
#sharding = NamedSharding(mesh, P('x', 'y'))
sharding = None
from typing import NamedTuple
from jaxdecomp import ShardedArray
mesh_shape = 64
box_size = 64.
halo_size = 2
snapshots = (0.5, 1.0)
class Params(NamedTuple):
omega_c: float
sigma8: float
initial_conditions : jnp.ndarray
mesh_shape = (mesh_shape,) * 3
box_size = (box_size,) * 3
omega_c = 0.25
sigma8 = 0.8
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
#initial_conditions = ShardedArray(initial_conditions, sharding)
params = Params(omega_c, sigma8, initial_conditions)
@partial(jax.jit , static_argnums=(1 , 2,3,4 ))
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Create initial conditions
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
particles = uniform_particles(mesh_shape , sharding)
ic_structure = jax.tree.structure(params.initial_conditions)
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles))
# Initial displacement
dx, p, f = lpt(cosmo,
params.initial_conditions,
particles,
a=0.1,
order=2,
halo_size=halo_size,
sharding=sharding)
# Evolve the simulation forward
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding))
solver = LeapfrogMidpoint()
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p)
print(f"y0 structure: {jax.tree.structure(y0)}")
stepsize_controller = ConstantStepSize()
res = diffeqsolve(ode_fn,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=y0,
args=cosmo,
saveat=SaveAt(ts=snapshots),
stepsize_controller=stepsize_controller)
ode_solutions = [sol[0] for sol in res.ys]
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1])
return particles + dx , ode_field
ode_field = cic_paint_dx(ode_solutions[-1])
return dx , ode_field
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots)
import matplotlib.pyplot as plt
lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
#lpt_field = cic_paint_dx(lpt_particles)
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(lpt_field.sum(axis=0) , cmap='magma')
plt.colorbar()
plt.title('LPT field')
plt.subplot(122)
plt.imshow(ode_field.sum(axis=0) , cmap='magma')
plt.colorbar()
plt.title('ODE field')
plt.show()
plt.close()
#particles = jax.random.uniform(jax.random.PRNGKey(0), (4 , 4 ,4 , 3), minval=0.1, maxval=0.9)
#field = jax.random.uniform(jax.random.PRNGKey(0), (4, 4, 4))
#
#partiles = ShardedArray(particles, sharding)
#field = ShardedArray(field, sharding)
#
#
#cic_read_dx(field , particles )