mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
fixed a whole lot of issues
This commit is contained in:
parent
429813ad92
commit
72ae0fd88f
5 changed files with 251 additions and 155 deletions
120
jaxpm/pm.py
120
jaxpm/pm.py
|
@ -1,107 +1,99 @@
|
|||
import jax
|
||||
from jax.experimental.maps import xmap
|
||||
import jax.numpy as jnp
|
||||
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.ops import fft3d, ifft3d, zeros
|
||||
from jaxpm.ops import fft3d, ifft3d, zeros, normal
|
||||
from jaxpm.kernels import fftk, apply_gradient_laplace
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
mesh_shape = delta_k.shape
|
||||
|
||||
kvec = fftk(mesh_shape, comms=comms)
|
||||
|
||||
if delta_k is None:
|
||||
delta, token = cic_paint(zeros(mesh_shape,comms=comms),
|
||||
positions,
|
||||
halo_size=halo_size, token=token, comms=comms)
|
||||
delta_k, token = fft3d(delta, token=token, comms=comms)
|
||||
|
||||
# Computes gravitational potential
|
||||
forces_k = apply_gradient_laplace(kfield, kvec)
|
||||
delta = cic_paint(zeros(mesh_shape, comms=comms),
|
||||
positions,
|
||||
halo_size=halo_size, comms=comms)
|
||||
delta_k = fft3d(delta, comms=comms)
|
||||
|
||||
# Computes gravitational forces
|
||||
fx, token = ifft3d(forces_k[...,0], token=token, comms=comms)
|
||||
fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms)
|
||||
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
|
||||
forces_k = apply_gradient_laplace(delta_k, kvec)
|
||||
|
||||
fy, token = ifft3d(forces_k[...,1], token=token, comms=comms)
|
||||
fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms)
|
||||
# Interpolate forces at the position of particles
|
||||
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
|
||||
positions, halo_size=halo_size, comms=comms)
|
||||
for i in range(3)], axis=-1)
|
||||
|
||||
fz, token = ifft3d(forces_k[...,2], token=token, comms=comms)
|
||||
fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms)
|
||||
|
||||
return jnp.stack([fx,fy,fz],axis=-1), token
|
||||
|
||||
def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms):
|
||||
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
|
||||
"""
|
||||
Computes first order LPT displacement
|
||||
"""
|
||||
initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms)
|
||||
initial_force = pm_forces(
|
||||
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
|
||||
return dx, p, f, comms
|
||||
p = a**2 * growth_rate(cosmo, a) * \
|
||||
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
|
||||
dGfa(cosmo, a) * initial_force
|
||||
return dx, p, f
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
|
||||
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
Generate initial conditions in Fourier space.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
# Sample normal field
|
||||
field = normal(key, mesh_shape, comms=comms)
|
||||
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
return field
|
||||
# Transform to Fourier space
|
||||
kfield = fft3d(field, comms=comms)
|
||||
|
||||
# Rescaling k to physical units
|
||||
kvec = [k / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(kfield.shape,
|
||||
symmetric=False,
|
||||
comms=comms))]
|
||||
|
||||
# Evaluating linear matter powerspectrum
|
||||
k = jnp.logspace(-4, 2, 256)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||
) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
# Multipliyng the field by the proper power spectrum
|
||||
kfield = xmap(lambda kfield, kx, ky, kz:
|
||||
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
|
||||
k, jnp.sqrt(pk)),
|
||||
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
|
||||
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
|
||||
|
||||
return kfield
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape,
|
||||
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def pgd_correction(pos, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
||||
args:
|
||||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
|
||||
return dpos_pgd
|
Loading…
Add table
Add a link
Reference in a new issue