Fixes in solvers

This commit is contained in:
Wassim KABALAN 2024-07-09 02:35:24 +02:00
parent a0be772f3c
commit f3599e73da

View file

@ -1,47 +1,49 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Optional, Tuple from functools import partial
from typing import Any, Sequence
import jax import jax
import jax_cosmo as jc
from diffrax import (AbstractSolver, AbstractStepSizeController,
ConstantStepSize, ODETerm, SaveAt, Tsit5, diffeqsolve)
from jax import numpy as jnp from jax import numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_cosmo import Cosmology from jax_cosmo import Cosmology
from jaxtyping import Array, Bool, PyTree, Real, Shaped from jaxtyping import Array, Bool, PyTree, Real, Shaped
import jaxpm as jpm
from jaxpm.growth import (growth_factor, growth_factor_second, growth_rate,
growth_rate_second)
class PMSolverStatus(Enum): class PMSolverStatus(Enum):
INIT = 0 INIT = 0
LPT = 1 LPT = 1
LPT2 = 2 LPT2 = 2
ODE = 3 ODE = 3
DONE = 4
@register_pytree_node_class @partial(jax.tree_util.register_dataclass,
data_fields=['displacements', 'velocities', 'cosmo'],
meta_fields=['solver_stats', 'status', 'kvec'])
@dataclass @dataclass
class State(object): class State(object):
displacements: Tuple[int, int, int, 3] displacements: Array
velocities: Tuple[int, int, int, 3] velocities: Array
cosmo: Cosmology cosmo: Cosmology
kvec: list[Array, Array, Array] kvec: Sequence
solver_stats: dict[str, Any] solver_stats: dict[str, Any] = None
status: PMSolverStatus status: PMSolverStatus = PMSolverStatus.INIT
def tree_flatten(self):
children = (self.displacements, self.velocities, self.cosmo)
aux_data = (self.kvec, self.solver_stats, self.status)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children)
class FastPM(object): class FastPM(object):
initial_conditions: Array initial_delta_k: Array
kvec: list[Array, Array, Array]
def init_state(self, cosmo, particules, kvec, initial_conditions): def init_state(self, cosmo, particules, kvec, initial_field, box_size):
self.initial_conditions = initial_conditions self.initial_delta_k = jpm.ops.fftn(initial_field)
self.box_size = box_size
# Check sharding on this # Check sharding on this
zeros = jnp.zeros_like(particules) zeros = jnp.zeros_like(particules)
state = State(displacements=zeros, state = State(displacements=zeros,
@ -51,11 +53,206 @@ class FastPM(object):
return state return state
def lpt(state, a=0.1): def compute_initial_forces(self, state, delta_k):
pass
def lpt2(state, a=0.1): #TODO this must done in a function generat_ic
pass mesh_shape = state.displacements.shape[:3]
box_size = self.box_size
ky, kz, kx = state.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
delta_k = jpm.ops.interpolate_ic(delta_k, kk, state.cosmo, box_size)
kernel_lap = jnp.where(
kk == 0, 1.,
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
pot_k = delta_k * kernel_lap
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
forces_k = jnp.stack([
pot_k * 1j / 6.0 *
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
],
axis=-1)
def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01): init_force = jnp.stack(
pass [jpm.ops.ifftn(forces_k[..., i]).real for i in range(3)], axis=-1)
return init_force, delta_k
def compute_ode_forces(self, state):
mesh_shape = self.initial_delta_k.shape
box_size = self.box_size
print(f"type of state {type(state)}")
pos = jnp.array(state.displacements)
ky, kz, kx = state.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
delta_k = jpm.painting.cic_paint_dx(pos)
kernel_lap = jnp.where(
kk == 0, 1.,
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
pot_k = delta_k * kernel_lap
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
forces_k = jnp.stack([
pot_k * 1j / 6.0 *
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
],
axis=-1)
forces = jnp.stack([
jpm.painting.cic_read_dx(jpm.ops.ifftn(forces_k[..., i])).real
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * state.cosmo.Omega_m
return forces
def make_ode_fn(self):
def ode_fn(a, state, args):
print(f"helo")
forces = self.compute_ode_forces(state)
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(
state.cosmo, a))) * state.velocities
# Computes the update of velocity (kick)
dvel = 1. / (a**2 *
jnp.sqrt(jc.background.Esqr(state.cosmo, a))) * forces
state = State(displacements=dpos,
velocities=dvel,
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.ODE)
return state
return ode_fn
def compute_lpt2_source(self, delta_k):
mesh_shape = self.initial_delta_k.shape
box_size = self.box_size
ky, kz, kx = self.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
invlaplace_kernel = -jnp.where(kk == 0, 1., 1. /
(kx**2 + ky**2 + kz**2))
pot_k = delta_k * invlaplace_kernel
# Taken from https://github.com/hsimonfroy/montecosmo
# Based on https://arxiv.org/abs/0910.0258
delta2 = 0
shear_acc = 0
for i, ki in enumerate(self.kvec):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
shear_ii = jpm.ops.ifft(-ki**2 * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
for kj in self.kvec[i + 1:]:
# Substract squared strict-up-triangle terms
delta2 -= jpm.ops.ifft(-ki * kj * pot_k)**2
return delta2
def lpt(self, state, a=0.1):
a = jnp.atleast_1d(a)
if state.status != PMSolverStatus.INIT:
raise ValueError(
f"LPT simulation has to be done before the other steps")
init_force, _ = self.compute_initial_forces(state,
self.initial_delta_k)
dx = growth_factor(state.cosmo, a) * init_force
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx
return State(displacements=dx,
velocities=p,
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.LPT)
def lpt2(self, state, a=0.1):
a = jnp.atleast_1d(a)
if state.status != PMSolverStatus.INIT:
raise ValueError(
f"LPT2 simulation has to be done in the beginning")
init_force, delta_k = self.compute_initial_forces(
state, self.initial_delta_k)
dx = growth_factor(state.cosmo, a) * init_force
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx
delta2 = self.compute_lpt2_source(delta_k)
init_force2 = self.compute_initial_forces(state, delta2)
dx2 = 3 / 7 * growth_factor_second(state.cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx2
dx += dx2
p += p2
return State(displacements=dx,
velocities=p,
cosmo=state.cosmo,
status=PMSolverStatus.LPT2)
def nbody(self,
state,
solver: AbstractSolver,
stepsize_controller: AbstractStepSizeController,
t0=0.1,
t1=1,
dt0=0.01):
if state.status == PMSolverStatus.INIT:
state = self.lpt(state, a=t0)
elif state.status == PMSolverStatus.ODE or \
state.status == PMSolverStatus.DONE:
raise ValueError(f"nbody already done on state {state.status}")
ode_fn = self.make_ode_fn()
state.status = PMSolverStatus.ODE
solution = diffeqsolve(ODETerm(ode_fn),
solver,
t0=t0,
t1=t1,
dt0=dt0,
y0=state,
saveat=SaveAt(t1=True),
args=None,
stepsize_controller=stepsize_controller)
return State(displacements=solution.ys.displacements[-1],
velocities=solution.ys.velocities[-1],
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.DONE,
solver_stats=solution.stats)