From 9e203b56800b312850665955c7c0866f9c12b430 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 18 Jan 2025 01:13:51 +0100 Subject: [PATCH] pm now accept pytrees --- jaxpm/pm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index e34d584..a4fffc7 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -7,7 +7,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx - +import jax def pm_forces(positions, mesh_shape=None, @@ -51,9 +51,11 @@ def pm_forces(positions, pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( kvec, r_split=r_split) # Computes gravitational forces - forces = jnp.stack([ + forces = [ read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions - ) for i in range(3)], axis=-1) # yapf: disable + ) for i in range(3)] + + forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2]) return forces @@ -71,8 +73,8 @@ def lpt(cosmo, """ paint_absolute_pos = particles is not None if particles is None: - particles = jnp.zeros_like(initial_conditions, - shape=(*initial_conditions.shape, 3)) + particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, + shape=(*ic.shape, 3)) , initial_conditions) a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) @@ -172,8 +174,7 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, - mesh_shape, +def make_diffrax_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): @@ -183,6 +184,7 @@ def make_diffrax_ode(cosmo, state is a tuple (position, velocities) """ pos, vel = state + cosmo = args forces = pm_forces(pos, mesh_shape=mesh_shape, @@ -196,7 +198,7 @@ def make_diffrax_ode(cosmo, # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - return jnp.stack([dpos, dvel]) + return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel) return nbody_ode