mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
format
This commit is contained in:
parent
82be56836a
commit
4f508b7cb6
3 changed files with 54 additions and 44 deletions
|
@ -133,7 +133,6 @@ def get_local_shape(mesh_shape):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed=None):
|
def normal_field(mesh_shape, seed=None):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
|
|
|
@ -38,6 +38,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, initial_conditions, a, halo_size=0):
|
def lpt(cosmo, initial_conditions, a, halo_size=0):
|
||||||
"""
|
"""
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
jax.distributed.initialize()
|
jax.distributed.initialize()
|
||||||
|
|
||||||
rank = jax.process_index()
|
rank = jax.process_index()
|
||||||
|
@ -8,18 +10,15 @@ size = jax.process_count()
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
|
||||||
from jax.experimental.ode import odeint
|
|
||||||
|
|
||||||
from jaxpm.painting import cic_paint, cic_read , cic_paint_dx , cic_read_dx
|
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
|
||||||
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController, SaveAt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, PartitionSpec as P , NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jaxpm.distributed import normal_field
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
|
||||||
|
|
||||||
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
|
from jaxpm.painting import cic_paint_dx
|
||||||
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
|
||||||
size = 256
|
size = 256
|
||||||
mesh_shape = [size] * 3
|
mesh_shape = [size] * 3
|
||||||
|
@ -38,15 +37,19 @@ if jax.device_count() > 1 :
|
||||||
def run_simulation(omega_c, sigma8):
|
def run_simulation(omega_c, sigma8):
|
||||||
# Create a small function to generate the matter power spectrum
|
# Create a small function to generate the matter power spectrum
|
||||||
k = jnp.logspace(-4, 1, 128)
|
k = jnp.logspace(-4, 1, 128)
|
||||||
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
pk = jc.power.linear_matter_power(
|
||||||
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
|
||||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
|
||||||
# Create initial conditions
|
# Create initial conditions
|
||||||
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))
|
initial_conditions = linear_field(mesh_shape,
|
||||||
|
box_size,
|
||||||
|
pk_fn,
|
||||||
|
seed=jax.random.PRNGKey(0))
|
||||||
|
|
||||||
# Create particles
|
# Create particles
|
||||||
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])
|
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
|
||||||
|
axis=-1).reshape([-1, 3])
|
||||||
|
|
||||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
|
|
||||||
|
@ -55,11 +58,17 @@ def run_simulation(omega_c, sigma8):
|
||||||
|
|
||||||
# Evolve the simulation forward
|
# Evolve the simulation forward
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||||
term = ODETerm(lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
term = ODETerm(
|
||||||
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
|
|
||||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||||
res = diffeqsolve(term, solver, t0=0.1, t1=1., dt0=0.01, y0=jnp.stack([dx, p],axis=0),
|
res = diffeqsolve(term,
|
||||||
|
solver,
|
||||||
|
t0=0.1,
|
||||||
|
t1=1.,
|
||||||
|
dt0=0.01,
|
||||||
|
y0=jnp.stack([dx, p], axis=0),
|
||||||
args=cosmo,
|
args=cosmo,
|
||||||
saveat=SaveAt(ts=snapshots),
|
saveat=SaveAt(ts=snapshots),
|
||||||
stepsize_controller=stepsize_controller)
|
stepsize_controller=stepsize_controller)
|
||||||
|
@ -67,7 +76,9 @@ def run_simulation(omega_c, sigma8):
|
||||||
# Return the simulation volume at requested
|
# Return the simulation volume at requested
|
||||||
states = res.ys
|
states = res.ys
|
||||||
field = cic_paint_dx(dx, halo_size=halo_size)
|
field = cic_paint_dx(dx, halo_size=halo_size)
|
||||||
final_fields = [cic_paint_dx(state[0] , halo_size = halo_size) for state in states]
|
final_fields = [
|
||||||
|
cic_paint_dx(state[0], halo_size=halo_size) for state in states
|
||||||
|
]
|
||||||
|
|
||||||
return initial_conditions, field, final_fields, res.stats
|
return initial_conditions, field, final_fields, res.stats
|
||||||
|
|
||||||
|
@ -89,7 +100,6 @@ np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||||
|
|
||||||
if final_fields is not None:
|
if final_fields is not None:
|
||||||
for i, final_field in enumerate(final_fields):
|
for i, final_field in enumerate(final_fields):
|
||||||
np.save(f'final_field_{i}_{rank}.npy',
|
np.save(f'final_field_{i}_{rank}.npy', final_field.addressable_data(0))
|
||||||
final_field.addressable_data(0))
|
|
||||||
|
|
||||||
print(f"Finished!!")
|
print(f"Finished!!")
|
Loading…
Add table
Reference in a new issue