From 20fe25c324d0a9a13fe21ab62b7133b9ba0f7f41 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Mon, 20 Jan 2025 22:40:28 +0100 Subject: [PATCH] update tests and add test for FWD REV gradient --- tests/conftest.py | 16 +++ tests/helpers.py | 4 +- tests/test_against_fpm.py | 77 +++++++++--- tests/test_distributed_pm.py | 231 +++++++++++++++++++++++++++++++---- tests/test_sharded_array.py | 147 ++++++++++++++++++++++ 5 files changed, 434 insertions(+), 41 deletions(-) create mode 100644 tests/test_sharded_array.py diff --git a/tests/conftest.py b/tests/conftest.py index 6d91684..1ea04c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,3 +173,19 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor): fpm_mesh = particle_mesh.paint(finalstate.X).value return fpm_mesh + +def compare_sharding(sharding1, sharding2): + def get_axis_size(sharding, idx): + axis_name = sharding.spec[idx] + if axis_name is None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[idx]] + def get_pdims_from_sharding(sharding): + return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))]) + + pdims1 = get_pdims_from_sharding(sharding1) + pdims2 = get_pdims_from_sharding(sharding2) + pdims1 = pdims1 + (1,) * (3 - len(pdims1)) + pdims2 = pdims2 + (1,) * (3 - len(pdims2)) + return pdims1 == pdims2 \ No newline at end of file diff --git a/tests/helpers.py b/tests/helpers.py index 0b85161..40a6253 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,7 +2,7 @@ import jax.numpy as jnp def MSE(x, y): - return jnp.mean((x - y)**2) + return ((x - y)**2).mean() def MSE_3D(x, y): @@ -10,4 +10,4 @@ def MSE_3D(x, y): def MSRE(x, y): - return jnp.mean(((x - y) / y)**2) + return (((x - y) / y)**2).mean() diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 6d17939..9a7bc93 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -3,24 +3,30 @@ from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from helpers import MSE, MSRE from jax import numpy as jnp +from jaxdecomp import ShardedArray from jaxpm.distributed import uniform_particles from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.utils import power_spectrum - +import jax _TOLERANCE = 1e-4 _PM_TOLERANCE = 1e-3 @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field, fpm_lpt2_field, cosmo, order): + fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} particles = uniform_particles(mesh_shape) + if shardedArrayAPI: + particles = ShardedArray(particles) + initial_conditions = ShardedArray(initial_conditions) + # Initial displacement dx, _, _ = lpt(cosmo, initial_conditions, @@ -31,44 +37,61 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) - _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + lpt_field_arr, = jax.tree.leaves(lpt_field) + _, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + if shardedArrayAPI: + assert type(dx) == ShardedArray + assert type(lpt_field) == ShardedArray + @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field, fpm_lpt2_field, cosmo, order): + fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} + if shardedArrayAPI: + initial_conditions = ShardedArray(initial_conditions) + # Initial displacement dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) lpt_field = cic_paint_dx(dx) fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field - - _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + lpt_field_arr, = jax.tree.leaves(lpt_field) + _, jpm_ps = power_spectrum(lpt_field_arr, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSE(lpt_field_arr, fpm_ref_field) < _TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + if shardedArrayAPI: + assert type(dx) == ShardedArray + assert type(lpt_field) == ShardedArray @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order): + cosmo, order , shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} particles = uniform_particles(mesh_shape) + if shardedArrayAPI: + particles = ShardedArray(particles) + initial_conditions = ShardedArray(initial_conditions) + # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, @@ -76,7 +99,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -87,7 +110,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, saveat = SaveAt(t1=True) - y0 = jnp.stack([particles + dx, p]) + y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p) solutions = diffeqsolve(ode_fn, solver, @@ -95,6 +118,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -102,27 +126,37 @@ def test_nbody_absolute(simulation_config, initial_conditions, fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 - _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + final_field_arr, = jax.tree.leaves(final_field) + _, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + if shardedArrayAPI: + assert type(dx) == ShardedArray + assert type( solutions.ys[-1, 0]) == ShardedArray + assert type(final_field) == ShardedArray + @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order): + cosmo, order , shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} + if shardedArrayAPI: + initial_conditions = ShardedArray(initial_conditions) + # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) solver = Dopri5() controller = PIDController(rtol=1e-9, @@ -133,7 +167,7 @@ def test_nbody_relative(simulation_config, initial_conditions, saveat = SaveAt(t1=True) - y0 = jnp.stack([dx, p]) + y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p) solutions = diffeqsolve(ode_fn, solver, @@ -141,6 +175,7 @@ def test_nbody_relative(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -148,8 +183,14 @@ def test_nbody_relative(simulation_config, initial_conditions, fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 - _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + final_field_arr, = jax.tree.leaves(final_field) + _, jpm_ps = power_spectrum(final_field_arr, box_shape=box_shape) _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) - assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSE(final_field_arr, fpm_ref_field) < _PM_TOLERANCE assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + + if shardedArrayAPI: + assert type(dx) == ShardedArray + assert type( solutions.ys[-1, 0]) == ShardedArray + assert type(final_field) == ShardedArray diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index fd683ab..8054408 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -1,4 +1,4 @@ -from conftest import initialize_distributed +from conftest import initialize_distributed , compare_sharding initialize_distributed() # ignore : E402 @@ -12,38 +12,48 @@ from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 - -from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.pm import pm_forces # noqa : E402 +from jaxpm.distributed import uniform_particles , fft3d # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 - +from jaxdecomp import ShardedArray # noqa : E402 +from functools import partial # noqa : E402 +import jax_cosmo as jc # noqa : E402 _TOLERANCE = 3.0 # 🙃🙃 @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("absolute_painting", [True, False]) +@pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, - absolute_painting): + absolute_painting,shardedArrayAPI): mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN cosmo._workspace = {} + if shardedArrayAPI: + ic = ShardedArray(initial_conditions) + else: + ic = initial_conditions + if absolute_painting: particles = uniform_particles(mesh_shape) + if shardedArrayAPI: + particles = ShardedArray(particles) # Initial displacement dx, p, _ = lpt(cosmo, - initial_conditions, + ic, particles, a=0.1, order=order) - ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) - y0 = jnp.stack([particles + dx, p]) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) + y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) else: - dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) + dx, p, _ = lpt(cosmo, ic, a=0.1, order=order) ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) - y0 = jnp.stack([dx, p]) + make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) + y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -59,6 +69,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t0=0.1, t1=1.0, dt0=None, + args=cosmo, y0=y0, stepsize_controller=controller, saveat=saveat) @@ -76,17 +87,22 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 - initial_conditions = lax.with_sharding_constraint(initial_conditions, + ic = lax.with_sharding_constraint(initial_conditions, sharding) - print(f"sharded initial conditions {initial_conditions.sharding}") + print(f"sharded initial conditions {ic.sharding}") + + if shardedArrayAPI: + ic = ShardedArray(ic , sharding) cosmo._workspace = {} if absolute_painting: particles = uniform_particles(mesh_shape, sharding=sharding) + if shardedArrayAPI: + particles = ShardedArray(particles, sharding) # Initial displacement dx, p, _ = lpt(cosmo, - initial_conditions, + ic, particles, a=0.1, order=order, @@ -94,26 +110,26 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, + make_diffrax_ode( mesh_shape, halo_size=halo_size, sharding=sharding)) - y0 = jnp.stack([particles + dx, p]) + y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) else: dx, p, _ = lpt(cosmo, - initial_conditions, + ic, a=0.1, order=order, halo_size=halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(cosmo, + make_diffrax_ode( mesh_shape, paint_absolute_pos=False, halo_size=halo_size, sharding=sharding)) - y0 = jnp.stack([dx, p]) + y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -130,6 +146,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t1=1.0, dt0=None, y0=y0, + args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -143,10 +160,182 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, halo_size=halo_size, sharding=sharding) - multi_device_final_field = process_allgather(multi_device_final_field, + multi_device_final_field_g = process_allgather(multi_device_final_field, tiled=True) - mse = MSE(single_device_final_field, multi_device_final_field) + single_device_final_field_arr, = jax.tree.leaves(single_device_final_field) + multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g) + mse = MSE(single_device_final_field_arr, multi_device_final_field_arr) print(f"MSE is {mse}") + if shardedArrayAPI: + assert type(multi_device_final_field) == ShardedArray + assert compare_sharding(multi_device_final_field.sharding , sharding) + assert compare_sharding(multi_device_final_field.initial_sharding , sharding) + assert mse < _TOLERANCE + + + +@pytest.mark.distributed +@pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("absolute_painting", [True, False]) +def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2, + absolute_painting): + + mesh_shape, box_shape = simulation_config + # SINGLE DEVICE RUN + cosmo._workspace = {} + + mesh = jax.make_mesh((1, 8), ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) + + print(f"sharded initial conditions {initial_conditions.sharding}") + + + initial_conditions = ShardedArray(initial_conditions , sharding) + + cosmo._workspace = {} + + @jax.jit + def forward_model(initial_conditions , cosmo): + + + if absolute_painting: + particles = uniform_particles(mesh_shape, sharding=sharding) + particles = ShardedArray(particles, sharding) + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + + ode_fn = ODETerm( + make_diffrax_ode( + mesh_shape, + halo_size=halo_size, + sharding=sharding)) + + y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) + else: + dx, p, _ = lpt(cosmo, + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + ode_fn = ODETerm( + make_diffrax_ode( + mesh_shape, + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) + y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + args=cosmo, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), + solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + else: + multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + + return multi_device_final_field + + @jax.jit + def model(initial_conditions , cosmo): + + final_field = forward_model(initial_conditions , cosmo) + final_field, = jax.tree.leaves(final_field) + + return MSE(final_field, + nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + obs_val = model(initial_conditions , cosmo) + + shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5 + + good_grads = jax.grad(model)(initial_conditions , cosmo) + off_grads = jax.grad(model)(shifted_initial_conditions , cosmo) + + assert compare_sharding(good_grads.sharding , initial_conditions.sharding) + assert compare_sharding(off_grads.sharding , initial_conditions.sharding) + + +@pytest.mark.distributed +@pytest.mark.parametrize("absolute_painting", [True, False]) +def test_fwd_rev_gradients(cosmo,absolute_painting): + + mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0) + # SINGLE DEVICE RUN + cosmo._workspace = {} + + mesh = jax.make_mesh((1, 8), ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) + + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) + + print(f"sharded initial conditions {initial_conditions.sharding}") + initial_conditions = ShardedArray(initial_conditions , sharding) + + cosmo._workspace = {} + + @partial(jax.jit , static_argnums=(3,4 , 5)) + def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None): + + paint_absolute_pos = particles is not None + if particles is None: + particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, + shape=(*ic.shape, 3)) , initial_conditions) + + a = jnp.atleast_1d(a) + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + delta_k = fft3d(initial_conditions) + initial_force = pm_forces(particles, + delta=delta_k, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) + + return initial_force[...,0] + + particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None + forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) + back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) + fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) + + assert compare_sharding(forces.sharding , initial_conditions.sharding) + assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding) + assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding) diff --git a/tests/test_sharded_array.py b/tests/test_sharded_array.py new file mode 100644 index 0000000..d73e525 --- /dev/null +++ b/tests/test_sharded_array.py @@ -0,0 +1,147 @@ +import os +#os.environ["JAX_PLATFORM_NAME"] = "cpu" +#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + + +import os +os.environ["EQX_ON_ERROR"] = "nan" +import jax +import jax.numpy as jnp +import jax_cosmo as jc +from jax.debug import visualize_array_sharding + +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read +from jaxpm.pm import linear_field, lpt, make_diffrax_ode +from functools import partial +from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve +from jaxpm.distributed import uniform_particles + +#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices" + + + +from jax.experimental.mesh_utils import create_device_mesh +from jax.experimental.multihost_utils import process_allgather +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +all_gather = partial(process_allgather, tiled=False) + +pdims = (2, 4) +#devices = create_device_mesh(pdims) +#mesh = Mesh(devices, axis_names=('x', 'y')) +#sharding = NamedSharding(mesh, P('x', 'y')) +sharding = None + + +from typing import NamedTuple +from jaxdecomp import ShardedArray + +mesh_shape = 64 +box_size = 64. +halo_size = 2 +snapshots = (0.5, 1.0) + +class Params(NamedTuple): + omega_c: float + sigma8: float + initial_conditions : jnp.ndarray + +mesh_shape = (mesh_shape,) * 3 +box_size = (box_size,) * 3 +omega_c = 0.25 +sigma8 = 0.8 +# Create a small function to generate the matter power spectrum +k = jnp.logspace(-4, 1, 128) +pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) +pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) + +initial_conditions = linear_field(mesh_shape, + box_size, + pk_fn, + seed=jax.random.PRNGKey(0), + sharding=sharding) + + +#initial_conditions = ShardedArray(initial_conditions, sharding) + +params = Params(omega_c, sigma8, initial_conditions) + + + +@partial(jax.jit , static_argnums=(1 , 2,3,4 )) +def forward_model(params , mesh_shape,box_size,halo_size , snapshots): + + # Create initial conditions + cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8) + particles = uniform_particles(mesh_shape , sharding) + ic_structure = jax.tree.structure(params.initial_conditions) + particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles)) + # Initial displacement + dx, p, f = lpt(cosmo, + params.initial_conditions, + particles, + a=0.1, + order=2, + halo_size=halo_size, + sharding=sharding) + + # Evolve the simulation forward + ode_fn = ODETerm( + make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding)) + solver = LeapfrogMidpoint() + + y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p) + print(f"y0 structure: {jax.tree.structure(y0)}") + + stepsize_controller = ConstantStepSize() + res = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1., + dt0=0.01, + y0=y0, + args=cosmo, + saveat=SaveAt(ts=snapshots), + stepsize_controller=stepsize_controller) + ode_solutions = [sol[0] for sol in res.ys] + + ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1]) + return particles + dx , ode_field + + + ode_field = cic_paint_dx(ode_solutions[-1]) + return dx , ode_field + + + +lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots) + + +import matplotlib.pyplot as plt + +lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles) +#lpt_field = cic_paint_dx(lpt_particles) + +plt.figure(figsize=(12, 6)) +plt.subplot(121) +plt.imshow(lpt_field.sum(axis=0) , cmap='magma') +plt.colorbar() +plt.title('LPT field') +plt.subplot(122) +plt.imshow(ode_field.sum(axis=0) , cmap='magma') +plt.colorbar() +plt.title('ODE field') +plt.show() +plt.close() + +#particles = jax.random.uniform(jax.random.PRNGKey(0), (4 , 4 ,4 , 3), minval=0.1, maxval=0.9) +#field = jax.random.uniform(jax.random.PRNGKey(0), (4, 4, 4)) +# +#partiles = ShardedArray(particles, sharding) +#field = ShardedArray(field, sharding) +# +# +#cic_read_dx(field , particles ) \ No newline at end of file