diff --git a/dev/jaxdecomp.py b/dev/jaxdecomp.py new file mode 100644 index 0000000..14b249b --- /dev/null +++ b/dev/jaxdecomp.py @@ -0,0 +1,62 @@ +import argparse +import jax +import numpy as np + +# Setting up distributed jax +jax.distributed.initialize() +rank = jax.process_index() +size = jax.process_count() + +import jax.numpy as jnp +import jax_cosmo as jc +from jaxpm.pm import linear_field, lpt +from jaxpm.painting import cic_paint +from jax.experimental import mesh_utils +from jax.sharding import Mesh + +mesh_shape= [256, 256, 256] +box_size = [256.,256.,256.] +snapshots = jnp.linspace(0.1, 1., 2) + +@jax.jit +def run_simulation(omega_c, sigma8, seed): + # Create a cosmology + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + + # Create a small function to generate the matter power spectrum + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape) + + # Create initial conditions + initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed) + + # Initialize particle displacements + dx, p, f = lpt(cosmo, initial_conditions, 1.0) + + field = cic_paint(jnp.zeros_like(initial_conditions), dx) + return field + +def main(args): + # Setting up distributed random numbers + master_key = jax.random.PRNGKey(42) + key = jax.random.split(master_key, size)[rank] + + # Create computing mesh and sharding information + devices = mesh_utils.create_device_mesh((2,2)) + mesh = Mesh(devices.T, axis_names=('x', 'y')) + + # Run the simulation on the compute mesh + with mesh: + field = run_simulation(0.32, 0.8, key) + + print('done') + np.save(f'field_{rank}.npy', field.addressable_data(0)) + + # Closing distributed jax + jax.distributed.shutdown() + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py new file mode 100644 index 0000000..9a81440 --- /dev/null +++ b/jaxpm/distributed.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, Hashable + +Specs = Any +AxisName = Hashable + +try: + import jaxdecomp + distributed = True +except ImportError: + print("jaxdecomp not installed. Distributed functions will not work.") + distributed = False + +import jax.numpy as jnp +from jax._src import mesh as mesh_lib +from jax.experimental.shard_map import shard_map + + +def autoshmap(f: Callable, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()): + """Helper function to wrap the provided function in a shard map if + the code is being executed in a mesh context.""" + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + return f + else: + return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + +def fft3d(x): + if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): + return jaxdecomp.pfft3d(x.astype(jnp.complex64)) + else: + return jnp.fft.rfftn(x) + +def ifft3d(x): + if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): + return jaxdecomp.pifft3d(x).real + else: + return jnp.fft.irfftn(x) + +def get_local_shape(mesh_shape): + """ Helper function to get the local size of a mesh given the global size. + """ + if mesh_lib.thread_resources.env.physical_mesh.empty: + return mesh_shape + else: + pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape + return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]] \ No newline at end of file diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8447f8a..64001f5 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,24 +1,33 @@ +from jaxpm.distributed import autoshmap +from jax.sharding import PartitionSpec as P +from functools import partial import jax.numpy as jnp import numpy as np -def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size +def fftk(shape, dtype=np.float32): """ - k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) + Generate Fourier transform wave numbers for a given mesh. - k.append(kd.astype(dtype)) - del kd, kdshape - return k + Args: + nc (int): Shape of the mesh grid. + Returns: + list: List of wave number arrays for each dimension in + the order [kx, ky, kz]. + """ + kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] + @partial( + autoshmap, + in_specs=(P('x'), P('y'), P(None)), + out_specs=(P('x'), P(None, 'y'), P(None))) + def get_kvec(ky, kz, kx): + return (ky.reshape([-1, 1, 1]), + kz.reshape([1, -1, 1]), + kx.reshape([1, 1, -1])) # yapf: disable + ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds + # to the order of dimensions in the transposed FFT + return kx, ky, kz def gradient_kernel(kvec, direction, order=1): """ @@ -60,11 +69,7 @@ def laplace_kernel(kvec): Complex kernel """ kk = sum(ki**2 for ki in kvec) - mask = (kk == 0).nonzero() - kk[mask] = 1 - wts = 1. / kk - imask = (~(kk == 0)).astype(int) - wts *= imask + wts = jnp.where(kk == 0, 1., 1. / kk) return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fb5dbd5..bacaf46 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -3,13 +3,25 @@ import jax.lax as lax import jax.numpy as jnp from jaxpm.kernels import cic_compensation, fftk +from jax.sharding import PartitionSpec as P +from functools import partial +from jaxpm.distributed import autoshmap - -def cic_paint(mesh, positions, weight=None): +@partial(autoshmap, + in_specs=(P('x', 'y'), P('x','y'), P('x','y')), + out_specs=P('x', 'y')) +def cic_paint(mesh, displacement, weight=None): """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + mesh: [nx, ny, nz] + displacement field: [nx, ny, nz, 3] + """ + part_shape = displacement.shape + positions = jnp.stack(jnp.meshgrid( + jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), axis=-1) + displacement + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -34,11 +46,22 @@ def cic_paint(mesh, positions, weight=None): return mesh -def cic_read(mesh, positions): +@partial(autoshmap, + in_specs=(P('x', 'y'), P('x','y')), + out_specs=P('x', 'y')) +def cic_read(mesh, displacement): """ Paints positions onto mesh mesh: [nx, ny, nz] - positions: [npart, 3] + displacement: [nx,ny,nz, 3] """ + # Compute the position of the particles on a regular grid + part_shape = displacement.shape + positions = jnp.stack(jnp.meshgrid( + jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), axis=-1) + displacement + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -52,7 +75,7 @@ def cic_read(mesh, positions): jnp.array(mesh.shape)) return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], - neighboor_coords[..., 3]] * kernel).sum(axis=-1) + neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1]) def cic_paint_2d(mesh, positions, weight): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 41ab2a7..d6d80c2 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,12 +1,15 @@ import jax import jax.numpy as jnp import jax_cosmo as jc +from jax.sharding import PartitionSpec as P from jaxpm.growth import dGfa, growth_factor, growth_rate from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_read +from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape +from functools import partial def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): """ @@ -17,26 +20,34 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): kvec = fftk(mesh_shape) if delta is None: - delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) + delta_k = fft3d(cic_paint(jnp.zeros(mesh_shape), positions)) else: - delta_k = jnp.fft.rfftn(delta) + delta_k = fft3d(delta) # Computes gravitational potential pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) # Computes gravitational forces return jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions) + cic_read(ifft3d(gradient_kernel(kvec, i) * pot_k), positions) for i in range(3) ], axis=-1) -def lpt(cosmo, initial_conditions, positions, a): +def lpt(cosmo, initial_conditions, a, particles_shape=None): """ Computes first order LPT displacement """ - initial_force = pm_forces(positions, delta=initial_conditions) + if particles_shape is None: + particles_shape = initial_conditions.shape + local_mesh_shape = get_local_shape(particles_shape) + displacement = autoshmap( + partial(jnp.zeros, shape=local_mesh_shape+[3], dtype='float32'), + in_specs=(), + out_specs=P('x', 'y'))() # yapf: disable + + initial_force = pm_forces(displacement, delta=initial_conditions) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, @@ -56,9 +67,15 @@ def linear_field(mesh_shape, box_size, pk, seed): pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - field = jax.random.normal(seed, mesh_shape) - field = jnp.fft.rfftn(field) * pkmesh**0.5 - field = jnp.fft.irfftn(field) + # Initialize a random field with one slice on each gpu + local_mesh_shape = get_local_shape(mesh_shape) + field = autoshmap( + partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), + in_specs=P(None), + out_specs=P('x', 'y'))(seed) # yapf: disable + + field = fft3d(field) * pkmesh**0.5 + field = ifft3d(field) return field @@ -81,30 +98,3 @@ def make_ode_fn(mesh_shape): return dpos, dvel return nbody_ode - - -def pgd_correction(pos, params): - """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 - args: - pos: particle positions [npart, 3] - params: [alpha, kl, ks] pgd parameters - """ - kvec = fftk(mesh_shape) - - delta = cic_paint(jnp.zeros(mesh_shape), pos) - alpha, kl, ks = params - delta_k = jnp.fft.rfftn(delta) - PGD_range = PGD_kernel(kvec, kl, ks) - - pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range - - forces_pgd = jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) - for i in range(3) - ], - axis=-1) - - dpos_pgd = forces_pgd * alpha - - return dpos_pgd