mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
update code
This commit is contained in:
parent
e0c118a540
commit
21373b89ee
7 changed files with 84 additions and 100 deletions
51
jaxpm/pm.py
51
jaxpm/pm.py
|
@ -1,11 +1,9 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
|
||||
normal_field, zeros)
|
||||
from jaxpm.distributed import (fft3d, ifft3d,
|
||||
normal_field)
|
||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||
growth_rate, growth_rate_second)
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
|
@ -29,17 +27,17 @@ def pm_forces(positions,
|
|||
mesh_shape = delta.shape
|
||||
|
||||
if paint_absolute_pos:
|
||||
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
|
||||
x,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = lambda x: cic_read(
|
||||
x, positions, halo_size=halo_size, sharding=sharding)
|
||||
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
|
||||
pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = lambda grid_mesh, pos: cic_read(
|
||||
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||
else:
|
||||
paint_fn = partial(cic_paint_dx,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding)
|
||||
paint_fn = lambda disp: cic_paint_dx(
|
||||
disp, halo_size=halo_size, sharding=sharding)
|
||||
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||
|
||||
if delta is None:
|
||||
field = paint_fn(positions)
|
||||
|
@ -55,7 +53,7 @@ def pm_forces(positions,
|
|||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
@ -73,6 +71,8 @@ def lpt(cosmo,
|
|||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
@ -167,25 +167,27 @@ def make_ode_fn(mesh_shape,
|
|||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
#dpos = dpos if not paint_absolute_pos else dpos + pos
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||
def make_diffrax_ode(cosmo, mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
"""
|
||||
State is an array [position, velocities]
|
||||
|
||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
forces = pm_forces(
|
||||
pos, mesh_shape, halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -197,7 +199,6 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
|||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def pgd_correction(pos, mesh_shape, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue