mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Fixes in solvers
This commit is contained in:
parent
a0be772f3c
commit
f3599e73da
1 changed files with 224 additions and 27 deletions
251
jaxpm/solvers.py
251
jaxpm/solvers.py
|
@ -1,47 +1,49 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
from functools import partial
|
||||
from typing import Any, Sequence
|
||||
|
||||
import jax
|
||||
import jax_cosmo as jc
|
||||
from diffrax import (AbstractSolver, AbstractStepSizeController,
|
||||
ConstantStepSize, ODETerm, SaveAt, Tsit5, diffeqsolve)
|
||||
from jax import numpy as jnp
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
from jax_cosmo import Cosmology
|
||||
from jaxtyping import Array, Bool, PyTree, Real, Shaped
|
||||
|
||||
import jaxpm as jpm
|
||||
from jaxpm.growth import (growth_factor, growth_factor_second, growth_rate,
|
||||
growth_rate_second)
|
||||
|
||||
|
||||
class PMSolverStatus(Enum):
|
||||
INIT = 0
|
||||
LPT = 1
|
||||
LPT2 = 2
|
||||
ODE = 3
|
||||
DONE = 4
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
@partial(jax.tree_util.register_dataclass,
|
||||
data_fields=['displacements', 'velocities', 'cosmo'],
|
||||
meta_fields=['solver_stats', 'status', 'kvec'])
|
||||
@dataclass
|
||||
class State(object):
|
||||
displacements: Tuple[int, int, int, 3]
|
||||
velocities: Tuple[int, int, int, 3]
|
||||
displacements: Array
|
||||
velocities: Array
|
||||
cosmo: Cosmology
|
||||
kvec: list[Array, Array, Array]
|
||||
solver_stats: dict[str, Any]
|
||||
status: PMSolverStatus
|
||||
|
||||
def tree_flatten(self):
|
||||
children = (self.displacements, self.velocities, self.cosmo)
|
||||
aux_data = (self.kvec, self.solver_stats, self.status)
|
||||
return (children, aux_data)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
del aux_data
|
||||
return cls(*children)
|
||||
kvec: Sequence
|
||||
solver_stats: dict[str, Any] = None
|
||||
status: PMSolverStatus = PMSolverStatus.INIT
|
||||
|
||||
|
||||
class FastPM(object):
|
||||
initial_conditions: Array
|
||||
initial_delta_k: Array
|
||||
kvec: list[Array, Array, Array]
|
||||
|
||||
def init_state(self, cosmo, particules, kvec, initial_conditions):
|
||||
self.initial_conditions = initial_conditions
|
||||
def init_state(self, cosmo, particules, kvec, initial_field, box_size):
|
||||
self.initial_delta_k = jpm.ops.fftn(initial_field)
|
||||
self.box_size = box_size
|
||||
# Check sharding on this
|
||||
zeros = jnp.zeros_like(particules)
|
||||
state = State(displacements=zeros,
|
||||
|
@ -51,11 +53,206 @@ class FastPM(object):
|
|||
|
||||
return state
|
||||
|
||||
def lpt(state, a=0.1):
|
||||
pass
|
||||
def compute_initial_forces(self, state, delta_k):
|
||||
|
||||
def lpt2(state, a=0.1):
|
||||
pass
|
||||
#TODO this must done in a function generat_ic
|
||||
mesh_shape = state.displacements.shape[:3]
|
||||
box_size = self.box_size
|
||||
ky, kz, kx = state.kvec
|
||||
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
||||
(ky / box_size[1] * mesh_shape[1])**2 +
|
||||
(kz / box_size[1] * mesh_shape[1])**2)
|
||||
delta_k = jpm.ops.interpolate_ic(delta_k, kk, state.cosmo, box_size)
|
||||
kernel_lap = jnp.where(
|
||||
kk == 0, 1.,
|
||||
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
|
||||
pot_k = delta_k * kernel_lap
|
||||
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
|
||||
forces_k = jnp.stack([
|
||||
pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01):
|
||||
pass
|
||||
init_force = jnp.stack(
|
||||
[jpm.ops.ifftn(forces_k[..., i]).real for i in range(3)], axis=-1)
|
||||
|
||||
return init_force, delta_k
|
||||
|
||||
def compute_ode_forces(self, state):
|
||||
|
||||
mesh_shape = self.initial_delta_k.shape
|
||||
box_size = self.box_size
|
||||
print(f"type of state {type(state)}")
|
||||
|
||||
pos = jnp.array(state.displacements)
|
||||
|
||||
ky, kz, kx = state.kvec
|
||||
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
||||
(ky / box_size[1] * mesh_shape[1])**2 +
|
||||
(kz / box_size[1] * mesh_shape[1])**2)
|
||||
delta_k = jpm.painting.cic_paint_dx(pos)
|
||||
kernel_lap = jnp.where(
|
||||
kk == 0, 1.,
|
||||
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
|
||||
pot_k = delta_k * kernel_lap
|
||||
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
|
||||
forces_k = jnp.stack([
|
||||
pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
forces = jnp.stack([
|
||||
jpm.painting.cic_read_dx(jpm.ops.ifftn(forces_k[..., i])).real
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
forces = forces * 1.5 * state.cosmo.Omega_m
|
||||
|
||||
return forces
|
||||
|
||||
def make_ode_fn(self):
|
||||
|
||||
def ode_fn(a, state, args):
|
||||
|
||||
print(f"helo")
|
||||
forces = self.compute_ode_forces(state)
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(
|
||||
state.cosmo, a))) * state.velocities
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 *
|
||||
jnp.sqrt(jc.background.Esqr(state.cosmo, a))) * forces
|
||||
|
||||
state = State(displacements=dpos,
|
||||
velocities=dvel,
|
||||
cosmo=state.cosmo,
|
||||
kvec=state.kvec,
|
||||
status=PMSolverStatus.ODE)
|
||||
return state
|
||||
|
||||
return ode_fn
|
||||
|
||||
def compute_lpt2_source(self, delta_k):
|
||||
|
||||
mesh_shape = self.initial_delta_k.shape
|
||||
box_size = self.box_size
|
||||
|
||||
ky, kz, kx = self.kvec
|
||||
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
|
||||
(ky / box_size[1] * mesh_shape[1])**2 +
|
||||
(kz / box_size[1] * mesh_shape[1])**2)
|
||||
invlaplace_kernel = -jnp.where(kk == 0, 1., 1. /
|
||||
(kx**2 + ky**2 + kz**2))
|
||||
|
||||
pot_k = delta_k * invlaplace_kernel
|
||||
|
||||
# Taken from https://github.com/hsimonfroy/montecosmo
|
||||
# Based on https://arxiv.org/abs/0910.0258
|
||||
delta2 = 0
|
||||
shear_acc = 0
|
||||
for i, ki in enumerate(self.kvec):
|
||||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
shear_ii = jpm.ops.ifft(-ki**2 * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
||||
for kj in self.kvec[i + 1:]:
|
||||
# Substract squared strict-up-triangle terms
|
||||
delta2 -= jpm.ops.ifft(-ki * kj * pot_k)**2
|
||||
|
||||
return delta2
|
||||
|
||||
def lpt(self, state, a=0.1):
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
|
||||
if state.status != PMSolverStatus.INIT:
|
||||
raise ValueError(
|
||||
f"LPT simulation has to be done before the other steps")
|
||||
|
||||
init_force, _ = self.compute_initial_forces(state,
|
||||
self.initial_delta_k)
|
||||
|
||||
dx = growth_factor(state.cosmo, a) * init_force
|
||||
|
||||
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
|
||||
jc.background.Esqr(state.cosmo, a)) * dx
|
||||
|
||||
return State(displacements=dx,
|
||||
velocities=p,
|
||||
cosmo=state.cosmo,
|
||||
kvec=state.kvec,
|
||||
status=PMSolverStatus.LPT)
|
||||
|
||||
def lpt2(self, state, a=0.1):
|
||||
a = jnp.atleast_1d(a)
|
||||
|
||||
if state.status != PMSolverStatus.INIT:
|
||||
raise ValueError(
|
||||
f"LPT2 simulation has to be done in the beginning")
|
||||
|
||||
init_force, delta_k = self.compute_initial_forces(
|
||||
state, self.initial_delta_k)
|
||||
|
||||
dx = growth_factor(state.cosmo, a) * init_force
|
||||
|
||||
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
|
||||
jc.background.Esqr(state.cosmo, a)) * dx
|
||||
|
||||
delta2 = self.compute_lpt2_source(delta_k)
|
||||
|
||||
init_force2 = self.compute_initial_forces(state, delta2)
|
||||
|
||||
dx2 = 3 / 7 * growth_factor_second(state.cosmo, a) * init_force2
|
||||
p2 = a**2 * growth_rate_second(state.cosmo, a) * jnp.sqrt(
|
||||
jc.background.Esqr(state.cosmo, a)) * dx2
|
||||
|
||||
dx += dx2
|
||||
p += p2
|
||||
|
||||
return State(displacements=dx,
|
||||
velocities=p,
|
||||
cosmo=state.cosmo,
|
||||
status=PMSolverStatus.LPT2)
|
||||
|
||||
def nbody(self,
|
||||
state,
|
||||
solver: AbstractSolver,
|
||||
stepsize_controller: AbstractStepSizeController,
|
||||
t0=0.1,
|
||||
t1=1,
|
||||
dt0=0.01):
|
||||
|
||||
if state.status == PMSolverStatus.INIT:
|
||||
state = self.lpt(state, a=t0)
|
||||
elif state.status == PMSolverStatus.ODE or \
|
||||
state.status == PMSolverStatus.DONE:
|
||||
raise ValueError(f"nbody already done on state {state.status}")
|
||||
|
||||
ode_fn = self.make_ode_fn()
|
||||
|
||||
state.status = PMSolverStatus.ODE
|
||||
solution = diffeqsolve(ODETerm(ode_fn),
|
||||
solver,
|
||||
t0=t0,
|
||||
t1=t1,
|
||||
dt0=dt0,
|
||||
y0=state,
|
||||
saveat=SaveAt(t1=True),
|
||||
args=None,
|
||||
stepsize_controller=stepsize_controller)
|
||||
|
||||
return State(displacements=solution.ys.displacements[-1],
|
||||
velocities=solution.ys.velocities[-1],
|
||||
cosmo=state.cosmo,
|
||||
kvec=state.kvec,
|
||||
status=PMSolverStatus.DONE,
|
||||
solver_stats=solution.stats)
|
||||
|
|
Loading…
Add table
Reference in a new issue