mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
update formatting
This commit is contained in:
parent
6408aff1de
commit
319942a6bc
5 changed files with 113 additions and 96 deletions
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
|
@ -9,15 +10,17 @@ size = jax.process_count()
|
|||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jaxpm.pm import linear_field, lpt
|
||||
from jaxpm.painting import cic_paint
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh
|
||||
|
||||
mesh_shape= [256, 256, 256]
|
||||
box_size = [256.,256.,256.]
|
||||
from jaxpm.painting import cic_paint
|
||||
from jaxpm.pm import linear_field, lpt
|
||||
|
||||
mesh_shape = [256, 256, 256]
|
||||
box_size = [256., 256., 256.]
|
||||
snapshots = jnp.linspace(0.1, 1., 2)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8, seed):
|
||||
# Create a cosmology
|
||||
|
@ -25,38 +28,42 @@ def run_simulation(omega_c, sigma8, seed):
|
|||
|
||||
# Create a small function to generate the matter power spectrum
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
|
||||
).reshape(x.shape)
|
||||
|
||||
# Create initial conditions
|
||||
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
|
||||
|
||||
# Initialize particle displacements
|
||||
|
||||
# Initialize particle displacements
|
||||
dx, p, f = lpt(cosmo, initial_conditions, 1.0)
|
||||
|
||||
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
|
||||
return field
|
||||
|
||||
|
||||
def main(args):
|
||||
# Setting up distributed random numbers
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
# Setting up distributed random numbers
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
|
||||
# Create computing mesh and sharding information
|
||||
devices = mesh_utils.create_device_mesh((2,2))
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
# Create computing mesh and sharding information
|
||||
devices = mesh_utils.create_device_mesh((2, 2))
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
|
||||
# Run the simulation on the compute mesh
|
||||
with mesh:
|
||||
field = run_simulation(0.32, 0.8, key)
|
||||
# Run the simulation on the compute mesh
|
||||
with mesh:
|
||||
field = run_simulation(0.32, 0.8, key)
|
||||
|
||||
print('done')
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
|
||||
# Closing distributed jax
|
||||
jax.distributed.shutdown()
|
||||
|
||||
print('done')
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
|
||||
# Closing distributed jax
|
||||
jax.distributed.shutdown()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue