mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
update demo
This commit is contained in:
parent
0433c615f3
commit
2f509932f5
1 changed files with 9 additions and 6 deletions
|
@ -14,18 +14,19 @@ import numpy as np
|
||||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
SaveAt, diffeqsolve)
|
SaveAt, diffeqsolve)
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jax.experimental.multihost_utils import process_allgather
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
|
||||||
size = 256
|
size = 64
|
||||||
mesh_shape = [size] * 3
|
mesh_shape = [size] * 3
|
||||||
box_size = [float(size)] * 3
|
box_size = [float(size)] * 3
|
||||||
snapshots = jnp.linspace(0.1, 1., 4)
|
snapshots = jnp.linspace(0.1, 1., 4)
|
||||||
halo_size = 32
|
halo_size = 4
|
||||||
pdims = (1, 1)
|
pdims = (1, 1)
|
||||||
mesh = None
|
mesh = None
|
||||||
sharding = None
|
sharding = None
|
||||||
|
@ -59,9 +60,9 @@ def run_simulation(omega_c, sigma8):
|
||||||
0.1,
|
0.1,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
return initial_conditions, cic_paint_dx(dx,
|
# return initial_conditions, cic_paint_dx(dx,
|
||||||
halo_size=halo_size,
|
# halo_size=halo_size,
|
||||||
sharding=sharding), None, None
|
# sharding=sharding), None, None
|
||||||
# Evolve the simulation forward
|
# Evolve the simulation forward
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||||
term = ODETerm(
|
term = ODETerm(
|
||||||
|
@ -91,6 +92,8 @@ def run_simulation(omega_c, sigma8):
|
||||||
|
|
||||||
|
|
||||||
# Run the simulation
|
# Run the simulation
|
||||||
|
distributed_str = "distributed" if mesh is not None else "single device"
|
||||||
|
print(f"running {distributed_str} simulation")
|
||||||
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
||||||
|
|
||||||
# # Print the statistics
|
# # Print the statistics
|
||||||
|
|
Loading…
Add table
Reference in a new issue