mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
update formatting
This commit is contained in:
parent
6408aff1de
commit
319942a6bc
5 changed files with 113 additions and 96 deletions
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -9,15 +10,17 @@ size = jax.process_count()
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
from jaxpm.pm import linear_field, lpt
|
|
||||||
from jaxpm.painting import cic_paint
|
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh
|
from jax.sharding import Mesh
|
||||||
|
|
||||||
mesh_shape= [256, 256, 256]
|
from jaxpm.painting import cic_paint
|
||||||
box_size = [256.,256.,256.]
|
from jaxpm.pm import linear_field, lpt
|
||||||
|
|
||||||
|
mesh_shape = [256, 256, 256]
|
||||||
|
box_size = [256., 256., 256.]
|
||||||
snapshots = jnp.linspace(0.1, 1., 2)
|
snapshots = jnp.linspace(0.1, 1., 2)
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def run_simulation(omega_c, sigma8, seed):
|
def run_simulation(omega_c, sigma8, seed):
|
||||||
# Create a cosmology
|
# Create a cosmology
|
||||||
|
@ -25,8 +28,10 @@ def run_simulation(omega_c, sigma8, seed):
|
||||||
|
|
||||||
# Create a small function to generate the matter power spectrum
|
# Create a small function to generate the matter power spectrum
|
||||||
k = jnp.logspace(-4, 1, 128)
|
k = jnp.logspace(-4, 1, 128)
|
||||||
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
pk = jc.power.linear_matter_power(
|
||||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
|
||||||
|
).reshape(x.shape)
|
||||||
|
|
||||||
# Create initial conditions
|
# Create initial conditions
|
||||||
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
|
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
|
||||||
|
@ -37,26 +42,28 @@ def run_simulation(omega_c, sigma8, seed):
|
||||||
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
|
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# Setting up distributed random numbers
|
# Setting up distributed random numbers
|
||||||
master_key = jax.random.PRNGKey(42)
|
master_key = jax.random.PRNGKey(42)
|
||||||
key = jax.random.split(master_key, size)[rank]
|
key = jax.random.split(master_key, size)[rank]
|
||||||
|
|
||||||
# Create computing mesh and sharding information
|
# Create computing mesh and sharding information
|
||||||
devices = mesh_utils.create_device_mesh((2,2))
|
devices = mesh_utils.create_device_mesh((2, 2))
|
||||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
|
|
||||||
# Run the simulation on the compute mesh
|
# Run the simulation on the compute mesh
|
||||||
with mesh:
|
with mesh:
|
||||||
field = run_simulation(0.32, 0.8, key)
|
field = run_simulation(0.32, 0.8, key)
|
||||||
|
|
||||||
print('done')
|
print('done')
|
||||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||||
|
|
||||||
|
# Closing distributed jax
|
||||||
|
jax.distributed.shutdown()
|
||||||
|
|
||||||
# Closing distributed jax
|
|
||||||
jax.distributed.shutdown()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -16,10 +16,10 @@ from jax.experimental.shard_map import shard_map
|
||||||
|
|
||||||
|
|
||||||
def autoshmap(f: Callable,
|
def autoshmap(f: Callable,
|
||||||
in_specs: Specs,
|
in_specs: Specs,
|
||||||
out_specs: Specs,
|
out_specs: Specs,
|
||||||
check_rep: bool = True,
|
check_rep: bool = True,
|
||||||
auto: frozenset[AxisName] = frozenset()):
|
auto: frozenset[AxisName] = frozenset()):
|
||||||
"""Helper function to wrap the provided function in a shard map if
|
"""Helper function to wrap the provided function in a shard map if
|
||||||
the code is being executed in a mesh context."""
|
the code is being executed in a mesh context."""
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
|
@ -28,23 +28,28 @@ def autoshmap(f: Callable,
|
||||||
else:
|
else:
|
||||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||||
|
|
||||||
|
|
||||||
def fft3d(x):
|
def fft3d(x):
|
||||||
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||||
else:
|
else:
|
||||||
return jnp.fft.rfftn(x)
|
return jnp.fft.rfftn(x)
|
||||||
|
|
||||||
|
|
||||||
def ifft3d(x):
|
def ifft3d(x):
|
||||||
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
return jaxdecomp.pifft3d(x).real
|
return jaxdecomp.pifft3d(x).real
|
||||||
else:
|
else:
|
||||||
return jnp.fft.irfftn(x)
|
return jnp.fft.irfftn(x)
|
||||||
|
|
||||||
|
|
||||||
def get_local_shape(mesh_shape):
|
def get_local_shape(mesh_shape):
|
||||||
""" Helper function to get the local size of a mesh given the global size.
|
""" Helper function to get the local size of a mesh given the global size.
|
||||||
"""
|
"""
|
||||||
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||||
return mesh_shape
|
return mesh_shape
|
||||||
else:
|
else:
|
||||||
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
||||||
return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]]
|
return [
|
||||||
|
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||||
|
]
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from jaxpm.distributed import autoshmap
|
|
||||||
from jax.sharding import PartitionSpec as P
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
|
||||||
def fftk(shape, dtype=np.float32):
|
def fftk(shape, dtype=np.float32):
|
||||||
"""
|
"""
|
||||||
Generate Fourier transform wave numbers for a given mesh.
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -16,18 +18,19 @@ def fftk(shape, dtype=np.float32):
|
||||||
list: List of wave number arrays for each dimension in
|
list: List of wave number arrays for each dimension in
|
||||||
the order [kx, ky, kz].
|
the order [kx, ky, kz].
|
||||||
"""
|
"""
|
||||||
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
||||||
@partial(
|
|
||||||
autoshmap,
|
@partial(autoshmap,
|
||||||
in_specs=(P('x'), P('y'), P(None)),
|
in_specs=(P('x'), P('y'), P(None)),
|
||||||
out_specs=(P('x'), P(None, 'y'), P(None)))
|
out_specs=(P('x'), P(None, 'y'), P(None)))
|
||||||
def get_kvec(ky, kz, kx):
|
def get_kvec(ky, kz, kx):
|
||||||
return (ky.reshape([-1, 1, 1]),
|
return (ky.reshape([-1, 1, 1]),
|
||||||
kz.reshape([1, -1, 1]),
|
kz.reshape([1, -1, 1]),
|
||||||
kx.reshape([1, 1, -1])) # yapf: disable
|
kx.reshape([1, 1, -1])) # yapf: disable
|
||||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||||
# to the order of dimensions in the transposed FFT
|
# to the order of dimensions in the transposed FFT
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,26 +1,28 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from jaxpm.kernels import cic_compensation, fftk
|
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from functools import partial
|
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
from jaxpm.kernels import cic_compensation, fftk
|
||||||
|
|
||||||
|
|
||||||
@partial(autoshmap,
|
@partial(autoshmap,
|
||||||
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
|
in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))
|
out_specs=P('x', 'y'))
|
||||||
def cic_paint(mesh, displacement, weight=None):
|
def cic_paint(mesh, displacement, weight=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
displacement field: [nx, ny, nz, 3]
|
displacement field: [nx, ny, nz, 3]
|
||||||
"""
|
"""
|
||||||
part_shape = displacement.shape
|
part_shape = displacement.shape
|
||||||
positions = jnp.stack(jnp.meshgrid(
|
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
|
||||||
jnp.arange(part_shape[0]),
|
jnp.arange(part_shape[1]),
|
||||||
jnp.arange(part_shape[1]),
|
jnp.arange(part_shape[2]),
|
||||||
jnp.arange(part_shape[2]),
|
indexing='ij'),
|
||||||
indexing='ij'), axis=-1) + displacement
|
axis=-1) + displacement
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
@ -46,9 +48,7 @@ def cic_paint(mesh, displacement, weight=None):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
@partial(autoshmap,
|
@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
|
||||||
in_specs=(P('x', 'y'), P('x','y')),
|
|
||||||
out_specs=P('x', 'y'))
|
|
||||||
def cic_read(mesh, displacement):
|
def cic_read(mesh, displacement):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
|
@ -56,11 +56,11 @@ def cic_read(mesh, displacement):
|
||||||
"""
|
"""
|
||||||
# Compute the position of the particles on a regular grid
|
# Compute the position of the particles on a regular grid
|
||||||
part_shape = displacement.shape
|
part_shape = displacement.shape
|
||||||
positions = jnp.stack(jnp.meshgrid(
|
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
|
||||||
jnp.arange(part_shape[0]),
|
jnp.arange(part_shape[1]),
|
||||||
jnp.arange(part_shape[1]),
|
jnp.arange(part_shape[2]),
|
||||||
jnp.arange(part_shape[2]),
|
indexing='ij'),
|
||||||
indexing='ij'), axis=-1) + displacement
|
axis=-1) + displacement
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
@ -75,7 +75,8 @@ def cic_read(mesh, displacement):
|
||||||
jnp.array(mesh.shape))
|
jnp.array(mesh.shape))
|
||||||
|
|
||||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
|
||||||
|
displacement.shape[:-1])
|
||||||
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d
|
||||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||||
longrange_kernel)
|
longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_read
|
from jaxpm.painting import cic_paint, cic_read
|
||||||
from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Reference in a new issue