diff --git a/jaxpm/solvers.py b/jaxpm/solvers.py index 90cd8f3..fa2dc6e 100644 --- a/jaxpm/solvers.py +++ b/jaxpm/solvers.py @@ -1,47 +1,49 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Optional, Tuple +from functools import partial +from typing import Any, Sequence import jax +import jax_cosmo as jc +from diffrax import (AbstractSolver, AbstractStepSizeController, + ConstantStepSize, ODETerm, SaveAt, Tsit5, diffeqsolve) from jax import numpy as jnp -from jax.tree_util import register_pytree_node_class from jax_cosmo import Cosmology from jaxtyping import Array, Bool, PyTree, Real, Shaped +import jaxpm as jpm +from jaxpm.growth import (growth_factor, growth_factor_second, growth_rate, + growth_rate_second) + class PMSolverStatus(Enum): INIT = 0 LPT = 1 LPT2 = 2 ODE = 3 + DONE = 4 -@register_pytree_node_class +@partial(jax.tree_util.register_dataclass, + data_fields=['displacements', 'velocities', 'cosmo'], + meta_fields=['solver_stats', 'status', 'kvec']) @dataclass class State(object): - displacements: Tuple[int, int, int, 3] - velocities: Tuple[int, int, int, 3] + displacements: Array + velocities: Array cosmo: Cosmology - kvec: list[Array, Array, Array] - solver_stats: dict[str, Any] - status: PMSolverStatus - - def tree_flatten(self): - children = (self.displacements, self.velocities, self.cosmo) - aux_data = (self.kvec, self.solver_stats, self.status) - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del aux_data - return cls(*children) + kvec: Sequence + solver_stats: dict[str, Any] = None + status: PMSolverStatus = PMSolverStatus.INIT class FastPM(object): - initial_conditions: Array + initial_delta_k: Array + kvec: list[Array, Array, Array] - def init_state(self, cosmo, particules, kvec, initial_conditions): - self.initial_conditions = initial_conditions + def init_state(self, cosmo, particules, kvec, initial_field, box_size): + self.initial_delta_k = jpm.ops.fftn(initial_field) + self.box_size = box_size # Check sharding on this zeros = jnp.zeros_like(particules) state = State(displacements=zeros, @@ -51,11 +53,206 @@ class FastPM(object): return state - def lpt(state, a=0.1): - pass + def compute_initial_forces(self, state, delta_k): - def lpt2(state, a=0.1): - pass + #TODO this must done in a function generat_ic + mesh_shape = state.displacements.shape[:3] + box_size = self.box_size + ky, kz, kx = state.kvec + kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 + + (ky / box_size[1] * mesh_shape[1])**2 + + (kz / box_size[1] * mesh_shape[1])**2) + delta_k = jpm.ops.interpolate_ic(delta_k, kk, state.cosmo, box_size) + kernel_lap = jnp.where( + kk == 0, 1., + 1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange + pot_k = delta_k * kernel_lap + # Forces have to be a Z pencil because they are going to be IFFT back to X pencil + forces_k = jnp.stack([ + pot_k * 1j / 6.0 * + (8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 * + (8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 * + (8 * jnp.sin(kz) - jnp.sin(2 * kz)) + ], + axis=-1) - def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01): - pass + init_force = jnp.stack( + [jpm.ops.ifftn(forces_k[..., i]).real for i in range(3)], axis=-1) + + return init_force, delta_k + + def compute_ode_forces(self, state): + + mesh_shape = self.initial_delta_k.shape + box_size = self.box_size + print(f"type of state {type(state)}") + + pos = jnp.array(state.displacements) + + ky, kz, kx = state.kvec + kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 + + (ky / box_size[1] * mesh_shape[1])**2 + + (kz / box_size[1] * mesh_shape[1])**2) + delta_k = jpm.painting.cic_paint_dx(pos) + kernel_lap = jnp.where( + kk == 0, 1., + 1. / (kx**2 + ky**2 + kz**2)) # Laplace kernel + longrange + pot_k = delta_k * kernel_lap + # Forces have to be a Z pencil because they are going to be IFFT back to X pencil + forces_k = jnp.stack([ + pot_k * 1j / 6.0 * + (8 * jnp.sin(kx) - jnp.sin(2 * kx)), pot_k * 1j / 6.0 * + (8 * jnp.sin(ky) - jnp.sin(2 * ky)), pot_k * 1j / 6.0 * + (8 * jnp.sin(kz) - jnp.sin(2 * kz)) + ], + axis=-1) + + forces = jnp.stack([ + jpm.painting.cic_read_dx(jpm.ops.ifftn(forces_k[..., i])).real + for i in range(3) + ], + axis=-1) + forces = forces * 1.5 * state.cosmo.Omega_m + + return forces + + def make_ode_fn(self): + + def ode_fn(a, state, args): + + print(f"helo") + forces = self.compute_ode_forces(state) + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr( + state.cosmo, a))) * state.velocities + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * + jnp.sqrt(jc.background.Esqr(state.cosmo, a))) * forces + + state = State(displacements=dpos, + velocities=dvel, + cosmo=state.cosmo, + kvec=state.kvec, + status=PMSolverStatus.ODE) + return state + + return ode_fn + + def compute_lpt2_source(self, delta_k): + + mesh_shape = self.initial_delta_k.shape + box_size = self.box_size + + ky, kz, kx = self.kvec + kk = jnp.sqrt((kx / box_size[0] * mesh_shape[0])**2 + + (ky / box_size[1] * mesh_shape[1])**2 + + (kz / box_size[1] * mesh_shape[1])**2) + invlaplace_kernel = -jnp.where(kk == 0, 1., 1. / + (kx**2 + ky**2 + kz**2)) + + pot_k = delta_k * invlaplace_kernel + + # Taken from https://github.com/hsimonfroy/montecosmo + # Based on https://arxiv.org/abs/0910.0258 + delta2 = 0 + shear_acc = 0 + for i, ki in enumerate(self.kvec): + # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... + shear_ii = jpm.ops.ifft(-ki**2 * pot_k) + delta2 += shear_ii * shear_acc + shear_acc += shear_ii + + for kj in self.kvec[i + 1:]: + # Substract squared strict-up-triangle terms + delta2 -= jpm.ops.ifft(-ki * kj * pot_k)**2 + + return delta2 + + def lpt(self, state, a=0.1): + + a = jnp.atleast_1d(a) + + if state.status != PMSolverStatus.INIT: + raise ValueError( + f"LPT simulation has to be done before the other steps") + + init_force, _ = self.compute_initial_forces(state, + self.initial_delta_k) + + dx = growth_factor(state.cosmo, a) * init_force + + p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt( + jc.background.Esqr(state.cosmo, a)) * dx + + return State(displacements=dx, + velocities=p, + cosmo=state.cosmo, + kvec=state.kvec, + status=PMSolverStatus.LPT) + + def lpt2(self, state, a=0.1): + a = jnp.atleast_1d(a) + + if state.status != PMSolverStatus.INIT: + raise ValueError( + f"LPT2 simulation has to be done in the beginning") + + init_force, delta_k = self.compute_initial_forces( + state, self.initial_delta_k) + + dx = growth_factor(state.cosmo, a) * init_force + + p = a**2 * growth_rate(state.cosmo, a) * jnp.sqrt( + jc.background.Esqr(state.cosmo, a)) * dx + + delta2 = self.compute_lpt2_source(delta_k) + + init_force2 = self.compute_initial_forces(state, delta2) + + dx2 = 3 / 7 * growth_factor_second(state.cosmo, a) * init_force2 + p2 = a**2 * growth_rate_second(state.cosmo, a) * jnp.sqrt( + jc.background.Esqr(state.cosmo, a)) * dx2 + + dx += dx2 + p += p2 + + return State(displacements=dx, + velocities=p, + cosmo=state.cosmo, + status=PMSolverStatus.LPT2) + + def nbody(self, + state, + solver: AbstractSolver, + stepsize_controller: AbstractStepSizeController, + t0=0.1, + t1=1, + dt0=0.01): + + if state.status == PMSolverStatus.INIT: + state = self.lpt(state, a=t0) + elif state.status == PMSolverStatus.ODE or \ + state.status == PMSolverStatus.DONE: + raise ValueError(f"nbody already done on state {state.status}") + + ode_fn = self.make_ode_fn() + + state.status = PMSolverStatus.ODE + solution = diffeqsolve(ODETerm(ode_fn), + solver, + t0=t0, + t1=t1, + dt0=dt0, + y0=state, + saveat=SaveAt(t1=True), + args=None, + stepsize_controller=stepsize_controller) + + return State(displacements=solution.ys.displacements[-1], + velocities=solution.ys.velocities[-1], + cosmo=state.cosmo, + kvec=state.kvec, + status=PMSolverStatus.DONE, + solver_stats=solution.stats)