JaxPM/jaxpm/solvers.py
Wassim KABALAN da06f4dba8 Add solvers
2024-07-08 00:24:07 +02:00

61 lines
1.5 KiB
Python

from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Tuple
import jax
from jax import numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_cosmo import Cosmology
from jaxtyping import Array, Bool, PyTree, Real, Shaped
class PMSolverStatus(Enum):
INIT = 0
LPT = 1
LPT2 = 2
ODE = 3
@register_pytree_node_class
@dataclass
class State(object):
displacements: Tuple[int, int, int, 3]
velocities: Tuple[int, int, int, 3]
cosmo: Cosmology
kvec: list[Array, Array, Array]
solver_stats: dict[str, Any]
status: PMSolverStatus
def tree_flatten(self):
children = (self.displacements, self.velocities, self.cosmo)
aux_data = (self.kvec, self.solver_stats, self.status)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children)
class FastPM(object):
initial_conditions: Array
def init_state(self, cosmo, particules, kvec, initial_conditions):
self.initial_conditions = initial_conditions
# Check sharding on this
zeros = jnp.zeros_like(particules)
state = State(displacements=zeros,
velocities=zeros,
cosmo=cosmo,
kvec=kvec)
return state
def lpt(state, a=0.1):
pass
def lpt2(state, a=0.1):
pass
def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01):
pass