mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-11 21:50:55 +00:00
format
This commit is contained in:
parent
51ee4dd937
commit
b43cb373a0
1 changed files with 6 additions and 2 deletions
|
@ -18,7 +18,8 @@ import numpy as np
|
||||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
PIDController, SaveAt, diffeqsolve)
|
PIDController, SaveAt, diffeqsolve)
|
||||||
from jax.experimental.multihost_utils import process_allgather
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jax.sharding import PartitionSpec as P, NamedSharding
|
from jax.sharding import NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
|
@ -104,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
ode_fn = ODETerm(
|
ode_fn = ODETerm(
|
||||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
|
make_diffrax_ode(mesh_shape,
|
||||||
|
paint_absolute_pos=False,
|
||||||
|
sharding=sharding,
|
||||||
|
halo_size=halo_size))
|
||||||
|
|
||||||
# Choose solver
|
# Choose solver
|
||||||
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
|
|
Loading…
Add table
Reference in a new issue