Fixes in solvers

This commit is contained in:
Wassim KABALAN 2024-07-09 02:35:24 +02:00
parent a0be772f3c
commit f3599e73da

View file

@ -1,47 +1,49 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Tuple
from functools import partial
from typing import Any, Sequence
import jax
import jax_cosmo as jc
from diffrax import (AbstractSolver, AbstractStepSizeController,
ConstantStepSize, ODETerm, SaveAt, Tsit5, diffeqsolve)
from jax import numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_cosmo import Cosmology
from jaxtyping import Array, Bool, PyTree, Real, Shaped
import jaxpm as jpm
from jaxpm.growth import (growth_factor, growth_factor_second, growth_rate,
growth_rate_second)
class PMSolverStatus(Enum):
INIT = 0
LPT = 1
LPT2 = 2
ODE = 3
DONE = 4
@register_pytree_node_class
@partial(jax.tree_util.register_dataclass,
data_fields=['displacements', 'velocities', 'cosmo'],
meta_fields=['solver_stats', 'status', 'kvec'])
@dataclass
class State(object):
displacements: Tuple[int, int, int, 3]
velocities: Tuple[int, int, int, 3]
displacements: Array
velocities: Array
cosmo: Cosmology
kvec: list[Array, Array, Array]
solver_stats: dict[str, Any]
status: PMSolverStatus
def tree_flatten(self):
children = (self.displacements, self.velocities, self.cosmo)
aux_data = (self.kvec, self.solver_stats, self.status)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children)
kvec: Sequence
solver_stats: dict[str, Any] = None
status: PMSolverStatus = PMSolverStatus.INIT
class FastPM(object):
initial_conditions: Array
initial_delta_k: Array
kvec: list[Array, Array, Array]
def init_state(self, cosmo, particules, kvec, initial_conditions):
self.initial_conditions = initial_conditions
def init_state(self, cosmo, particules, kvec, initial_field, box_size):
self.initial_delta_k = jpm.ops.fftn(initial_field)
self.box_size = box_size
# Check sharding on this
zeros = jnp.zeros_like(particules)
state = State(displacements=zeros,
@ -51,11 +53,206 @@ class FastPM(object):
return state
def lpt(state, a=0.1):
pass
def compute_initial_forces(self, state, delta_k):
def lpt2(state, a=0.1):
pass
#TODO this must done in a function generat_ic
mesh_shape = state.displacements.shape[:3]
box_size = self.box_size
ky, kz, kx = state.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
delta_k = jpm.ops.interpolate_ic(delta_k, kk, state.cosmo, box_size)
kernel_lap = jnp.where(
kk == 0, 1.,
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
pot_k = delta_k * kernel_lap
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
forces_k = jnp.stack([
pot_k * 1j / 6.0 *
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
],
axis=-1)
def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01):
pass
init_force = jnp.stack(
[jpm.ops.ifftn(forces_k[..., i]).real for i in range(3)], axis=-1)
return init_force, delta_k
def compute_ode_forces(self, state):
mesh_shape = self.initial_delta_k.shape
box_size = self.box_size
print(f"type of state {type(state)}")
pos = jnp.array(state.displacements)
ky, kz, kx = state.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
delta_k = jpm.painting.cic_paint_dx(pos)
kernel_lap = jnp.where(
kk == 0, 1.,
1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange
pot_k = delta_k * kernel_lap
# Forces have to be a Z pencil because they are going to be IFFT back to X pencil
forces_k = jnp.stack([
pot_k * 1j / 6.0 *
(8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 *
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz))
],
axis=-1)
forces = jnp.stack([
jpm.painting.cic_read_dx(jpm.ops.ifftn(forces_k[..., i])).real
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * state.cosmo.Omega_m
return forces
def make_ode_fn(self):
def ode_fn(a, state, args):
print(f"helo")
forces = self.compute_ode_forces(state)
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(
state.cosmo, a))) * state.velocities
# Computes the update of velocity (kick)
dvel = 1. / (a**2 *
jnp.sqrt(jc.background.Esqr(state.cosmo, a))) * forces
state = State(displacements=dpos,
velocities=dvel,
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.ODE)
return state
return ode_fn
def compute_lpt2_source(self, delta_k):
mesh_shape = self.initial_delta_k.shape
box_size = self.box_size
ky, kz, kx = self.kvec
kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 +
(ky / box_size[1] * mesh_shape[1])**2 +
(kz / box_size[1] * mesh_shape[1])**2)
invlaplace_kernel = -jnp.where(kk == 0, 1., 1. /
(kx**2 + ky**2 + kz**2))
pot_k = delta_k * invlaplace_kernel
# Taken from https://github.com/hsimonfroy/montecosmo
# Based on https://arxiv.org/abs/0910.0258
delta2 = 0
shear_acc = 0
for i, ki in enumerate(self.kvec):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
shear_ii = jpm.ops.ifft(-ki**2 * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
for kj in self.kvec[i + 1:]:
# Substract squared strict-up-triangle terms
delta2 -= jpm.ops.ifft(-ki * kj * pot_k)**2
return delta2
def lpt(self, state, a=0.1):
a = jnp.atleast_1d(a)
if state.status != PMSolverStatus.INIT:
raise ValueError(
f"LPT simulation has to be done before the other steps")
init_force, _ = self.compute_initial_forces(state,
self.initial_delta_k)
dx = growth_factor(state.cosmo, a) * init_force
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx
return State(displacements=dx,
velocities=p,
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.LPT)
def lpt2(self, state, a=0.1):
a = jnp.atleast_1d(a)
if state.status != PMSolverStatus.INIT:
raise ValueError(
f"LPT2 simulation has to be done in the beginning")
init_force, delta_k = self.compute_initial_forces(
state, self.initial_delta_k)
dx = growth_factor(state.cosmo, a) * init_force
p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx
delta2 = self.compute_lpt2_source(delta_k)
init_force2 = self.compute_initial_forces(state, delta2)
dx2 = 3 / 7 * growth_factor_second(state.cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(state.cosmo, a) * jnp.sqrt(
jc.background.Esqr(state.cosmo, a)) * dx2
dx += dx2
p += p2
return State(displacements=dx,
velocities=p,
cosmo=state.cosmo,
status=PMSolverStatus.LPT2)
def nbody(self,
state,
solver: AbstractSolver,
stepsize_controller: AbstractStepSizeController,
t0=0.1,
t1=1,
dt0=0.01):
if state.status == PMSolverStatus.INIT:
state = self.lpt(state, a=t0)
elif state.status == PMSolverStatus.ODE or \
state.status == PMSolverStatus.DONE:
raise ValueError(f"nbody already done on state {state.status}")
ode_fn = self.make_ode_fn()
state.status = PMSolverStatus.ODE
solution = diffeqsolve(ODETerm(ode_fn),
solver,
t0=t0,
t1=t1,
dt0=dt0,
y0=state,
saveat=SaveAt(t1=True),
args=None,
stepsize_controller=stepsize_controller)
return State(displacements=solution.ys.displacements[-1],
velocities=solution.ys.velocities[-1],
cosmo=state.cosmo,
kvec=state.kvec,
status=PMSolverStatus.DONE,
solver_stats=solution.stats)