mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-12 22:20:55 +00:00
update to leapfrog
This commit is contained in:
parent
5a587fd402
commit
a160a3faa9
1 changed files with 4 additions and 6 deletions
|
@ -11,7 +11,7 @@ size = jax.device_count()
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
@ -27,7 +27,7 @@ snapshots = jnp.linspace(0.1, 1., 4)
|
||||||
halo_size = 32
|
halo_size = 32
|
||||||
pdims = (1, 1)
|
pdims = (1, 1)
|
||||||
if jax.device_count() > 1:
|
if jax.device_count() > 1:
|
||||||
pdims = (4, 2)
|
pdims = (8, 1)
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
@ -51,15 +51,13 @@ def run_simulation(omega_c, sigma8):
|
||||||
|
|
||||||
# Initial displacement
|
# Initial displacement
|
||||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
||||||
return initial_conditions, cic_paint_dx(dx,
|
|
||||||
halo_size=halo_size), None, None
|
|
||||||
# Evolve the simulation forward
|
# Evolve the simulation forward
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||||
term = ODETerm(
|
term = ODETerm(
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
solver = Dopri5()
|
solver = LeapfrogMidpoint()
|
||||||
|
|
||||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
stepsize_controller = ConstantStepSize()
|
||||||
res = diffeqsolve(term,
|
res = diffeqsolve(term,
|
||||||
solver,
|
solver,
|
||||||
t0=0.1,
|
t0=0.1,
|
||||||
|
|
Loading…
Add table
Reference in a new issue