mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
pm now accept pytrees
This commit is contained in:
parent
204a9526ec
commit
9e203b5680
1 changed files with 10 additions and 8 deletions
18
jaxpm/pm.py
18
jaxpm/pm.py
|
@ -7,7 +7,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
|||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
invlaplace_kernel, longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
|
||||
import jax
|
||||
|
||||
def pm_forces(positions,
|
||||
mesh_shape=None,
|
||||
|
@ -51,9 +51,11 @@ def pm_forces(positions,
|
|||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
forces = [
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
) for i in range(3)]
|
||||
|
||||
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
|
||||
|
||||
return forces
|
||||
|
||||
|
@ -71,8 +73,8 @@ def lpt(cosmo,
|
|||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
||||
shape=(*ic.shape, 3)) , initial_conditions)
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
@ -172,8 +174,7 @@ def make_ode_fn(mesh_shape,
|
|||
return nbody_ode
|
||||
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
def make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
@ -183,6 +184,7 @@ def make_diffrax_ode(cosmo,
|
|||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
cosmo = args
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
|
@ -196,7 +198,7 @@ def make_diffrax_ode(cosmo,
|
|||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return jnp.stack([dpos, dvel])
|
||||
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue