mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-12 14:10:55 +00:00
update to leapfrog
This commit is contained in:
parent
5a587fd402
commit
a160a3faa9
1 changed files with 4 additions and 6 deletions
|
@ -11,7 +11,7 @@ size = jax.device_count()
|
|||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
@ -27,7 +27,7 @@ snapshots = jnp.linspace(0.1, 1., 4)
|
|||
halo_size = 32
|
||||
pdims = (1, 1)
|
||||
if jax.device_count() > 1:
|
||||
pdims = (4, 2)
|
||||
pdims = (8, 1)
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
@ -51,15 +51,13 @@ def run_simulation(omega_c, sigma8):
|
|||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
||||
return initial_conditions, cic_paint_dx(dx,
|
||||
halo_size=halo_size), None, None
|
||||
# Evolve the simulation forward
|
||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||
term = ODETerm(
|
||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||
solver = Dopri5()
|
||||
solver = LeapfrogMidpoint()
|
||||
|
||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||
stepsize_controller = ConstantStepSize()
|
||||
res = diffeqsolve(term,
|
||||
solver,
|
||||
t0=0.1,
|
||||
|
|
Loading…
Add table
Reference in a new issue