update formatting

This commit is contained in:
EiffL 2024-07-09 18:02:57 -04:00
parent 6408aff1de
commit 319942a6bc
5 changed files with 113 additions and 96 deletions

View file

@ -1,4 +1,5 @@
import argparse import argparse
import jax import jax
import numpy as np import numpy as np
@ -9,15 +10,17 @@ size = jax.process_count()
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jaxpm.pm import linear_field, lpt
from jaxpm.painting import cic_paint
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh from jax.sharding import Mesh
mesh_shape= [256, 256, 256] from jaxpm.painting import cic_paint
box_size = [256.,256.,256.] from jaxpm.pm import linear_field, lpt
mesh_shape = [256, 256, 256]
box_size = [256., 256., 256.]
snapshots = jnp.linspace(0.1, 1., 2) snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit @jax.jit
def run_simulation(omega_c, sigma8, seed): def run_simulation(omega_c, sigma8, seed):
# Create a cosmology # Create a cosmology
@ -25,8 +28,10 @@ def run_simulation(omega_c, sigma8, seed):
# Create a small function to generate the matter power spectrum # Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128) k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) pk = jc.power.linear_matter_power(
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape) jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
).reshape(x.shape)
# Create initial conditions # Create initial conditions
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed) initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
@ -37,13 +42,14 @@ def run_simulation(omega_c, sigma8, seed):
field = cic_paint(jnp.zeros_like(initial_conditions), dx) field = cic_paint(jnp.zeros_like(initial_conditions), dx)
return field return field
def main(args): def main(args):
# Setting up distributed random numbers # Setting up distributed random numbers
master_key = jax.random.PRNGKey(42) master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank] key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information # Create computing mesh and sharding information
devices = mesh_utils.create_device_mesh((2,2)) devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices.T, axis_names=('x', 'y')) mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Run the simulation on the compute mesh # Run the simulation on the compute mesh
@ -56,6 +62,7 @@ def main(args):
# Closing distributed jax # Closing distributed jax
jax.distributed.shutdown() jax.distributed.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
args = parser.parse_args() args = parser.parse_args()

View file

@ -28,18 +28,21 @@ def autoshmap(f: Callable,
else: else:
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def fft3d(x): def fft3d(x):
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64)) return jaxdecomp.pfft3d(x.astype(jnp.complex64))
else: else:
return jnp.fft.rfftn(x) return jnp.fft.rfftn(x)
def ifft3d(x): def ifft3d(x):
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pifft3d(x).real return jaxdecomp.pifft3d(x).real
else: else:
return jnp.fft.irfftn(x) return jnp.fft.irfftn(x)
def get_local_shape(mesh_shape): def get_local_shape(mesh_shape):
""" Helper function to get the local size of a mesh given the global size. """ Helper function to get the local size of a mesh given the global size.
""" """
@ -47,4 +50,6 @@ def get_local_shape(mesh_shape):
return mesh_shape return mesh_shape
else: else:
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]] return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
]

View file

@ -1,8 +1,10 @@
from jaxpm.distributed import autoshmap
from jax.sharding import PartitionSpec as P
from functools import partial from functools import partial
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
def fftk(shape, dtype=np.float32): def fftk(shape, dtype=np.float32):
@ -17,8 +19,8 @@ def fftk(shape, dtype=np.float32):
the order [kx, ky, kz]. the order [kx, ky, kz].
""" """
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(
autoshmap, @partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)), in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None))) out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx): def get_kvec(ky, kz, kx):
@ -29,6 +31,7 @@ def fftk(shape, dtype=np.float32):
# to the order of dimensions in the transposed FFT # to the order of dimensions in the transposed FFT
return kx, ky, kz return kx, ky, kz
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):
""" """
Computes the gradient kernel in the requested direction Computes the gradient kernel in the requested direction

View file

@ -1,14 +1,16 @@
from functools import partial
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from functools import partial
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
from jaxpm.kernels import cic_compensation, fftk
@partial(autoshmap, @partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y'), P('x','y')), in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y')) out_specs=P('x', 'y'))
def cic_paint(mesh, displacement, weight=None): def cic_paint(mesh, displacement, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
@ -16,11 +18,11 @@ def cic_paint(mesh, displacement, weight=None):
displacement field: [nx, ny, nz, 3] displacement field: [nx, ny, nz, 3]
""" """
part_shape = displacement.shape part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid( positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]), jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]), jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement indexing='ij'),
axis=-1) + displacement
positions = positions.reshape([-1, 3]) positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
@ -46,9 +48,7 @@ def cic_paint(mesh, displacement, weight=None):
return mesh return mesh
@partial(autoshmap, @partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
in_specs=(P('x', 'y'), P('x','y')),
out_specs=P('x', 'y'))
def cic_read(mesh, displacement): def cic_read(mesh, displacement):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
@ -56,11 +56,11 @@ def cic_read(mesh, displacement):
""" """
# Compute the position of the particles on a regular grid # Compute the position of the particles on a regular grid
part_shape = displacement.shape part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid( positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]), jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]), jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement indexing='ij'),
axis=-1) + displacement
positions = positions.reshape([-1, 3]) positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
@ -75,7 +75,8 @@ def cic_read(mesh, displacement):
jnp.array(mesh.shape)) jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1]) neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
displacement.shape[:-1])
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):

View file

@ -1,15 +1,16 @@
from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d
from jaxpm.growth import dGfa, growth_factor, growth_rate from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel) longrange_kernel)
from jaxpm.painting import cic_paint, cic_read from jaxpm.painting import cic_paint, cic_read
from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape
from functools import partial
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
""" """