mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
258 lines
8.6 KiB
Python
258 lines
8.6 KiB
Python
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from functools import partial
|
|
from typing import Any, Sequence
|
|
|
|
import jax
|
|
import jax_cosmo as jc
|
|
from diffrax import (AbstractSolver, AbstractStepSizeController,
|
|
ConstantStepSize, ODETerm, SaveAt, Tsit5, diffeqsolve)
|
|
from jax import numpy as jnp
|
|
from jax_cosmo import Cosmology
|
|
from jaxtyping import Array, Bool, PyTree, Real, Shaped
|
|
|
|
import jaxpm as jpm
|
|
from jaxpm.growth import (growth_factor, growth_factor_second, growth_rate,
|
|
growth_rate_second)
|
|
|
|
|
|
class PMSolverStatus(Enum):
|
|
INIT = 0
|
|
LPT = 1
|
|
LPT2 = 2
|
|
ODE = 3
|
|
DONE = 4
|
|
|
|
|
|
@partial(jax.tree_util.register_dataclass,
|
|
data_fields=['displacements', 'velocities', 'cosmo'],
|
|
meta_fields=['solver_stats', 'status', 'kvec'])
|
|
@dataclass
|
|
class State(object):
|
|
displacements: Array
|
|
velocities: Array
|
|
cosmo: Cosmology
|
|
kvec: Sequence
|
|
solver_stats: dict[str, Any] = None
|
|
status: PMSolverStatus = PMSolverStatus.INIT
|
|
|
|
|
|
class FastPM(object):
|
|
initial_delta_k: Array
|
|
kvec: list[Array, Array, Array]
|
|
|
|
def init_state(self, cosmo, particules, kvec, initial_field, box_size):
|
|
self.initial_delta_k = jpm.ops.fftn(initial_field)
|
|
self.box_size = box_size
|
|
# Check sharding on this
|
|
zeros = jnp.zeros_like(particules)
|
|
state = State(displacements=zeros,
|
|
velocities=zeros,
|
|
cosmo=cosmo,
|
|
kvec=kvec)
|
|
|
|
return state
|
|
|
|
def compute_initial_forces(self, state, delta_k):
|
|
|
|
#TODO this must done in a function generat_ic
|
|
mesh_shape = state.displacements.shape[:3]
|
|
box_size = self.box_size
|
|
ky, kz, kx = state.kvec
|
|
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
|
(ky / box_size[1] * mesh_shape[1])**2 +
|
|
(kz / box_size[1] * mesh_shape[1])**2)
|
|
delta_k = jpm.ops.interpolate_ic(delta_k, kk, state.cosmo, box_size)
|
|
kernel_lap = jnp.where(
|
|
kk == 0, 1.,
|
|
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
|
|
pot_k = delta_k * kernel_lap
|
|
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
|
|
forces_k = jnp.stack([
|
|
pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
|
|
],
|
|
axis=-1)
|
|
|
|
init_force = jnp.stack(
|
|
[jpm.ops.ifftn(forces_k[..., i]).real for i in range(3)], axis=-1)
|
|
|
|
return init_force, delta_k
|
|
|
|
def compute_ode_forces(self, state):
|
|
|
|
mesh_shape = self.initial_delta_k.shape
|
|
box_size = self.box_size
|
|
print(f"type of state {type(state)}")
|
|
|
|
pos = jnp.array(state.displacements)
|
|
|
|
ky, kz, kx = state.kvec
|
|
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
|
(ky / box_size[1] * mesh_shape[1])**2 +
|
|
(kz / box_size[1] * mesh_shape[1])**2)
|
|
delta_k = jpm.painting.cic_paint_dx(pos)
|
|
kernel_lap = jnp.where(
|
|
kk == 0, 1.,
|
|
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
|
|
pot_k = delta_k * kernel_lap
|
|
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
|
|
forces_k = jnp.stack([
|
|
pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
|
|
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
|
|
],
|
|
axis=-1)
|
|
|
|
forces = jnp.stack([
|
|
jpm.painting.cic_read_dx(jpm.ops.ifftn(forces_k[..., i])).real
|
|
for i in range(3)
|
|
],
|
|
axis=-1)
|
|
forces = forces * 1.5 * state.cosmo.Omega_m
|
|
|
|
return forces
|
|
|
|
def make_ode_fn(self):
|
|
|
|
def ode_fn(a, state, args):
|
|
|
|
print(f"helo")
|
|
forces = self.compute_ode_forces(state)
|
|
|
|
# Computes the update of position (drift)
|
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(
|
|
state.cosmo, a))) * state.velocities
|
|
|
|
# Computes the update of velocity (kick)
|
|
dvel = 1. / (a**2 *
|
|
jnp.sqrt(jc.background.Esqr(state.cosmo, a))) * forces
|
|
|
|
state = State(displacements=dpos,
|
|
velocities=dvel,
|
|
cosmo=state.cosmo,
|
|
kvec=state.kvec,
|
|
status=PMSolverStatus.ODE)
|
|
return state
|
|
|
|
return ode_fn
|
|
|
|
def compute_lpt2_source(self, delta_k):
|
|
|
|
mesh_shape = self.initial_delta_k.shape
|
|
box_size = self.box_size
|
|
|
|
ky, kz, kx = self.kvec
|
|
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
|
(ky / box_size[1] * mesh_shape[1])**2 +
|
|
(kz / box_size[1] * mesh_shape[1])**2)
|
|
invlaplace_kernel = -jnp.where(kk == 0, 1., 1. /
|
|
(kx**2 + ky**2 + kz**2))
|
|
|
|
pot_k = delta_k * invlaplace_kernel
|
|
|
|
# Taken from https://github.com/hsimonfroy/montecosmo
|
|
# Based on https://arxiv.org/abs/0910.0258
|
|
delta2 = 0
|
|
shear_acc = 0
|
|
for i, ki in enumerate(self.kvec):
|
|
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
|
shear_ii = jpm.ops.ifft(-ki**2 * pot_k)
|
|
delta2 += shear_ii * shear_acc
|
|
shear_acc += shear_ii
|
|
|
|
for kj in self.kvec[i + 1:]:
|
|
# Substract squared strict-up-triangle terms
|
|
delta2 -= jpm.ops.ifft(-ki * kj * pot_k)**2
|
|
|
|
return delta2
|
|
|
|
def lpt(self, state, a=0.1):
|
|
|
|
a = jnp.atleast_1d(a)
|
|
|
|
if state.status != PMSolverStatus.INIT:
|
|
raise ValueError(
|
|
f"LPT simulation has to be done before the other steps")
|
|
|
|
init_force, _ = self.compute_initial_forces(state,
|
|
self.initial_delta_k)
|
|
|
|
dx = growth_factor(state.cosmo, a) * init_force
|
|
|
|
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
|
|
jc.background.Esqr(state.cosmo, a)) * dx
|
|
|
|
return State(displacements=dx,
|
|
velocities=p,
|
|
cosmo=state.cosmo,
|
|
kvec=state.kvec,
|
|
status=PMSolverStatus.LPT)
|
|
|
|
def lpt2(self, state, a=0.1):
|
|
a = jnp.atleast_1d(a)
|
|
|
|
if state.status != PMSolverStatus.INIT:
|
|
raise ValueError(
|
|
f"LPT2 simulation has to be done in the beginning")
|
|
|
|
init_force, delta_k = self.compute_initial_forces(
|
|
state, self.initial_delta_k)
|
|
|
|
dx = growth_factor(state.cosmo, a) * init_force
|
|
|
|
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
|
|
jc.background.Esqr(state.cosmo, a)) * dx
|
|
|
|
delta2 = self.compute_lpt2_source(delta_k)
|
|
|
|
init_force2 = self.compute_initial_forces(state, delta2)
|
|
|
|
dx2 = 3 / 7 * growth_factor_second(state.cosmo, a) * init_force2
|
|
p2 = a**2 * growth_rate_second(state.cosmo, a) * jnp.sqrt(
|
|
jc.background.Esqr(state.cosmo, a)) * dx2
|
|
|
|
dx += dx2
|
|
p += p2
|
|
|
|
return State(displacements=dx,
|
|
velocities=p,
|
|
cosmo=state.cosmo,
|
|
status=PMSolverStatus.LPT2)
|
|
|
|
def nbody(self,
|
|
state,
|
|
solver: AbstractSolver,
|
|
stepsize_controller: AbstractStepSizeController,
|
|
t0=0.1,
|
|
t1=1,
|
|
dt0=0.01):
|
|
|
|
if state.status == PMSolverStatus.INIT:
|
|
state = self.lpt(state, a=t0)
|
|
elif state.status == PMSolverStatus.ODE or \
|
|
state.status == PMSolverStatus.DONE:
|
|
raise ValueError(f"nbody already done on state {state.status}")
|
|
|
|
ode_fn = self.make_ode_fn()
|
|
|
|
state.status = PMSolverStatus.ODE
|
|
solution = diffeqsolve(ODETerm(ode_fn),
|
|
solver,
|
|
t0=t0,
|
|
t1=t1,
|
|
dt0=dt0,
|
|
y0=state,
|
|
saveat=SaveAt(t1=True),
|
|
args=None,
|
|
stepsize_controller=stepsize_controller)
|
|
|
|
return State(displacements=solution.ys.displacements[-1],
|
|
velocities=solution.ys.velocities[-1],
|
|
cosmo=state.cosmo,
|
|
kvec=state.kvec,
|
|
status=PMSolverStatus.DONE,
|
|
solver_stats=solution.stats)
|