mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
pm now accept pytrees
This commit is contained in:
parent
204a9526ec
commit
9e203b5680
1 changed files with 10 additions and 8 deletions
18
jaxpm/pm.py
18
jaxpm/pm.py
|
@ -7,7 +7,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
invlaplace_kernel, longrange_kernel)
|
invlaplace_kernel, longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||||
|
import jax
|
||||||
|
|
||||||
def pm_forces(positions,
|
def pm_forces(positions,
|
||||||
mesh_shape=None,
|
mesh_shape=None,
|
||||||
|
@ -51,9 +51,11 @@ def pm_forces(positions,
|
||||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||||
kvec, r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = [
|
||||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||||
) for i in range(3)], axis=-1) # yapf: disable
|
) for i in range(3)]
|
||||||
|
|
||||||
|
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
@ -71,8 +73,8 @@ def lpt(cosmo,
|
||||||
"""
|
"""
|
||||||
paint_absolute_pos = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
if particles is None:
|
if particles is None:
|
||||||
particles = jnp.zeros_like(initial_conditions,
|
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
||||||
shape=(*initial_conditions.shape, 3))
|
shape=(*ic.shape, 3)) , initial_conditions)
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
|
@ -172,8 +174,7 @@ def make_ode_fn(mesh_shape,
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def make_diffrax_ode(cosmo,
|
def make_diffrax_ode(mesh_shape,
|
||||||
mesh_shape,
|
|
||||||
paint_absolute_pos=True,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
|
@ -183,6 +184,7 @@ def make_diffrax_ode(cosmo,
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
|
cosmo = args
|
||||||
|
|
||||||
forces = pm_forces(pos,
|
forces = pm_forces(pos,
|
||||||
mesh_shape=mesh_shape,
|
mesh_shape=mesh_shape,
|
||||||
|
@ -196,7 +198,7 @@ def make_diffrax_ode(cosmo,
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
return jnp.stack([dpos, dvel])
|
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue