update demo

This commit is contained in:
Wassim KABALAN 2024-10-22 12:16:28 -04:00
parent 0433c615f3
commit 2f509932f5

View file

@ -14,18 +14,19 @@ import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
SaveAt, diffeqsolve)
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
size = 256
size = 64
mesh_shape = [size] * 3
box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32
halo_size = 4
pdims = (1, 1)
mesh = None
sharding = None
@ -59,9 +60,9 @@ def run_simulation(omega_c, sigma8):
0.1,
halo_size=halo_size,
sharding=sharding)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size,
sharding=sharding), None, None
# return initial_conditions, cic_paint_dx(dx,
# halo_size=halo_size,
# sharding=sharding), None, None
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
@ -91,6 +92,8 @@ def run_simulation(omega_c, sigma8):
# Run the simulation
distributed_str = "distributed" if mesh is not None else "single device"
print(f"running {distributed_str} simulation")
init, field, final_fields, stats = run_simulation(0.32, 0.8)
# # Print the statistics