mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Add solvers
This commit is contained in:
parent
8b9287184a
commit
da06f4dba8
1 changed files with 61 additions and 0 deletions
61
jaxpm/solvers.py
Normal file
61
jaxpm/solvers.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
from jax_cosmo import Cosmology
|
||||
from jaxtyping import Array, Bool, PyTree, Real, Shaped
|
||||
|
||||
|
||||
class PMSolverStatus(Enum):
|
||||
INIT = 0
|
||||
LPT = 1
|
||||
LPT2 = 2
|
||||
ODE = 3
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
@dataclass
|
||||
class State(object):
|
||||
displacements: Tuple[int, int, int, 3]
|
||||
velocities: Tuple[int, int, int, 3]
|
||||
cosmo: Cosmology
|
||||
kvec: list[Array, Array, Array]
|
||||
solver_stats: dict[str, Any]
|
||||
status: PMSolverStatus
|
||||
|
||||
def tree_flatten(self):
|
||||
children = (self.displacements, self.velocities, self.cosmo)
|
||||
aux_data = (self.kvec, self.solver_stats, self.status)
|
||||
return (children, aux_data)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
del aux_data
|
||||
return cls(*children)
|
||||
|
||||
|
||||
class FastPM(object):
|
||||
initial_conditions: Array
|
||||
|
||||
def init_state(self, cosmo, particules, kvec, initial_conditions):
|
||||
self.initial_conditions = initial_conditions
|
||||
# Check sharding on this
|
||||
zeros = jnp.zeros_like(particules)
|
||||
state = State(displacements=zeros,
|
||||
velocities=zeros,
|
||||
cosmo=cosmo,
|
||||
kvec=kvec)
|
||||
|
||||
return state
|
||||
|
||||
def lpt(state, a=0.1):
|
||||
pass
|
||||
|
||||
def lpt2(state, a=0.1):
|
||||
pass
|
||||
|
||||
def nbody(state, solver, stepsize_controller, t0=0.1, t1=1, dt0=0.01):
|
||||
pass
|
Loading…
Add table
Reference in a new issue