mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
adding example of distributed solution
This commit is contained in:
parent
a2811c0606
commit
a742065ffd
5 changed files with 192 additions and 62 deletions
62
dev/jaxdecomp.py
Normal file
62
dev/jaxdecomp.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
import argparse
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Setting up distributed jax
|
||||||
|
jax.distributed.initialize()
|
||||||
|
rank = jax.process_index()
|
||||||
|
size = jax.process_count()
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
|
from jaxpm.pm import linear_field, lpt
|
||||||
|
from jaxpm.painting import cic_paint
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import Mesh
|
||||||
|
|
||||||
|
mesh_shape= [256, 256, 256]
|
||||||
|
box_size = [256.,256.,256.]
|
||||||
|
snapshots = jnp.linspace(0.1, 1., 2)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def run_simulation(omega_c, sigma8, seed):
|
||||||
|
# Create a cosmology
|
||||||
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
|
|
||||||
|
# Create a small function to generate the matter power spectrum
|
||||||
|
k = jnp.logspace(-4, 1, 128)
|
||||||
|
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)
|
||||||
|
|
||||||
|
# Create initial conditions
|
||||||
|
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
|
||||||
|
|
||||||
|
# Initialize particle displacements
|
||||||
|
dx, p, f = lpt(cosmo, initial_conditions, 1.0)
|
||||||
|
|
||||||
|
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
|
||||||
|
return field
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Setting up distributed random numbers
|
||||||
|
master_key = jax.random.PRNGKey(42)
|
||||||
|
key = jax.random.split(master_key, size)[rank]
|
||||||
|
|
||||||
|
# Create computing mesh and sharding information
|
||||||
|
devices = mesh_utils.create_device_mesh((2,2))
|
||||||
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
|
|
||||||
|
# Run the simulation on the compute mesh
|
||||||
|
with mesh:
|
||||||
|
field = run_simulation(0.32, 0.8, key)
|
||||||
|
|
||||||
|
print('done')
|
||||||
|
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||||
|
|
||||||
|
# Closing distributed jax
|
||||||
|
jax.distributed.shutdown()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
50
jaxpm/distributed.py
Normal file
50
jaxpm/distributed.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
from typing import Any, Callable, Hashable
|
||||||
|
|
||||||
|
Specs = Any
|
||||||
|
AxisName = Hashable
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jaxdecomp
|
||||||
|
distributed = True
|
||||||
|
except ImportError:
|
||||||
|
print("jaxdecomp not installed. Distributed functions will not work.")
|
||||||
|
distributed = False
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax._src import mesh as mesh_lib
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
|
||||||
|
|
||||||
|
def autoshmap(f: Callable,
|
||||||
|
in_specs: Specs,
|
||||||
|
out_specs: Specs,
|
||||||
|
check_rep: bool = True,
|
||||||
|
auto: frozenset[AxisName] = frozenset()):
|
||||||
|
"""Helper function to wrap the provided function in a shard map if
|
||||||
|
the code is being executed in a mesh context."""
|
||||||
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
|
if mesh.empty:
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||||
|
|
||||||
|
def fft3d(x):
|
||||||
|
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
|
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||||
|
else:
|
||||||
|
return jnp.fft.rfftn(x)
|
||||||
|
|
||||||
|
def ifft3d(x):
|
||||||
|
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
|
return jaxdecomp.pifft3d(x).real
|
||||||
|
else:
|
||||||
|
return jnp.fft.irfftn(x)
|
||||||
|
|
||||||
|
def get_local_shape(mesh_shape):
|
||||||
|
""" Helper function to get the local size of a mesh given the global size.
|
||||||
|
"""
|
||||||
|
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||||
|
return mesh_shape
|
||||||
|
else:
|
||||||
|
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
||||||
|
return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]]
|
|
@ -1,24 +1,33 @@
|
||||||
|
from jaxpm.distributed import autoshmap
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from functools import partial
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
def fftk(shape, dtype=np.float32):
|
||||||
""" Return k_vector given a shape (nc, nc, nc) and box_size
|
|
||||||
"""
|
"""
|
||||||
k = []
|
Generate Fourier transform wave numbers for a given mesh.
|
||||||
for d in range(len(shape)):
|
|
||||||
kd = np.fft.fftfreq(shape[d])
|
|
||||||
kd *= 2 * np.pi
|
|
||||||
kdshape = np.ones(len(shape), dtype='int')
|
|
||||||
if symmetric and d == len(shape) - 1:
|
|
||||||
kd = kd[:shape[d] // 2 + 1]
|
|
||||||
kdshape[d] = len(kd)
|
|
||||||
kd = kd.reshape(kdshape)
|
|
||||||
|
|
||||||
k.append(kd.astype(dtype))
|
Args:
|
||||||
del kd, kdshape
|
nc (int): Shape of the mesh grid.
|
||||||
return k
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of wave number arrays for each dimension in
|
||||||
|
the order [kx, ky, kz].
|
||||||
|
"""
|
||||||
|
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]
|
||||||
|
@partial(
|
||||||
|
autoshmap,
|
||||||
|
in_specs=(P('x'), P('y'), P(None)),
|
||||||
|
out_specs=(P('x'), P(None, 'y'), P(None)))
|
||||||
|
def get_kvec(ky, kz, kx):
|
||||||
|
return (ky.reshape([-1, 1, 1]),
|
||||||
|
kz.reshape([1, -1, 1]),
|
||||||
|
kx.reshape([1, 1, -1])) # yapf: disable
|
||||||
|
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||||
|
# to the order of dimensions in the transposed FFT
|
||||||
|
return kx, ky, kz
|
||||||
|
|
||||||
def gradient_kernel(kvec, direction, order=1):
|
def gradient_kernel(kvec, direction, order=1):
|
||||||
"""
|
"""
|
||||||
|
@ -60,11 +69,7 @@ def laplace_kernel(kvec):
|
||||||
Complex kernel
|
Complex kernel
|
||||||
"""
|
"""
|
||||||
kk = sum(ki**2 for ki in kvec)
|
kk = sum(ki**2 for ki in kvec)
|
||||||
mask = (kk == 0).nonzero()
|
wts = jnp.where(kk == 0, 1., 1. / kk)
|
||||||
kk[mask] = 1
|
|
||||||
wts = 1. / kk
|
|
||||||
imask = (~(kk == 0)).astype(int)
|
|
||||||
wts *= imask
|
|
||||||
return wts
|
return wts
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,25 @@ import jax.lax as lax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from jaxpm.kernels import cic_compensation, fftk
|
from jaxpm.kernels import cic_compensation, fftk
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from functools import partial
|
||||||
|
from jaxpm.distributed import autoshmap
|
||||||
|
|
||||||
|
@partial(autoshmap,
|
||||||
def cic_paint(mesh, positions, weight=None):
|
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
|
||||||
|
out_specs=P('x', 'y'))
|
||||||
|
def cic_paint(mesh, displacement, weight=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
displacement field: [nx, ny, nz, 3]
|
||||||
"""
|
"""
|
||||||
|
part_shape = displacement.shape
|
||||||
|
positions = jnp.stack(jnp.meshgrid(
|
||||||
|
jnp.arange(part_shape[0]),
|
||||||
|
jnp.arange(part_shape[1]),
|
||||||
|
jnp.arange(part_shape[2]),
|
||||||
|
indexing='ij'), axis=-1) + displacement
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
@ -34,11 +46,22 @@ def cic_paint(mesh, positions, weight=None):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def cic_read(mesh, positions):
|
@partial(autoshmap,
|
||||||
|
in_specs=(P('x', 'y'), P('x','y')),
|
||||||
|
out_specs=P('x', 'y'))
|
||||||
|
def cic_read(mesh, displacement):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
displacement: [nx,ny,nz, 3]
|
||||||
"""
|
"""
|
||||||
|
# Compute the position of the particles on a regular grid
|
||||||
|
part_shape = displacement.shape
|
||||||
|
positions = jnp.stack(jnp.meshgrid(
|
||||||
|
jnp.arange(part_shape[0]),
|
||||||
|
jnp.arange(part_shape[1]),
|
||||||
|
jnp.arange(part_shape[2]),
|
||||||
|
indexing='ij'), axis=-1) + displacement
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
@ -52,7 +75,7 @@ def cic_read(mesh, positions):
|
||||||
jnp.array(mesh.shape))
|
jnp.array(mesh.shape))
|
||||||
|
|
||||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
|
||||||
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
|
|
60
jaxpm/pm.py
60
jaxpm/pm.py
|
@ -1,12 +1,15 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||||
longrange_kernel)
|
longrange_kernel)
|
||||||
from jaxpm.painting import cic_paint, cic_read
|
from jaxpm.painting import cic_paint, cic_read
|
||||||
|
from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||||
"""
|
"""
|
||||||
|
@ -17,26 +20,34 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||||
kvec = fftk(mesh_shape)
|
kvec = fftk(mesh_shape)
|
||||||
|
|
||||||
if delta is None:
|
if delta is None:
|
||||||
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
|
delta_k = fft3d(cic_paint(jnp.zeros(mesh_shape), positions))
|
||||||
else:
|
else:
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
delta_k = fft3d(delta)
|
||||||
|
|
||||||
# Computes gravitational potential
|
# Computes gravitational potential
|
||||||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||||
r_split=r_split)
|
r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
return jnp.stack([
|
return jnp.stack([
|
||||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
|
cic_read(ifft3d(gradient_kernel(kvec, i) * pot_k), positions)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
],
|
],
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, initial_conditions, positions, a):
|
def lpt(cosmo, initial_conditions, a, particles_shape=None):
|
||||||
"""
|
"""
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
"""
|
"""
|
||||||
initial_force = pm_forces(positions, delta=initial_conditions)
|
if particles_shape is None:
|
||||||
|
particles_shape = initial_conditions.shape
|
||||||
|
local_mesh_shape = get_local_shape(particles_shape)
|
||||||
|
displacement = autoshmap(
|
||||||
|
partial(jnp.zeros, shape=local_mesh_shape+[3], dtype='float32'),
|
||||||
|
in_specs=(),
|
||||||
|
out_specs=P('x', 'y'))() # yapf: disable
|
||||||
|
|
||||||
|
initial_force = pm_forces(displacement, delta=initial_conditions)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
|
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
|
||||||
|
@ -56,9 +67,15 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
||||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||||
box_size[0] * box_size[1] * box_size[2])
|
box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
field = jax.random.normal(seed, mesh_shape)
|
# Initialize a random field with one slice on each gpu
|
||||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
local_mesh_shape = get_local_shape(mesh_shape)
|
||||||
field = jnp.fft.irfftn(field)
|
field = autoshmap(
|
||||||
|
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
|
||||||
|
in_specs=P(None),
|
||||||
|
out_specs=P('x', 'y'))(seed) # yapf: disable
|
||||||
|
|
||||||
|
field = fft3d(field) * pkmesh**0.5
|
||||||
|
field = ifft3d(field)
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,30 +98,3 @@ def make_ode_fn(mesh_shape):
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def pgd_correction(pos, params):
|
|
||||||
"""
|
|
||||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
|
||||||
args:
|
|
||||||
pos: particle positions [npart, 3]
|
|
||||||
params: [alpha, kl, ks] pgd parameters
|
|
||||||
"""
|
|
||||||
kvec = fftk(mesh_shape)
|
|
||||||
|
|
||||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
|
||||||
alpha, kl, ks = params
|
|
||||||
delta_k = jnp.fft.rfftn(delta)
|
|
||||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
|
||||||
|
|
||||||
pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
|
|
||||||
|
|
||||||
forces_pgd = jnp.stack([
|
|
||||||
cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
|
||||||
for i in range(3)
|
|
||||||
],
|
|
||||||
axis=-1)
|
|
||||||
|
|
||||||
dpos_pgd = forces_pgd * alpha
|
|
||||||
|
|
||||||
return dpos_pgd
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue