mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
update demo
This commit is contained in:
parent
0433c615f3
commit
2f509932f5
1 changed files with 9 additions and 6 deletions
|
@ -14,18 +14,19 @@ import numpy as np
|
|||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||
SaveAt, diffeqsolve)
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
size = 256
|
||||
size = 64
|
||||
mesh_shape = [size] * 3
|
||||
box_size = [float(size)] * 3
|
||||
snapshots = jnp.linspace(0.1, 1., 4)
|
||||
halo_size = 32
|
||||
halo_size = 4
|
||||
pdims = (1, 1)
|
||||
mesh = None
|
||||
sharding = None
|
||||
|
@ -59,9 +60,9 @@ def run_simulation(omega_c, sigma8):
|
|||
0.1,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
return initial_conditions, cic_paint_dx(dx,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding), None, None
|
||||
# return initial_conditions, cic_paint_dx(dx,
|
||||
# halo_size=halo_size,
|
||||
# sharding=sharding), None, None
|
||||
# Evolve the simulation forward
|
||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||
term = ODETerm(
|
||||
|
@ -91,6 +92,8 @@ def run_simulation(omega_c, sigma8):
|
|||
|
||||
|
||||
# Run the simulation
|
||||
distributed_str = "distributed" if mesh is not None else "single device"
|
||||
print(f"running {distributed_str} simulation")
|
||||
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
||||
|
||||
# # Print the statistics
|
||||
|
|
Loading…
Add table
Reference in a new issue