adding example of distributed solution

This commit is contained in:
EiffL 2024-07-09 17:45:28 -04:00
parent a2811c0606
commit a742065ffd
5 changed files with 192 additions and 62 deletions

62
dev/jaxdecomp.py Normal file
View file

@ -0,0 +1,62 @@
import argparse
import jax
import numpy as np
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.pm import linear_field, lpt
from jaxpm.painting import cic_paint
from jax.experimental import mesh_utils
from jax.sharding import Mesh
mesh_shape= [256, 256, 256]
box_size = [256.,256.,256.]
snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit
def run_simulation(omega_c, sigma8, seed):
# Create a cosmology
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)
# Create initial conditions
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
# Initialize particle displacements
dx, p, f = lpt(cosmo, initial_conditions, 1.0)
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
return field
def main(args):
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
devices = mesh_utils.create_device_mesh((2,2))
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Run the simulation on the compute mesh
with mesh:
field = run_simulation(0.32, 0.8, key)
print('done')
np.save(f'field_{rank}.npy', field.addressable_data(0))
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
args = parser.parse_args()
main(args)

50
jaxpm/distributed.py Normal file
View file

@ -0,0 +1,50 @@
from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
try:
import jaxdecomp
distributed = True
except ImportError:
print("jaxdecomp not installed. Distributed functions will not work.")
distributed = False
import jax.numpy as jnp
from jax._src import mesh as mesh_lib
from jax.experimental.shard_map import shard_map
def autoshmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
return f
else:
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def fft3d(x):
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
else:
return jnp.fft.rfftn(x)
def ifft3d(x):
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pifft3d(x).real
else:
return jnp.fft.irfftn(x)
def get_local_shape(mesh_shape):
""" Helper function to get the local size of a mesh given the global size.
"""
if mesh_lib.thread_resources.env.physical_mesh.empty:
return mesh_shape
else:
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]]

View file

@ -1,24 +1,33 @@
from jaxpm.distributed import autoshmap
from jax.sharding import PartitionSpec as P
from functools import partial
import jax.numpy as jnp
import numpy as np
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
def fftk(shape, dtype=np.float32):
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
Generate Fourier transform wave numbers for a given mesh.
k.append(kd.astype(dtype))
del kd, kdshape
return k
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
@partial(
autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def gradient_kernel(kvec, direction, order=1):
"""
@ -60,11 +69,7 @@ def laplace_kernel(kvec):
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
wts = jnp.where(kk == 0, 1., 1. / kk)
return wts

View file

@ -3,13 +3,25 @@ import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jax.sharding import PartitionSpec as P
from functools import partial
from jaxpm.distributed import autoshmap
def cic_paint(mesh, positions, weight=None):
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
out_specs=P('x', 'y'))
def cic_paint(mesh, displacement, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
"""
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -34,11 +46,22 @@ def cic_paint(mesh, positions, weight=None):
return mesh
def cic_read(mesh, positions):
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y')),
out_specs=P('x', 'y'))
def cic_read(mesh, displacement):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
displacement: [nx,ny,nz, 3]
"""
# Compute the position of the particles on a regular grid
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -52,7 +75,7 @@ def cic_read(mesh, positions):
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
def cic_paint_2d(mesh, positions, weight):

View file

@ -1,12 +1,15 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.sharding import PartitionSpec as P
from jaxpm.growth import dGfa, growth_factor, growth_rate
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
longrange_kernel)
from jaxpm.painting import cic_paint, cic_read
from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape
from functools import partial
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
"""
@ -17,26 +20,34 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
kvec = fftk(mesh_shape)
if delta is None:
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
delta_k = fft3d(cic_paint(jnp.zeros(mesh_shape), positions))
else:
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
# Computes gravitational potential
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
r_split=r_split)
# Computes gravitational forces
return jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
cic_read(ifft3d(gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
],
axis=-1)
def lpt(cosmo, initial_conditions, positions, a):
def lpt(cosmo, initial_conditions, a, particles_shape=None):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta=initial_conditions)
if particles_shape is None:
particles_shape = initial_conditions.shape
local_mesh_shape = get_local_shape(particles_shape)
displacement = autoshmap(
partial(jnp.zeros, shape=local_mesh_shape+[3], dtype='float32'),
in_specs=(),
out_specs=P('x', 'y'))() # yapf: disable
initial_force = pm_forces(displacement, delta=initial_conditions)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
@ -56,9 +67,15 @@ def linear_field(mesh_shape, box_size, pk, seed):
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
# Initialize a random field with one slice on each gpu
local_mesh_shape = get_local_shape(mesh_shape)
field = autoshmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(seed) # yapf: disable
field = fft3d(field) * pkmesh**0.5
field = ifft3d(field)
return field
@ -81,30 +98,3 @@ def make_ode_fn(mesh_shape):
return dpos, dvel
return nbody_ode
def pgd_correction(pos, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
args:
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
return dpos_pgd