diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml new file mode 100644 index 0000000..97cd358 --- /dev/null +++ b/.github/workflows/formatting.yml @@ -0,0 +1,21 @@ +name: Code Formatting + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + - name: Install dependencies + run: | + python -m pip install --upgrade pip isort + python -m pip install pre-commit + - name: Run pre-commit + run: python -m pre_commit run --all-files diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..6dd84e0 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,45 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + run_tests: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10" , "3.11" , "3.12"] + + steps: + - name: Checkout Source + uses: actions/checkout@v2.3.1 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + sudo apt-get install -y libopenmpi-dev + python -m pip install --upgrade pip + pip install jax==0.4.35 + pip install numpy setuptools cython wheel + pip install git+https://github.com/MP-Gadget/pfft-python + pip install git+https://github.com/MP-Gadget/pmesh + pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation + pip install .[test] + + - name: Run Single Device Tests + run: | + cd tests + pytest -v -m "not distributed" + - name: Run Distributed tests + run: | + pytest -v -m distributed diff --git a/.gitignore b/.gitignore index b6e4761..baef139 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,11 @@ __pypackages__/ celerybeat-schedule celerybeat.pid + +out +traces +*.npy +*.out # SageMath parsed files *.sage.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d476f32..f44eaca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,4 @@ repos: rev: 5.13.2 hooks: - id: isort - name: isort (python) \ No newline at end of file + name: isort (python) diff --git a/README.md b/README.md index b1a6bd7..c87e191 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,15 @@ JAX-powered Cosmological Particle-Mesh N-body Solver -**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)** - ## Goals Provide a modern infrastructure to support differentiable PM N-body simulations using JAX: - Keep implementation simple and readable, in pure NumPy API -- Transparent distribution using builtin `xmap` - Any order forward and backward automatic differentiation - Support automated batching using `vmap` - Compatibility with external optimizer libraries like `optax` +- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35` + ## Open development and use @@ -23,6 +22,10 @@ Current expectations are: - Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal). - Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers. +## Getting Started + +To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial. + ## Contributors ✨ diff --git a/design.md b/design.md deleted file mode 100644 index a0727a1..0000000 --- a/design.md +++ /dev/null @@ -1,52 +0,0 @@ -# Design Document for JaxPM - -This document aims to detail some of the API, implementation choices, and internal mechanism. - -## Objective - -Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. - -## Related Work - -This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models. - -- [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow -- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD -- Borg - - -In addition, a number of fast N-body simulation projets exist out there: -- [FastPM](https://github.com/fastpm/fastpm) -- ... - -## Design Overview - -### Coding principles - -Following recent trends and JAX philosophy, the library should have a functional programming type of interface. - - -### Illustration of API - -Here is a potential illustration of what the user interface could be for the simulation code: -```python -import jaxpm as jpm -import jax_cosmo as jc - -# Instantiate differentiable cosmology object -cosmo = jc.Planck() - -# Creates initial conditions -inital_conditions = jpm.generate_ic(cosmo, boxsize, nmesh, dtype='float32') - -# Create a particular solver -solver = jpm.solvers.fastpm(cosmo, B=1) - -# Initialize and run the simulation -state = solver.init(initial_conditions) -state = solver.nbody(state) - -# Painting the results -density = jpm.zeros(boxsize, nmesh) -density = jpm.paint(density, state.positions) -``` diff --git a/dev/test_pfft.py b/dev/test_pfft.py deleted file mode 100644 index 5a956d8..0000000 --- a/dev/test_pfft.py +++ /dev/null @@ -1,96 +0,0 @@ -# Can be executed with: -# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py -from functools import partial - -import jax -import jax.lax as lax -import jax.numpy as jnp -import numpy as np -from jax.experimental.maps import Mesh, xmap -from jax.experimental.pjit import PartitionSpec, pjit - -jax.distributed.initialize() - -cube_size = 2048 - - -@partial(xmap, - in_axes=[...], - out_axes=['x', 'y', ...], - axis_sizes={ - 'x': cube_size, - 'y': cube_size - }, - axis_resources={ - 'x': 'nx', - 'y': 'ny', - 'key_x': 'nx', - 'key_y': 'ny' - }) -def pnormal(key): - return jax.random.normal(key, shape=[cube_size]) - - -@partial(xmap, - in_axes={ - 0: 'x', - 1: 'y' - }, - out_axes=['x', 'y', ...], - axis_resources={ - 'x': 'nx', - 'y': 'ny' - }) -@jax.jit -def pfft3d(mesh): - # [x, y, z] - mesh = jnp.fft.fft(mesh) # Transform on z - mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x] - mesh = jnp.fft.fft(mesh) # Transform on x - mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y] - mesh = jnp.fft.fft(mesh) # Transform on y - # [z, x, y] - return mesh - - -@partial(xmap, - in_axes={ - 0: 'x', - 1: 'y' - }, - out_axes=['x', 'y', ...], - axis_resources={ - 'x': 'nx', - 'y': 'ny' - }) -@jax.jit -def pifft3d(mesh): - # [z, x, y] - mesh = jnp.fft.ifft(mesh) # Transform on y - mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x] - mesh = jnp.fft.ifft(mesh) # Transform on x - mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z] - mesh = jnp.fft.ifft(mesh) # Transform on z - # [x, y, z] - return mesh - - -key = jax.random.PRNGKey(42) -# keys = jax.random.split(key, 4).reshape((2,2,2)) - -# We reshape all our devices to the mesh shape we want -devices = np.array(jax.devices()).reshape((2, 4)) - -with Mesh(devices, ('nx', 'ny')): - mesh = pnormal(key) - kmesh = pfft3d(mesh) - kmesh.block_until_ready() - -# jax.profiler.start_trace("tensorboard") -# with Mesh(devices, ('nx', 'ny')): -# mesh = pnormal(key) -# kmesh = pfft3d(mesh) -# kmesh.block_until_ready() -# jax.profiler.stop_trace() - -print('Done') diff --git a/dev/test_script.py b/dev/test_script.py deleted file mode 100644 index 4f3ca06..0000000 --- a/dev/test_script.py +++ /dev/null @@ -1,68 +0,0 @@ -# Start this script with: -# mpirun -np 4 python test_script.py -import os - -os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' -import jax -import jax.lax as lax -import jax.numpy as jnp -import matplotlib.pylab as plt -import numpy as np -import tensorflow_probability as tfp -from jax.experimental.maps import mesh, xmap -from jax.experimental.pjit import PartitionSpec, pjit - -tfp = tfp.substrates.jax -tfd = tfp.distributions - - -def cic_paint(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], - [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - - dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), - inserted_window_dims=(0, 1, 2), - scatter_dims_to_operand_dims=(0, 1, - 2)) - mesh = lax.scatter_add( - mesh, - neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - kernel.reshape([-1, 8]), dnums) - return mesh - - -# And let's draw some points from some 3D distribution -dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.], - scale_identity_multiplier=3.) -pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) - -f = pjit(lambda x: cic_paint(x, pos), - in_axis_resources=PartitionSpec('x', 'y', 'z'), - out_axis_resources=None) - -devices = np.array(jax.devices()).reshape((2, 2, 1)) - -# Let's import the mesh -m = jnp.zeros([32, 32, 32]) - -with mesh(devices, ('x', 'y', 'z')): - # Shard the mesh, I'm not sure this is absolutely necessary - m = pjit(lambda x: x, - in_axis_resources=None, - out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) - - # Apply the sharded CiC function - res = f(m) - -plt.imshow(res.sum(axis=2)) -plt.show() diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py new file mode 100644 index 0000000..721a971 --- /dev/null +++ b/jaxpm/distributed.py @@ -0,0 +1,198 @@ +from typing import Any, Callable, Hashable + +Specs = Any +AxisName = Hashable + +from functools import partial + +import jax +import jax.numpy as jnp +import jaxdecomp +from jax import lax +from jax.experimental.shard_map import shard_map +from jax.sharding import AbstractMesh, Mesh +from jax.sharding import PartitionSpec as P + + +def autoshmap( + f: Callable, + gpu_mesh: Mesh | AbstractMesh | None, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = False, + auto: frozenset[AxisName] = frozenset()) -> Callable: + """Helper function to wrap the provided function in a shard map if + the code is being executed in a mesh context.""" + if gpu_mesh is None or gpu_mesh.empty: + return f + else: + return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto) + + +def fft3d(x): + return jaxdecomp.pfft3d(x) + + +def ifft3d(x): + return jaxdecomp.pifft3d(x).real + + +def get_halo_size(halo_size, sharding): + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is None or gpu_mesh.empty: + zero_ext = (0, 0) + zero_tuple = (0, 0) + return (zero_tuple, zero_tuple, zero_tuple), zero_ext + else: + pdims = gpu_mesh.devices.shape + halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size) + halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size) + + halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2 + halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2 + return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext)) + + +def halo_exchange(x, halo_extents, halo_periods=(True, True)): + if (halo_extents[0] > 0 or halo_extents[1] > 0): + return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) + else: + return x + + +def slice_unpad_impl(x, pad_width): + + halo_x, _ = pad_width[0] + halo_y, _ = pad_width[1] + # Apply corrections along x + x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2]) + x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:]) + # Apply corrections along y + x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) + x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) + + unpad_slice = [slice(None)] * 3 + if halo_x > 0: + unpad_slice[0] = slice(halo_x, -halo_x) + if halo_y > 0: + unpad_slice[1] = slice(halo_y, -halo_y) + + return x[tuple(unpad_slice)] + + +def slice_pad(x, pad_width, sharding): + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is not None and not (gpu_mesh.empty) and ( + pad_width[0][0] > 0 or pad_width[1][0] > 0): + assert sharding is not None + spec = sharding.spec + return shard_map((partial(jnp.pad, pad_width=pad_width)), + mesh=gpu_mesh, + in_specs=spec, + out_specs=spec)(x) + else: + return x + + +def slice_unpad(x, pad_width, sharding): + mesh = sharding.mesh if sharding is not None else None + if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + assert sharding is not None + spec = sharding.spec + return shard_map(partial(slice_unpad_impl, pad_width=pad_width), + mesh=mesh, + in_specs=spec, + out_specs=spec)(x) + else: + return x + + +def get_local_shape(mesh_shape, sharding=None): + """ Helper function to get the local size of a mesh given the global size. + """ + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is None or gpu_mesh.empty: + return mesh_shape + else: + pdims = gpu_mesh.devices.shape + return [ + mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], + *mesh_shape[2:] + ] + + +def _axis_names(spec): + if len(spec) == 1: + x_axis, = spec + y_axis = None + single_axis = True + elif len(spec) == 2: + x_axis, y_axis = spec + if y_axis == None: + single_axis = True + elif x_axis == None: + x_axis = y_axis + single_axis = True + else: + single_axis = False + else: + raise ValueError("Only 1 or 2 axis sharding is supported") + return x_axis, y_axis, single_axis + + +def uniform_particles(mesh_shape, sharding=None): + + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is not None and not (gpu_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape, sharding) + spec = sharding.spec + x_axis, y_axis, single_axis = _axis_names(spec) + + def particles(): + x_indx = lax.axis_index(x_axis) + y_indx = 0 if single_axis else lax.axis_index(y_axis) + + x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0] + y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1] + z = jnp.arange(local_mesh_shape[2]) + return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1) + + return shard_map(particles, mesh=gpu_mesh, in_specs=(), + out_specs=spec)() + else: + return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape], + indexing='ij'), + axis=-1) + + +def normal_field(mesh_shape, seed, sharding=None): + """Generate a Gaussian random field with the given power spectrum.""" + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is not None and not (gpu_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape, sharding) + + size = jax.device_count() + # rank = jax.process_index() + # process_index is multi_host only + # to make the code work both in multi host and single controller we can do this trick + keys = jax.random.split(seed, size) + spec = sharding.spec + x_axis, y_axis, single_axis = _axis_names(spec) + + def normal(keys, shape, dtype): + idx = lax.axis_index(x_axis) + if not single_axis: + y_index = lax.axis_index(y_axis) + x_size = lax.psum(1, axis_name=x_axis) + idx += y_index * x_size + + return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype) + + return shard_map( + partial(normal, shape=local_mesh_shape, dtype='float32'), + mesh=gpu_mesh, + in_specs=P(None), + out_specs=spec)(keys) # yapf: disable + else: + return jax.random.normal(shape=mesh_shape, key=seed) diff --git a/jaxpm/growth.py b/jaxpm/growth.py index 5b6908c..ec248f3 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -1,6 +1,6 @@ import jax.numpy as np +from jax.numpy import interp from jax_cosmo.background import * -from jax_cosmo.scipy.interpolate import interp from jax_cosmo.scipy.ode import odeint @@ -587,5 +587,6 @@ def dGf2a(cosmo, a): cache = cosmo._workspace['background.growth_factor'] f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = interp(np.log(a), np.log(cache['a']), f2p) - E = E(cosmo, a) - return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f) + E_a = E(cosmo, a) + return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + + 3 * a**2 * E_a * D2f) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 3bcb9ee..912fe2f 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,30 +1,46 @@ import jax.numpy as jnp import numpy as np +from jax.lax import FftType +from jax.sharding import PartitionSpec as P +from jaxdecomp import fftfreq3d, get_output_specs + +from jaxpm.distributed import autoshmap -def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ - Return wave-vectors for a given shape +def fftk(k_array): """ - k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi - kdshape = np.ones(len(shape), dtype='int') - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] - kdshape[d] = len(kd) - kd = kd.reshape(kdshape) + Generate Fourier transform wave numbers for a given mesh. - k.append(kd.astype(dtype)) - del kd, kdshape - return k + Args: + nc (int): Shape of the mesh grid. + + Returns: + list: List of wave number arrays for each dimension in + the order [kx, ky, kz]. + """ + kx, ky, kz = fftfreq3d(k_array) + # to the order of dimensions in the transposed FFT + return kx, ky, kz + + +def interpolate_power_spectrum(input, k, pk, sharding=None): + + pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape) + + gpu_mesh = sharding.mesh if sharding is not None else None + specs = sharding.spec if sharding is not None else P() + out_specs = P(*get_output_specs( + FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P() + + return autoshmap(pk_fn, + gpu_mesh=gpu_mesh, + in_specs=out_specs, + out_specs=out_specs)(input) def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction - Parameters ----------- kvec: list @@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1): return wts -def invlaplace_kernel(kvec): +def invlaplace_kernel(kvec, fd=False): """ - Compute the inverse Laplace kernel + Compute the inverse Laplace kernel. + + cf. [Feng+2016](https://arxiv.org/pdf/1603.00476) Parameters ----------- kvec: list List of wave-vectors + fd: bool + Finite difference kernel Returns -------- wts: array Complex kernel values """ - kk = sum(ki**2 for ki in kvec) - kk_nozeros = jnp.where(kk==0, 1, kk) - return - jnp.where(kk==0, 0, 1 / kk_nozeros) + if fd: + kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec) + else: + kk = sum(ki**2 for ki in kvec) + kk_nozeros = jnp.where(kk == 0, 1, kk) + return -jnp.where(kk == 0, 0, 1 / kk_nozeros) def longrange_kernel(kvec, r_split): @@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split): List of wave-vectors r_split: float Splitting radius - Returns -------- wts: array Complex kernel values - TODO: @modichirag add documentation """ if r_split != 0: @@ -105,13 +126,12 @@ def cic_compensation(kvec): ----------- kvec: list List of wave-vectors - Returns: -------- wts: array Complex kernel values """ - kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] wts = (kwts[0] * kwts[1] * kwts[2])**(-2) return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fb5dbd5..3083f08 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,15 +1,24 @@ +from functools import partial + import jax import jax.lax as lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange, + ifft3d, slice_pad, slice_unpad) from jaxpm.kernels import cic_compensation, fftk +from jaxpm.painting_utils import gather, scatter -def cic_paint(mesh, positions, weight=None): +def _cic_paint_impl(grid_mesh, positions, weight=None): """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + mesh: [nx, ny, nz] + displacement field: [nx, ny, nz, 3] + """ + + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -19,48 +28,106 @@ def cic_paint(mesh, positions, weight=None): kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] if weight is not None: - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + if jnp.isscalar(weight): + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + else: + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), + kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - jnp.array(mesh.shape)) + jnp.array(grid_mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]), - dnums) + mesh = lax.scatter_add(grid_mesh, neighboor_coords, + kernel.reshape([-1, 8]), dnums) return mesh -def cic_read(mesh, positions): +@partial(jax.jit, static_argnums=(3, 4)) +def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): + + positions = positions.reshape((*grid_mesh.shape, 3)) + + halo_size, halo_extents = get_halo_size(halo_size, sharding) + grid_mesh = slice_pad(grid_mesh, halo_size, sharding) + + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + grid_mesh = autoshmap(_cic_paint_impl, + gpu_mesh=gpu_mesh, + in_specs=(spec, spec, P()), + out_specs=spec)(grid_mesh, positions, weight) + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) + + return grid_mesh + + +def _cic_read_impl(grid_mesh, positions): """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + mesh: [nx, ny, nz] + positions: [nx,ny,nz, 3] + """ + # Save original shape for reshaping output later + original_shape = positions.shape + # Reshape positions to a flat list of 3D coordinates + positions = positions.reshape([-1, 3]) + # Expand dimensions to calculate neighbor coordinates positions = jnp.expand_dims(positions, 1) + # Floor the positions to get the base grid cell for each particle floor = jnp.floor(positions) + # Define connections to calculate all neighbor coordinates connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - + # Calculate the 8 neighboring coordinates neighboor_coords = floor + connection + # Calculate kernel weights based on distance from each neighboring coordinate kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - + # Modulo operation to wrap around edges if necessary neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), - jnp.array(mesh.shape)) + jnp.array(grid_mesh.shape)) + # Ensure grid_mesh shape is as expected + # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel + return (grid_mesh[neighboor_coords[..., 0], + neighboor_coords[..., 1], + neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable - return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], - neighboor_coords[..., 3]] * kernel).sum(axis=-1) + +@partial(jax.jit, static_argnums=(2, 3)) +def cic_read(grid_mesh, positions, halo_size=0, sharding=None): + + original_shape = positions.shape + positions = positions.reshape((*grid_mesh.shape, 3)) + + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) + grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + + displacement = autoshmap(_cic_read_impl, + gpu_mesh=gpu_mesh, + in_specs=(spec, spec), + out_specs=spec)(grid_mesh, positions) + + return displacement.reshape(original_shape[:-1]) def cic_paint_2d(mesh, positions, weight): """ Paints positions onto a 2d mesh - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - """ + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -84,17 +151,109 @@ def cic_paint_2d(mesh, positions, weight): return mesh +def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): + + halo_x, _ = halo_size[0] + halo_y, _ = halo_size[1] + + original_shape = displacements.shape + particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') + if not jnp.isscalar(weight): + if weight.shape != original_shape[:-1]: + raise ValueError("Weight shape must match particle shape") + else: + weight = weight.flatten() + # Padding is forced to be zero in a single gpu run + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + particle_mesh = jnp.pad(particle_mesh, halo_size) + pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + return scatter(pmid.reshape([-1, 3]), + displacements.reshape([-1, 3]), + particle_mesh, + chunk_size=2**24, + val=weight) + + +@partial(jax.jit, static_argnums=(1, 2, 4)) +def cic_paint_dx(displacements, + halo_size=0, + sharding=None, + weight=1.0, + chunk_size=2**24): + + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) + + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + grid_mesh = autoshmap(partial(_cic_paint_dx_impl, + halo_size=halo_size, + weight=weight, + chunk_size=chunk_size), + gpu_mesh=gpu_mesh, + in_specs=spec, + out_specs=spec)(displacements) + + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) + return grid_mesh + + +def _cic_read_dx_impl(grid_mesh, disp, halo_size): + + halo_x, _ = halo_size[0] + halo_y, _ = halo_size[1] + + original_shape = [ + dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size) + ] + a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), + jnp.arange(original_shape[1]), + jnp.arange(original_shape[2]), + indexing='ij') + + pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + + pmid = pmid.reshape([-1, 3]) + disp = disp.reshape([-1, 3]) + + return gather(pmid, disp, grid_mesh).reshape(original_shape) + + +@partial(jax.jit, static_argnums=(2, 3)) +def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): + + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) + grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size), + gpu_mesh=gpu_mesh, + in_specs=(spec), + out_specs=spec)(grid_mesh, disp) + + return displacements + + def compensate_cic(field): """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ - nc = field.shape - kvec = fftk(nc) + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ + delta_k = fft3d(field) - delta_k = jnp.fft.rfftn(field) + kvec = fftk(delta_k) delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) + return ifft3d(delta_k) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py new file mode 100644 index 0000000..cf68f9d --- /dev/null +++ b/jaxpm/painting_utils.py @@ -0,0 +1,190 @@ +import jax +import jax.numpy as jnp +from jax.lax import scan + + +def _chunk_split(ptcl_num, chunk_size, *arrays): + """Split and reshape particle arrays into chunks and remainders, with the remainders + preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" + chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) + remainder_size = ptcl_num % chunk_size + chunk_num = ptcl_num // chunk_size + + remainder = None + chunks = arrays + if remainder_size: + remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] + chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] + + # `scan` triggers errors in scatter and gather without the `full` + chunks = [ + x.reshape(chunk_num, chunk_size, *x.shape[1:]) + if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks + ] + + return remainder, chunks + + +def enmesh(base_indices, displacements, cell_size, base_shape, offset, + new_cell_size, new_shape): + """Multilinear enmeshing.""" + base_indices = jnp.asarray(base_indices) + displacements = jnp.asarray(displacements) + with jax.experimental.enable_x64(): + cell_size = jnp.float64( + cell_size) if new_cell_size is not None else jnp.array( + cell_size, dtype=displacements.dtype) + if base_shape is not None: + base_shape = jnp.array(base_shape, dtype=base_indices.dtype) + offset = jnp.float64(offset) + if new_cell_size is not None: + new_cell_size = jnp.float64(new_cell_size) + if new_shape is not None: + new_shape = jnp.array(new_shape, dtype=base_indices.dtype) + + spatial_dim = base_indices.shape[1] + neighbor_offsets = ( + jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >> + jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1 + + if new_cell_size is not None: + particle_positions = base_indices * cell_size + displacements - offset + particle_positions = particle_positions[:, jnp. + newaxis] # insert neighbor axis + new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear + + if base_shape is not None: + grid_length = base_shape * cell_size + new_indices %= grid_length + + new_indices //= new_cell_size + new_displacements = particle_positions - new_indices * new_cell_size + + if base_shape is not None: + new_displacements -= jnp.rint( + new_displacements / grid_length + ) * grid_length # also abs(new_displacements) < new_cell_size is expected + + new_indices = new_indices.astype(base_indices.dtype) + new_displacements = new_displacements.astype(displacements.dtype) + new_cell_size = new_cell_size.astype(displacements.dtype) + + new_displacements /= new_cell_size + else: + offset_indices, offset_displacements = jnp.divmod(offset, cell_size) + base_indices -= offset_indices.astype(base_indices.dtype) + displacements -= offset_displacements.astype(displacements.dtype) + + # insert neighbor axis + base_indices = base_indices[:, jnp.newaxis] + displacements = displacements[:, jnp.newaxis] + + # multilinear + displacements /= cell_size + new_indices = jnp.floor(displacements).astype(base_indices.dtype) + new_indices += neighbor_offsets + new_displacements = displacements - new_indices + new_indices += base_indices + + if base_shape is not None: + new_indices %= base_shape + + weights = 1 - jnp.abs(new_displacements) + + if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None + new_indices = jnp.where(new_indices < 0, new_shape, new_indices) + + weights = weights.prod(axis=-1) + + return new_indices, weights + + +def _scatter_chunk(carry, chunk): + mesh, offset, cell_size, mesh_shape = carry + pmid, disp, val = chunk + spatial_ndim = pmid.shape[1] + spatial_shape = mesh.shape + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) + # scatter + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac)) + carry = mesh, offset, cell_size, mesh_shape + return carry, None + + +def scatter(pmid, + disp, + mesh, + chunk_size=2**24, + val=1., + offset=0, + cell_size=1.): + ptcl_num, spatial_ndim = pmid.shape + val = jnp.asarray(val) + mesh = jnp.asarray(mesh) + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + carry = mesh, offset, cell_size, mesh.shape + if remainder is not None: + carry = _scatter_chunk(carry, remainder)[0] + carry = scan(_scatter_chunk, carry, chunks)[0] + mesh = carry[0] + return mesh + + +def _chunk_cat(remainder_array, chunked_array): + """Reshape and concatenate one remainder and one chunked particle arrays.""" + array = chunked_array.reshape(-1, *chunked_array.shape[2:]) + + if remainder_array is not None: + array = jnp.concatenate((remainder_array, array), axis=0) + + return array + + +def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.): + ptcl_num, spatial_ndim = pmid.shape + + mesh = jnp.asarray(mesh) + + val = jnp.asarray(val) + + if mesh.shape[spatial_ndim:] != val.shape[1:]: + raise ValueError('channel shape mismatch: ' + f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + + carry = mesh, offset, cell_size, mesh.shape + val_0 = None + if remainder is not None: + val_0 = _gather_chunk(carry, remainder)[1] + val = scan(_gather_chunk, carry, chunks)[1] + + val = _chunk_cat(val_0, val) + + return val + + +def _gather_chunk(carry, chunk): + mesh, offset, cell_size, mesh_shape = carry + pmid, disp, val = chunk + + spatial_ndim = pmid.shape[1] + + spatial_shape = mesh.shape[:spatial_ndim] + chan_ndim = mesh.ndim - spatial_ndim + chan_axis = tuple(range(-chan_ndim, 0)) + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) + + # gather + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + frac = jnp.expand_dims(frac, chan_axis) + val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1) + + return carry, val diff --git a/jaxpm/plotting.py b/jaxpm/plotting.py new file mode 100644 index 0000000..9fe4d8e --- /dev/null +++ b/jaxpm/plotting.py @@ -0,0 +1,129 @@ +import matplotlib.pyplot as plt +import numpy as np + + +def plot_fields(fields_dict, sum_over=None): + """ + Plots sum projections of 3D fields along different axes, + slicing only the first `sum_over` elements along each axis. + + Args: + - fields: list of 3D arrays representing fields to plot + - names: list of names for each field, used in titles + - sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8) + """ + sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8 + nb_rows = len(fields_dict) + nb_cols = 3 + fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows)) + + def plot_subplots(proj_axis, field, row, title): + slicing = [slice(None)] * field.ndim + slicing[proj_axis] = slice(None, sum_over) + slicing = tuple(slicing) + + # Sum projection over the specified axis and plot + axes[row, proj_axis].imshow( + field[slicing].sum(axis=proj_axis) + 1, + cmap='magma', + extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) + axes[row, proj_axis].set_xlabel('Mpc/h') + axes[row, proj_axis].set_ylabel('Mpc/h') + axes[row, proj_axis].set_title(title) + + # Plot each field across the three axes + for i, (name, field) in enumerate(fields_dict.items()): + for proj_axis in range(3): + plot_subplots(proj_axis, field, i, + f"{name} projection {proj_axis}") + + plt.tight_layout() + plt.show() + + +def plot_fields_single_projection(fields_dict, + sum_over=None, + project_axis=0, + vmin=None, + vmax=None, + colorbar=False): + """ + Plots a single projection (along axis 0) of 3D fields in a grid, + summing over the first `sum_over` elements along the 0-axis, with 4 images per row. + + Args: + - fields_dict: dictionary where keys are field names and values are 3D arrays + - sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8) + """ + sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8 + nb_fields = len(fields_dict) + nb_cols = 4 # Set number of images per row + nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows + + fig, axes = plt.subplots(nb_rows, + nb_cols, + figsize=(5 * nb_cols, 5 * nb_rows)) + axes = np.atleast_2d(axes) # Ensure axes is always a 2D array + + for i, (name, field) in enumerate(fields_dict.items()): + row, col = divmod(i, nb_cols) + + # Define the slice for the 0-axis projection + slicing = [slice(None)] * field.ndim + slicing[project_axis] = slice(None, sum_over) + slicing = tuple(slicing) + + # Sum projection over axis 0 and plot + a = axes[row, + col].imshow(field[slicing].sum(axis=project_axis) + 1, + cmap='magma', + extent=[0, field.shape[1], 0, field.shape[2]], + vmin=vmin, + vmax=vmax) + axes[row, col].set_xlabel('Mpc/h') + axes[row, col].set_ylabel('Mpc/h') + axes[row, col].set_title(f"{name} projection 0") + if colorbar: + fig.colorbar(a, ax=axes[row, col], shrink=0.7) + + # Remove any empty subplots + for j in range(i + 1, nb_rows * nb_cols): + fig.delaxes(axes.flatten()[j]) + + plt.tight_layout() + plt.show() + + +def stack_slices(array): + """ + Stacks 2D slices of an array into a single array based on provided partition dimensions. + + Args: + - array_slices: a 2D list of array slices (list of lists format) where + array_slices[i][j] is the slice located at row i, column j in the grid. + - pdims: a tuple representing the grid dimensions (rows, columns). + + Returns: + - A single array constructed by stacking the slices. + """ + # Initialize an empty list to store the vertically stacked rows + pdims = array.sharding.mesh.devices.shape + + field_slices = [] + + # Iterate over rows in pdims[0] + for i in range(pdims[0]): + row_slices = [] + + # Iterate over columns in pdims[1] + for j in range(pdims[1]): + slice_index = i * pdims[0] + j + row_slices.append(array.addressable_data(slice_index)) + # Stack the current row of slices vertically + stacked_row = np.hstack(row_slices) + field_slices.append(stacked_row) + + # Stack all rows horizontally to form the full array + full_array = np.vstack(field_slices) + + return full_array diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 4aedef5..e34d584 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,50 +1,92 @@ -import jax import jax.numpy as jnp import jax_cosmo as jc -from jax_cosmo import Cosmology -from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a -from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel -from jaxpm.painting import cic_paint, cic_read +from jaxpm.distributed import fft3d, ifft3d, normal_field +from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, + growth_rate, growth_rate_second) +from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, + invlaplace_kernel, longrange_kernel) +from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx - -def pm_forces(positions, mesh_shape, delta=None, r_split=0): +def pm_forces(positions, + mesh_shape=None, + delta=None, + r_split=0, + paint_absolute_pos=True, + halo_size=0, + sharding=None): """ Computes gravitational forces on particles using a PM scheme """ + if mesh_shape is None: + assert (delta is not None),\ + "If mesh_shape is not provided, delta should be provided" + mesh_shape = delta.shape + + if paint_absolute_pos: + paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape, + device=sharding), + pos, + halo_size=halo_size, + sharding=sharding) + read_fn = lambda grid_mesh, pos: cic_read( + grid_mesh, pos, halo_size=halo_size, sharding=sharding) + else: + paint_fn = lambda disp: cic_paint_dx( + disp, halo_size=halo_size, sharding=sharding) + read_fn = lambda grid_mesh, disp: cic_read_dx( + grid_mesh, disp, halo_size=halo_size, sharding=sharding) + if delta is None: - delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) + field = paint_fn(positions) + delta_k = fft3d(field) elif jnp.isrealobj(delta): - delta_k = jnp.fft.rfftn(delta) + delta_k = fft3d(delta) else: delta_k = delta + kvec = fftk(delta_k) # Computes gravitational potential - kvec = fftk(mesh_shape) - pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel( + kvec, r_split=r_split) # Computes gravitational forces - return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions) - for i in range(3)], axis=-1) + forces = jnp.stack([ + read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions + ) for i in range(3)], axis=-1) # yapf: disable + + return forces -def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1): +def lpt(cosmo, + initial_conditions, + particles=None, + a=0.1, + halo_size=0, + sharding=None, + order=1): """ - Computes first and second order LPT displacement and momentum, + Computes first and second order LPT displacement and momentum, e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ + paint_absolute_pos = particles is not None + if particles is None: + particles = jnp.zeros_like(initial_conditions, + shape=(*initial_conditions.shape, 3)) + a = jnp.atleast_1d(a) - E = jnp.sqrt(jc.background.Esqr(cosmo, a)) - delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft? - mesh_shape = init_mesh.shape - - init_force = pm_forces(positions, mesh_shape, delta=delta_k) - dx = growth_factor(cosmo, a) * init_force + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + delta_k = fft3d(initial_conditions) + initial_force = pm_forces(particles, + delta=delta_k, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) + dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * E * dx - f = a**2 * E * dGfa(cosmo, a) * init_force - + f = a**2 * E * dGfa(cosmo, a) * initial_force if order == 2: - kvec = fftk(mesh_shape) + kvec = fftk(delta_k) pot_k = delta_k * invlaplace_kernel(kvec) delta2 = 0 @@ -54,47 +96,58 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1): # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) nabla_i_nabla_i = gradient_kernel(kvec, i)**2 - shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k) - delta2 += shear_ii * shear_acc + shear_ii = ifft3d(nabla_i_nabla_i * pot_k) + delta2 += shear_ii * shear_acc shear_acc += shear_ii # for kj in kvec[i+1:]: - for j in range(i+1, 3): + for j in range(i + 1, 3): # Substract squared strict-up-triangle terms # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 - nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j) - delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2 + nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel( + kvec, j) + delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2 - init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2)) + delta_k2 = fft3d(delta2) + init_force2 = pm_forces(particles, + delta=delta_k2, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second - dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2 + dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2 p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2 f2 = a**2 * E * dGf2a(cosmo, a) * init_force2 dx += dx2 - p += p2 - f += f2 + p += p2 + f += f2 return dx, p, f -def linear_field(mesh_shape, box_size, pk, seed): +def linear_field(mesh_shape, box_size, pk, seed, sharding=None): """ Generate initial conditions. """ - kvec = fftk(mesh_shape) + # Initialize a random field with one slice on each gpu + field = normal_field(mesh_shape, seed=seed, sharding=sharding) + field = fft3d(field) + kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - field = jax.random.normal(seed, mesh_shape) - field = jnp.fft.rfftn(field) * pkmesh**0.5 - field = jnp.fft.irfftn(field) + field = field * (pkmesh)**0.5 + field = ifft3d(field) return field -def make_ode_fn(mesh_shape): +def make_ode_fn(mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): def nbody_ode(state, a, cosmo): """ @@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape): """ pos, vel = state - forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + forces = pm_forces(pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel @@ -114,20 +171,28 @@ def make_ode_fn(mesh_shape): return nbody_ode -def get_ode_fn(cosmo:Cosmology, mesh_shape): + +def make_diffrax_ode(cosmo, + mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): def nbody_ode(a, state, args): """ - State is an array [position, velocities] - - Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) + state is a tuple (position, velocities) """ pos, vel = state - forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m + + forces = pm_forces(pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces @@ -138,51 +203,57 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape): def pgd_correction(pos, mesh_shape, params): """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, + improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 args: pos: particle positions [npart, 3] params: [alpha, kl, ks] pgd parameters """ - kvec = fftk(mesh_shape) delta = cic_paint(jnp.zeros(mesh_shape), pos) + delta_k = fft3d(delta) + kvec = fftk(delta_k) alpha, kl, ks = params - delta_k = jnp.fft.rfftn(delta) - PGD_range=PGD_kernel(kvec, kl, ks) - - pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range + PGD_range = PGD_kernel(kvec, kl, ks) + + pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range + + forces_pgd = jnp.stack([ + cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos) + for i in range(3) + ], + axis=-1) + + dpos_pgd = forces_pgd * alpha - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos) - for i in range(3)],axis=-1) - - dpos_pgd = forces_pgd*alpha - return dpos_pgd def make_neural_ode_fn(model, mesh_shape): - def neural_nbody_ode(state, a, cosmo:Cosmology, params): + + def neural_nbody_ode(state, a, cosmo: Cosmology, params): """ state is a tuple (position, velocities) """ pos, vel = state - kvec = fftk(mesh_shape) - delta = cic_paint(jnp.zeros(mesh_shape), pos) - - delta_k = jnp.fft.rfftn(delta) + delta_k = fft3d(delta) + kvec = fftk(delta_k) # Computes gravitational potential - pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, + r_split=0) # Apply a correction filter - kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) - pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) + kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec)) + pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a))) # Computes gravitational forces - forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos) - for i in range(3)],axis=-1) + forces = jnp.stack([ + cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos) + for i in range(3) + ], + axis=-1) forces = forces * 1.5 * cosmo.Omega_m @@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape): dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces return dpos, dvel + return neural_nbody_ode diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 1593ba0..96faeea 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,47 +1,168 @@ +from functools import partial + import jax.numpy as jnp import numpy as np from jax.scipy.stats import norm +from scipy.special import legendre -__all__ = ['power_spectrum'] +__all__ = [ + 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', + 'cross_correlation_coefficients', 'gaussian_smoothing' +] -def _initialize_pk(shape, boxsize, kmin, dk): +def _initialize_pk(mesh_shape, box_shape, kedges, los): """ - Helper function to initialize various (fixed) values for powerspectra... not differentiable! + Parameters + ---------- + mesh_shape : tuple of int + Shape of the mesh grid. + box_shape : tuple of float + Physical dimensions of the box. + kedges : None, int, float, or list + If None, set dk to twice the minimum. + If int, specifies number of edges. + If float, specifies dk. + los : array_like + Line-of-sight vector. + + Returns + ------- + dig : ndarray + Indices of the bins to which each value in input array belongs. + kcount : ndarray + Count of values in each bin. + kedges : ndarray + Edges of the bins. + mumesh : ndarray + Mu values for the mesh grid. """ - I = np.eye(len(shape), dtype='int') * -2 + 1 + kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist - W = np.empty(shape, dtype='f4') - W[...] = 2.0 - W[..., 0] = 1.0 - W[..., -1] = 1.0 + if isinstance(kedges, None | int | float): + if kedges is None: + dk = 2 * np.pi / np.min( + box_shape) * 2 # twice the minimum wavenumber + if isinstance(kedges, int): + dk = kmax / (kedges + 1) # final number of bins will be kedges-1 + elif isinstance(kedges, float): + dk = kedges + kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2 - kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2 - kedges = np.arange(kmin, kmax, dk) + kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1 + kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape) + for m, l, kshape in zip(mesh_shape, box_shape, kshapes)] + kmesh = sum(ki**2 for ki in kvec)**0.5 - k = [ - np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape) - for N, L, kshape, pkshape in zip(shape, boxsize, I, shape) - ] - kmag = sum(ki**2 for ki in k)**0.5 + dig = np.digitize(kmesh.reshape(-1), kedges) + kcount = np.bincount(dig, minlength=len(kedges) + 1) - xsum = np.zeros(len(kedges) + 1) - Nsum = np.zeros(len(kedges) + 1) + # Central value of each bin + # kavg = (kedges[1:] + kedges[:-1]) / 2 + kavg = np.bincount( + dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount + kavg = kavg[1:-1] - dig = np.digitize(kmag.flat, kedges) + if los is None: + mumesh = 1. + else: + mumesh = sum(ki * losi for ki, losi in zip(kvec, los)) + kmesh_nozeros = np.where(kmesh == 0, 1, kmesh) + mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros) - xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size) - Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size) - return dig, Nsum, xsum, W, k, kedges + return dig, kcount, kavg, mumesh -def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): +def power_spectrum(mesh, + mesh2=None, + box_shape=None, + kedges: int | float | list = None, + multipoles=0, + los=[0., 0., 1.]): """ - Calculate the powerspectra given real space field + Compute the auto and cross spectrum of 3D fields, with multipoles. + """ + # Initialize + mesh_shape = np.array(mesh.shape) + if box_shape is None: + box_shape = mesh_shape + else: + box_shape = np.asarray(box_shape) + + if multipoles == 0: + los = None + else: + los = np.asarray(los) + los = los / np.linalg.norm(los) + poles = np.atleast_1d(multipoles) + dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges, + los) + n_bins = len(kavg) + 2 + + # FFTs + meshk = jnp.fft.fftn(mesh, norm='ortho') + if mesh2 is None: + mmk = meshk.real**2 + meshk.imag**2 + else: + mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj() + + # Sum powers + pk = jnp.empty((len(poles), n_bins)) + for i_ell, ell in enumerate(poles): + weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1) + if mesh2 is None: + psum = jnp.bincount(dig, weights=weights, length=n_bins) + else: # XXX: bincount is really slow with complex numbers + psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins) + psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins) + psum = (psum_real**2 + psum_imag**2)**.5 + pk = pk.at[i_ell].set(psum) + + # Normalization and conversion from cell units to [Mpc/h]^3 + pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod() + + # pk = jnp.concatenate([kavg[None], pk]) + if np.ndim(multipoles) == 0: + return kavg, pk[0] + else: + return kavg, pk + + +def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, (pk1 / pk0)**.5 + + +def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk01 = pk_fn(mesh0, mesh1) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, pk01 / (pk0 * pk1)**.5 + + +def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None): + pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges) + ks, pk01 = pk_fn(mesh0, mesh1) + ks, pk0 = pk_fn(mesh0) + ks, pk1 = pk_fn(mesh1) + return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5 + + +def cross_correlation_coefficients(field_a, + field_b, + kmin=5, + dk=0.5, + boxsize=False): + """ + Calculate the cross correlation coefficients given two real space field Args: - field: real valued field + field_a: real valued field + field_b: real valued field kmin: minimum k-value for binned powerspectra dk: differential in each kbin boxsize: length of each boxlength (can be strangly shaped?) @@ -49,20 +170,21 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): Returns: kbins: the central value of the bins for plotting - power: real valued array of power in each bin + P / norm: normalized cross correlation coefficient between two field a and b """ - shape = field.shape + shape = field_a.shape nx, ny, nz = shape #initialze values related to powerspectra (mode bins and weights) dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) #fast fourier transform - fft_image = jnp.fft.fftn(field) + fft_image_a = jnp.fft.fftn(field_a) + fft_image_b = jnp.fft.fftn(field_b) #absolute value of fast fourier transform - pk = jnp.real(fft_image * jnp.conj(fft_image)) + pk = fft_image_a * jnp.conj(fft_image_b) #calculating powerspectra real = jnp.real(pk).reshape([-1]) @@ -83,55 +205,6 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False): return kbins, P / norm -def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False): - """ - Calculate the cross correlation coefficients given two real space field - - Args: - - field_a: real valued field - field_b: real valued field - kmin: minimum k-value for binned powerspectra - dk: differential in each kbin - boxsize: length of each boxlength (can be strangly shaped?) - - Returns: - - kbins: the central value of the bins for plotting - P / norm: normalized cross correlation coefficient between two field a and b - - """ - shape = field_a.shape - nx, ny, nz = shape - - #initialze values related to powerspectra (mode bins and weights) - dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk) - - #fast fourier transform - fft_image_a = jnp.fft.fftn(field_a) - fft_image_b = jnp.fft.fftn(field_b) - - #absolute value of fast fourier transform - pk = fft_image_a * jnp.conj(fft_image_b) - - #calculating powerspectra - real = jnp.real(pk).reshape([-1]) - imag = jnp.imag(pk).reshape([-1]) - - Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j - Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size) - - P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32') - - #normalization for powerspectra - norm = np.prod(np.array(shape[:])).astype('float32')**2 - - #find central values of each bin - kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2 - - return kbins, P / norm - - def gaussian_smoothing(im, sigma): """ im: 2d image diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py new file mode 100644 index 0000000..da3964e --- /dev/null +++ b/notebooks/05-MultiHost_PM.py @@ -0,0 +1,179 @@ +import os + +os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax +import jax + +jax.distributed.initialize() +rank = jax.process_index() +size = jax.process_count() +if rank == 0: + print(f"SIZE is {jax.device_count()}") + +import argparse +from functools import partial + +import jax.numpy as jnp +import jax_cosmo as jc +import numpy as np +from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, + PIDController, SaveAt, diffeqsolve) +from jax.experimental.mesh_utils import create_device_mesh +from jax.experimental.multihost_utils import process_allgather +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx +from jaxpm.pm import linear_field, lpt, make_diffrax_ode + +all_gather = partial(process_allgather, tiled=True) + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run a cosmological simulation with JAX.") + parser.add_argument( + "-p", + "--pdims", + type=int, + nargs=2, + default=[1, jax.devices()], + help="Processor grid dimensions as two integers (e.g., 2 4).") + parser.add_argument( + "-m", + "--mesh_shape", + type=int, + nargs=3, + default=[512, 512, 512], + help="Shape of the simulation mesh as three values (e.g., 512 512 512)." + ) + parser.add_argument( + "-b", + "--box_size", + type=float, + nargs=3, + default=[500.0, 500.0, 500.0], + help= + "Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)." + ) + parser.add_argument( + "-st", + "--snapshots", + type=int, + default=2, + help="Number of snapshots to save during the simulation.") + parser.add_argument("-H", + "--halo_size", + type=int, + default=64, + help="Halo size for the simulation.") + parser.add_argument("-s", + "--solver", + type=str, + choices=['leapfrog', 'dopri8'], + default='leapfrog', + help="ODE solver choice: 'leapfrog' or 'dopri8'.") + return parser.parse_args() + + +def create_mesh_and_sharding(pdims): + devices = create_device_mesh(pdims) + mesh = Mesh(devices, axis_names=('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + return mesh, sharding + + +@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6)) +def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size, + solver_choice, nb_snapshots, sharding): + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) + + initial_conditions = linear_field(mesh_shape, + box_size, + pk_fn, + seed=jax.random.PRNGKey(0), + sharding=sharding) + + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + + dx, p, _ = lpt(cosmo, + initial_conditions, + a=0.1, + halo_size=halo_size, + sharding=sharding) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + + # Choose solver + solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5() + stepsize_controller = ConstantStepSize( + ) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5) + res = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1., + dt0=0.01, + y0=jnp.stack([dx, p], axis=0), + args=cosmo, + saveat=SaveAt(ts=jnp.linspace(0.2, 1., nb_snapshots)), + stepsize_controller=stepsize_controller) + + ode_fields = [ + cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) + for sol in res.ys + ] + lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding) + return initial_conditions, lpt_field, ode_fields, res.stats + + +def main(): + args = parse_arguments() + mesh_shape = args.mesh_shape + box_size = args.box_size + halo_size = args.halo_size + solver_choice = args.solver + nb_snapshots = args.snapshots + + sharding = create_mesh_and_sharding(args.pdims) + + initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation( + 0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size, + solver_choice, nb_snapshots, sharding) + + if rank == 0: + os.makedirs("fields", exist_ok=True) + print(f"[{rank}] Simulation done") + print(f"Solver stats: {solver_stats}") + + # Save initial conditions + initial_conditions_g = all_gather(initial_conditions) + if rank == 0: + print(f"[{rank}] Saving initial_conditions") + np.save("fields/initial_conditions.npy", initial_conditions_g) + print(f"[{rank}] initial_conditions saved") + del initial_conditions_g, initial_conditions + + # Save LPT displacements + lpt_displacements_g = all_gather(lpt_displacements) + if rank == 0: + print(f"[{rank}] Saving lpt_displacements") + np.save("fields/lpt_displacements.npy", lpt_displacements_g) + print(f"[{rank}] lpt_displacements saved") + del lpt_displacements_g, lpt_displacements + + # Save each ODE solution separately + for i, sol in enumerate(ode_solutions): + sol_g = all_gather(sol) + if rank == 0: + print(f"[{rank}] Saving ode_solution_{i}") + np.save(f"fields/ode_solution_{i}.npy", sol_g) + print(f"[{rank}] ode_solution_{i} saved") + del sol_g + + +if __name__ == "__main__": + main() diff --git a/notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id b/notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id new file mode 100644 index 0000000..aa7a5e1 --- /dev/null +++ b/notebooks/06-Animating_PM_Fields.ipynb.REMOVED.git-id @@ -0,0 +1 @@ +c4a44973e4f11841a8c14f4d200e7e87887419aa \ No newline at end of file diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000..43d9d0b --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,39 @@ +# Particle Mesh Simulation with JAXPM on Multi-GPU and Multi-Host Systems + +This collection of notebooks demonstrates how to perform Particle Mesh (PM) simulations using **JAXPM**, leveraging JAX for efficient computation on multi-GPU and multi-host systems. Each notebook progressively covers different setups, from single-GPU simulations to advanced, distributed, multi-host simulations across multiple nodes. + +## Table of Contents + +1. **[Single-GPU Particle Mesh Simulation](01-Introduction.ipynb)** + - Introduction to basic PM simulations on a single GPU. + - Uses JAXPM to run simulations with absolute particle positions and Cloud-in-Cell (CIC) painting. + +2. **[Advanced Particle Mesh Simulation on a Single GPU](02-Advanced_usage.ipynb)** + - Explore using diffrax solvers in the ODE step. + - Explores second order Lagrangian Perturbation Theory (LPT) simulations. + - Introduces weighted density field projections + +3. **[Multi-GPU Particle Mesh Simulation with Halo Exchange](03-MultiGPU_PM_Halo.ipynb)** + - Extends PM simulation to multi-GPU setups with halo exchange. + - Uses sharding and device mesh configurations to manage distributed data across GPUs. + +4. **[Multi-GPU Particle Mesh Simulation with Advanced Solvers](04-MultiGPU_PM_Solvers.ipynb)** + - Compares different ODE solvers (Leapfrog and Dopri5) in multi-GPU simulations. + - Highlights performance, memory considerations, and solver impact on simulation quality. + +5. **[Multi-Host Particle Mesh Simulation](05-MultiHost_PM.ipynb)** + - Extends PM simulations to multi-host, multi-GPU setups for large-scale simulations. + - Guides through job submission, device initialization, and retrieving results across nodes. + +## Getting Started + +Each notebook includes installation instructions and guidelines for configuring JAXPM and required dependencies. Follow the setup instructions in each notebook to ensure an optimal environment. + +## Requirements + +- **JAXPM** (included in the installation commands within notebooks) +- **Diffrax** for ODE solvers +- **JAX** with CUDA support for multi-GPU or TPU setups +- **SLURM** for job scheduling on clusters (if running multi-host setups) + +> **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9fb84d2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "JaxPM" +version = "0.0.1" +description = "A dead simple FastPM implementation in JAX" +authors = [{ name = "JaxPM developers" }] +readme = "README.md" +requires-python = ">=3.9" +license = { file = "LICENSE" } +urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" } +dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"] + +[project.optional-dependencies] +test = [ + "jax>=0.4.30", + "numpy", + "jax_cosmo", + "jaxdecomp>=0.2.2", + "pytest>=8.0.0", + "pfft-python @ git+https://github.com/MP-Gadget/pfft-python", + "pmesh @ git+https://github.com/MP-Gadget/pmesh", + "fastpm @ git+https://github.com/ASKabalan/fastpm-python", + "diffrax" +] + +[tool.setuptools] +packages = ["jaxpm"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9f020be --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + distributed: mark a test as distributed + single_device: mark a test as single_device diff --git a/setup.py b/setup.py deleted file mode 100644 index a58759a..0000000 --- a/setup.py +++ /dev/null @@ -1,11 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name='JaxPM', - version='0.0.1', - url='https://github.com/DifferentiableUniverseInitiative/JaxPM', - author='JaxPM developers', - description='A dead simple FastPM implementation in JAX', - packages=find_packages(), - install_requires=['jax', 'jax_cosmo'], -) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6d91684 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,175 @@ +# Parameterized fixture for mesh_shape +import os + +import pytest + +os.environ["EQX_ON_ERROR"] = "nan" +setup_done = False +on_cluster = False + + +def is_on_cluster(): + global on_cluster + return on_cluster + + +def initialize_distributed(): + global setup_done + global on_cluster + if not setup_done: + if "SLURM_JOB_ID" in os.environ: + on_cluster = True + print("Running on cluster") + import jax + jax.distributed.initialize() + setup_done = True + on_cluster = True + else: + print("Running locally") + setup_done = True + on_cluster = False + os.environ["JAX_PLATFORM_NAME"] = "cpu" + os.environ[ + "XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax + + +@pytest.fixture( + scope="session", + params=[ + ((32, 32, 32), (256., 256., 256.)), # BOX + ((32, 32, 64), (256., 256., 512.)), # RECTANGULAR + ]) +def simulation_config(request): + return request.param + + +@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8]) +def lpt_scale_factor(request): + return request.param + + +@pytest.fixture(scope="session") +def cosmo(): + from functools import partial + + from jax_cosmo import Cosmology + Planck18 = partial( + Cosmology, + # Omega_m = 0.3111 + Omega_c=0.2607, + Omega_b=0.0490, + Omega_k=0.0, + h=0.6766, + n_s=0.9665, + sigma8=0.8102, + w0=-1.0, + wa=0.0, + ) + + return Planck18() + + +@pytest.fixture(scope="session") +def particle_mesh(simulation_config): + from pmesh.pm import ParticleMesh + mesh_shape, box_shape = simulation_config + return ParticleMesh(BoxSize=box_shape, Nmesh=mesh_shape, dtype='f4') + + +@pytest.fixture(scope="session") +def fpm_initial_conditions(cosmo, particle_mesh): + import jax_cosmo as jc + import numpy as np + from jax import numpy as jnp + + # Generate initial particle positions + grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype( + np.float32) + # Interpolate with linear_matter spectrum to get initial density field + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power(cosmo, k) + + def pk_fn(x): + return jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape) + + whitec = particle_mesh.generate_whitenoise(42, + type='complex', + unitary=False) + lineark = whitec.apply(lambda k, v: pk_fn(sum(ki**2 for ki in k)**0.5)**0.5 + * v * (1 / v.BoxSize).prod()**0.5) + init_mesh = lineark.c2r().value # XXX + + return lineark, grid, init_mesh + + +@pytest.fixture(scope="session") +def initial_conditions(fpm_initial_conditions): + _, _, init_mesh = fpm_initial_conditions + return init_mesh + + +@pytest.fixture(scope="session") +def solver(cosmo, particle_mesh): + from fastpm.core import Cosmology as FastPMCosmology + from fastpm.core import Solver + ref_cosmo = FastPMCosmology(cosmo) + return Solver(particle_mesh, ref_cosmo, B=1) + + +@pytest.fixture(scope="session") +def fpm_lpt1(solver, fpm_initial_conditions, lpt_scale_factor): + + lineark, grid, _ = fpm_initial_conditions + statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=1) + return statelpt + + +@pytest.fixture(scope="session") +def fpm_lpt1_field(fpm_lpt1, particle_mesh): + return particle_mesh.paint(fpm_lpt1.X).value + + +@pytest.fixture(scope="session") +def fpm_lpt2(solver, fpm_initial_conditions, lpt_scale_factor): + + lineark, grid, _ = fpm_initial_conditions + statelpt = solver.lpt(lineark, grid, lpt_scale_factor, order=2) + return statelpt + + +@pytest.fixture(scope="session") +def fpm_lpt2_field(fpm_lpt2, particle_mesh): + return particle_mesh.paint(fpm_lpt2.X).value + + +@pytest.fixture(scope="session") +def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor): + import numpy as np + from fastpm.core import leapfrog + + if lpt_scale_factor == 0.8: + pytest.skip("Do not run nbody simulation from scale factor 0.8") + + stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) + + finalstate = solver.nbody(fpm_lpt1, leapfrog(stages)) + fpm_mesh = particle_mesh.paint(finalstate.X).value + + return fpm_mesh + + +@pytest.fixture(scope="session") +def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor): + import numpy as np + from fastpm.core import leapfrog + + if lpt_scale_factor == 0.8: + pytest.skip("Do not run nbody simulation from scale factor 0.8") + + stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True) + + finalstate = solver.nbody(fpm_lpt2, leapfrog(stages)) + fpm_mesh = particle_mesh.paint(finalstate.X).value + + return fpm_mesh diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..0b85161 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,13 @@ +import jax.numpy as jnp + + +def MSE(x, y): + return jnp.mean((x - y)**2) + + +def MSE_3D(x, y): + return ((x - y)**2).mean(axis=0) + + +def MSRE(x, y): + return jnp.mean(((x - y) / y)**2) diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py new file mode 100644 index 0000000..6d17939 --- /dev/null +++ b/tests/test_against_fpm.py @@ -0,0 +1,155 @@ +import pytest +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from helpers import MSE, MSRE +from jax import numpy as jnp + +from jaxpm.distributed import uniform_particles +from jaxpm.painting import cic_paint, cic_paint_dx +from jaxpm.pm import lpt, make_diffrax_ode +from jaxpm.utils import power_spectrum + +_TOLERANCE = 1e-4 +_PM_TOLERANCE = 1e-3 + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, + fpm_lpt1_field, fpm_lpt2_field, cosmo, order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + particles = uniform_particles(mesh_shape) + + # Initial displacement + dx, _, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) + + fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field + + lpt_field = cic_paint(jnp.zeros(mesh_shape), particles + dx) + _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) + + assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, + fpm_lpt1_field, fpm_lpt2_field, cosmo, order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + # Initial displacement + dx, _, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) + + lpt_field = cic_paint_dx(dx) + + fpm_ref_field = fpm_lpt1_field if order == 1 else fpm_lpt2_field + + _, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) + + assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_nbody_absolute(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + particles = uniform_particles(mesh_shape) + + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) + + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([particles + dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) + + fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 + + _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) + + assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_nbody_relative(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, box_shape = simulation_config + cosmo._workspace = {} + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + + solver = Dopri5() + controller = PIDController(rtol=1e-9, + atol=1e-9, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint_dx(solutions.ys[-1, 0]) + + fpm_ref_field = nbody_from_lpt1 if order == 1 else nbody_from_lpt2 + + _, jpm_ps = power_spectrum(final_field, box_shape=box_shape) + _, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape) + + assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE + assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py new file mode 100644 index 0000000..fd683ab --- /dev/null +++ b/tests/test_distributed_pm.py @@ -0,0 +1,152 @@ +from conftest import initialize_distributed + +initialize_distributed() # ignore : E402 + +import jax # noqa : E402 +import jax.numpy as jnp # noqa : E402 +import pytest # noqa : E402 +from diffrax import SaveAt # noqa : E402 +from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve +from helpers import MSE # noqa : E402 +from jax import lax # noqa : E402 +from jax.experimental.multihost_utils import process_allgather # noqa : E402 +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P # noqa : E402 + +from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 +from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 + +_TOLERANCE = 3.0 # 🙃🙃 + + +@pytest.mark.distributed +@pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("absolute_painting", [True, False]) +def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, + absolute_painting): + + mesh_shape, box_shape = simulation_config + # SINGLE DEVICE RUN + cosmo._workspace = {} + if absolute_painting: + particles = uniform_particles(mesh_shape) + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=0.1, + order=order) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + y0 = jnp.stack([particles + dx, p]) + else: + dx, p, _ = lpt(cosmo, initial_conditions, a=0.1, order=order) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + single_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), + solutions.ys[-1, 0]) + else: + single_device_final_field = cic_paint_dx(solutions.ys[-1, 0]) + + print("Done with single device run") + # MULTI DEVICE RUN + + mesh = jax.make_mesh((1, 8), ('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + halo_size = mesh_shape[0] // 2 + + initial_conditions = lax.with_sharding_constraint(initial_conditions, + sharding) + + print(f"sharded initial conditions {initial_conditions.sharding}") + + cosmo._workspace = {} + if absolute_painting: + particles = uniform_particles(mesh_shape, sharding=sharding) + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, + mesh_shape, + halo_size=halo_size, + sharding=sharding)) + + y0 = jnp.stack([particles + dx, p]) + else: + dx, p, _ = lpt(cosmo, + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, + mesh_shape, + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=0.1, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), + solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + else: + multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) + + multi_device_final_field = process_allgather(multi_device_final_field, + tiled=True) + + mse = MSE(single_device_final_field, multi_device_final_field) + print(f"MSE is {mse}") + + assert mse < _TOLERANCE