diff --git a/dev/jaxdecomp.py b/dev/jaxdecomp.py index 14b249b..ddb19e5 100644 --- a/dev/jaxdecomp.py +++ b/dev/jaxdecomp.py @@ -1,4 +1,5 @@ import argparse + import jax import numpy as np @@ -9,15 +10,17 @@ size = jax.process_count() import jax.numpy as jnp import jax_cosmo as jc -from jaxpm.pm import linear_field, lpt -from jaxpm.painting import cic_paint from jax.experimental import mesh_utils from jax.sharding import Mesh -mesh_shape= [256, 256, 256] -box_size = [256.,256.,256.] +from jaxpm.painting import cic_paint +from jaxpm.pm import linear_field, lpt + +mesh_shape = [256, 256, 256] +box_size = [256., 256., 256.] snapshots = jnp.linspace(0.1, 1., 2) + @jax.jit def run_simulation(omega_c, sigma8, seed): # Create a cosmology @@ -25,38 +28,42 @@ def run_simulation(omega_c, sigma8, seed): # Create a small function to generate the matter power spectrum k = jnp.logspace(-4, 1, 128) - pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) - pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk + ).reshape(x.shape) # Create initial conditions initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed) - - # Initialize particle displacements + + # Initialize particle displacements dx, p, f = lpt(cosmo, initial_conditions, 1.0) field = cic_paint(jnp.zeros_like(initial_conditions), dx) return field + def main(args): - # Setting up distributed random numbers - master_key = jax.random.PRNGKey(42) - key = jax.random.split(master_key, size)[rank] + # Setting up distributed random numbers + master_key = jax.random.PRNGKey(42) + key = jax.random.split(master_key, size)[rank] - # Create computing mesh and sharding information - devices = mesh_utils.create_device_mesh((2,2)) - mesh = Mesh(devices.T, axis_names=('x', 'y')) + # Create computing mesh and sharding information + devices = mesh_utils.create_device_mesh((2, 2)) + mesh = Mesh(devices.T, axis_names=('x', 'y')) - # Run the simulation on the compute mesh - with mesh: - field = run_simulation(0.32, 0.8, key) + # Run the simulation on the compute mesh + with mesh: + field = run_simulation(0.32, 0.8, key) + + print('done') + np.save(f'field_{rank}.npy', field.addressable_data(0)) + + # Closing distributed jax + jax.distributed.shutdown() - print('done') - np.save(f'field_{rank}.npy', field.addressable_data(0)) - - # Closing distributed jax - jax.distributed.shutdown() if __name__ == '__main__': - parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") - args = parser.parse_args() - main(args) \ No newline at end of file + parser = argparse.ArgumentParser("Distributed LPT N-body simulation.") + args = parser.parse_args() + main(args) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 9a81440..398e0ed 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -16,11 +16,11 @@ from jax.experimental.shard_map import shard_map def autoshmap(f: Callable, - in_specs: Specs, - out_specs: Specs, - check_rep: bool = True, - auto: frozenset[AxisName] = frozenset()): - """Helper function to wrap the provided function in a shard map if + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()): + """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context.""" mesh = mesh_lib.thread_resources.env.physical_mesh if mesh.empty: @@ -28,23 +28,28 @@ def autoshmap(f: Callable, else: return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + def fft3d(x): - if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pfft3d(x.astype(jnp.complex64)) else: return jnp.fft.rfftn(x) + def ifft3d(x): - if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pifft3d(x).real else: return jnp.fft.irfftn(x) + def get_local_shape(mesh_shape): - """ Helper function to get the local size of a mesh given the global size. + """ Helper function to get the local size of a mesh given the global size. """ - if mesh_lib.thread_resources.env.physical_mesh.empty: - return mesh_shape - else: - pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape - return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]] \ No newline at end of file + if mesh_lib.thread_resources.env.physical_mesh.empty: + return mesh_shape + else: + pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape + return [ + mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2] + ] diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 64001f5..a4d83d6 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,12 +1,14 @@ -from jaxpm.distributed import autoshmap -from jax.sharding import PartitionSpec as P from functools import partial + import jax.numpy as jnp import numpy as np +from jax.sharding import PartitionSpec as P + +from jaxpm.distributed import autoshmap def fftk(shape, dtype=np.float32): - """ + """ Generate Fourier transform wave numbers for a given mesh. Args: @@ -16,18 +18,19 @@ def fftk(shape, dtype=np.float32): list: List of wave number arrays for each dimension in the order [kx, ky, kz]. """ - kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] - @partial( - autoshmap, - in_specs=(P('x'), P('y'), P(None)), - out_specs=(P('x'), P(None, 'y'), P(None))) - def get_kvec(ky, kz, kx): - return (ky.reshape([-1, 1, 1]), - kz.reshape([1, -1, 1]), - kx.reshape([1, 1, -1])) # yapf: disable - ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds - # to the order of dimensions in the transposed FFT - return kx, ky, kz + kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] + + @partial(autoshmap, + in_specs=(P('x'), P('y'), P(None)), + out_specs=(P('x'), P(None, 'y'), P(None))) + def get_kvec(ky, kz, kx): + return (ky.reshape([-1, 1, 1]), + kz.reshape([1, -1, 1]), + kx.reshape([1, 1, -1])) # yapf: disable + ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds + # to the order of dimensions in the transposed FFT + return kx, ky, kz + def gradient_kernel(kvec, direction, order=1): """ diff --git a/jaxpm/painting.py b/jaxpm/painting.py index bacaf46..aadae9e 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,26 +1,28 @@ +from functools import partial + import jax import jax.lax as lax import jax.numpy as jnp - -from jaxpm.kernels import cic_compensation, fftk from jax.sharding import PartitionSpec as P -from functools import partial -from jaxpm.distributed import autoshmap -@partial(autoshmap, - in_specs=(P('x', 'y'), P('x','y'), P('x','y')), - out_specs=P('x', 'y')) +from jaxpm.distributed import autoshmap +from jaxpm.kernels import cic_compensation, fftk + + +@partial(autoshmap, + in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')), + out_specs=P('x', 'y')) def cic_paint(mesh, displacement, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] """ part_shape = displacement.shape - positions = jnp.stack(jnp.meshgrid( - jnp.arange(part_shape[0]), - jnp.arange(part_shape[1]), - jnp.arange(part_shape[2]), - indexing='ij'), axis=-1) + displacement + positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), + axis=-1) + displacement positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -46,9 +48,7 @@ def cic_paint(mesh, displacement, weight=None): return mesh -@partial(autoshmap, - in_specs=(P('x', 'y'), P('x','y')), - out_specs=P('x', 'y')) +@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y')) def cic_read(mesh, displacement): """ Paints positions onto mesh mesh: [nx, ny, nz] @@ -56,11 +56,11 @@ def cic_read(mesh, displacement): """ # Compute the position of the particles on a regular grid part_shape = displacement.shape - positions = jnp.stack(jnp.meshgrid( - jnp.arange(part_shape[0]), - jnp.arange(part_shape[1]), - jnp.arange(part_shape[2]), - indexing='ij'), axis=-1) + displacement + positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), + jnp.arange(part_shape[1]), + jnp.arange(part_shape[2]), + indexing='ij'), + axis=-1) + displacement positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -75,7 +75,8 @@ def cic_read(mesh, displacement): jnp.array(mesh.shape)) return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], - neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1]) + neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape( + displacement.shape[:-1]) def cic_paint_2d(mesh, positions, weight): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 30a143c..fe16450 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,15 +1,16 @@ +from functools import partial + import jax import jax.numpy as jnp import jax_cosmo as jc from jax.sharding import PartitionSpec as P +from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d from jaxpm.growth import dGfa, growth_factor, growth_rate from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_read -from jaxpm.distributed import fft3d, ifft3d, autoshmap, get_local_shape -from functools import partial def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): """ @@ -100,28 +101,28 @@ def make_ode_fn(mesh_shape): return nbody_ode -def pgd_correction(pos, params): - """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 - args: - pos: particle positions [npart, 3] - params: [alpha, kl, ks] pgd parameters - """ - kvec = fftk(mesh_shape) +def pgd_correction(pos, params): + """ + improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 + args: + pos: particle positions [npart, 3] + params: [alpha, kl, ks] pgd parameters + """ + kvec = fftk(mesh_shape) - delta = cic_paint(jnp.zeros(mesh_shape), pos) - alpha, kl, ks = params - delta_k = jnp.fft.rfftn(delta) - PGD_range = PGD_kernel(kvec, kl, ks) + delta = cic_paint(jnp.zeros(mesh_shape), pos) + alpha, kl, ks = params + delta_k = jnp.fft.rfftn(delta) + PGD_range = PGD_kernel(kvec, kl, ks) - pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range + pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range - forces_pgd = jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) - for i in range(3) - ], - axis=-1) + forces_pgd = jnp.stack([ + cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos) + for i in range(3) + ], + axis=-1) - dpos_pgd = forces_pgd * alpha + dpos_pgd = forces_pgd * alpha - return dpos_pgd \ No newline at end of file + return dpos_pgd