diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 1a9e38f..14388ea 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None): # ix = sharding_info[0].Get_rank() # ny = sharding_info[1].Get_size() # iy = sharding_info[1].Get_rank() - ix = sharding_info.rank + ix = sharding_info.rank % nx iy = 0 shape = sharding_info.global_shape for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi + kd = jnp.fft.fftfreq(shape[d]) + kd *= 2 * jnp.pi if symmetric and d == len(shape) - 1: kd = kd[:shape[d] // 2 + 1] @@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None): return k -@partial(xmap, - in_axes=[['x', 'y', ...], - [['x'], ['y'], [...]]], - out_axes=['x', 'y', ...]) -def apply_gradient_laplace(kfield, kvec): - kx, ky, kz = kvec +@jax.jit +def apply_gradient_laplace(kfield, kx, ky, kz): kk = (kx**2 + ky**2 + kz**2) kernel = jnp.where(kk == 0, 1., 1./kk) return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), diff --git a/jaxpm/ops.py b/jaxpm/ops.py index ca78b82..41879a5 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -1,10 +1,10 @@ # Module for custom ops, typically mpi4jax import jax import jax.numpy as jnp -import mpi4jax import jaxdecomp from dataclasses import dataclass from typing import Tuple +from functools import partial @dataclass class ShardingInfo: @@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None): if sharding_info is None: arr = jnp.fft.fftn(arr).transpose([1, 2, 0]) else: - arr = jaxdecomp.pfft3d(arr, - pdims=sharding_info.pdims, - global_shape=sharding_info.global_shape) + arr = jaxdecomp.pfft3d(arr) return arr - def ifft3d(arr, sharding_info=None): if sharding_info is None: arr = jnp.fft.ifftn(arr.transpose([2, 0, 1])) else: - arr = jaxdecomp.pifft3d(arr, - pdims=sharding_info.pdims, - global_shape=sharding_info.global_shape) + arr = jaxdecomp.pifft3d(arr) return arr - def halo_reduce(arr, sharding_info=None): if sharding_info is None: return arr @@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None): global_shape = sharding_info.global_shape arr = jaxdecomp.halo_exchange(arr, halo_extents=(halo_size//2, halo_size//2, 0), - halo_periods=(True,True,True), - pdims=sharding_info.pdims, - global_shape=(global_shape[0]+2*halo_size, - global_shape[1]+halo_size, - global_shape[2])) + halo_periods=(True,True,True)) # Apply correction along x arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) @@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None): return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) - def zeros(shape, sharding_info=None): """ Initialize an array of given global shape partitionned if need be accross dimensions. diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 71c3671..88edd90 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -4,86 +4,124 @@ import jax.lax as lax from jaxpm.ops import halo_reduce from jaxpm.kernels import fftk, cic_compensation +import jaxdecomp +from functools import partial +from jax.sharding import Mesh, PartitionSpec as P,NamedSharding +from jax.experimental.shard_map import shard_map -def cic_paint(mesh, positions, halo_size=0, sharding_info=None): +@partial(jax.jit,static_argnums=(1)) +def add_halo(positions , halo_size): + positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) + return positions + + + +def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ + print(f" positions {positions.shape}") if sharding_info is not None: - # Add some padding for the halo exchange - mesh = jnp.pad(mesh, [[halo_size, halo_size], - [halo_size, halo_size], - [0, 0]]) - positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) + @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'), + out_specs=P('z', 'y')) + def sharded_pad(arr): + padded = jnp.pad(arr,pad_width=((halo_size, halo_size), (halo_size, halo_size), (0, 0))) + return padded + + # Add some padding for the halo exchange + with gpu_mesh: + nbody_mesh = sharded_pad(nbody_mesh) + positions = add_halo(positions , halo_size) + + with gpu_mesh: + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + with gpu_mesh: + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - neighboor_coords = jnp.mod(neighboor_coords.reshape( - [-1, 8, 3]).astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = jnp.mod(neighboor_coords.reshape( + [-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, - neighboor_coords, - kernel.reshape([-1, 8]), - dnums) + + with gpu_mesh: + nbody_mesh = lax.scatter_add(nbody_mesh, + neighboor_coords, + kernel.reshape([-1, 8]), + dnums) if sharding_info == None: - return mesh + return nbody_mesh else: - mesh = halo_reduce(mesh, sharding_info) - return mesh[halo_size:-halo_size, halo_size:-halo_size] + with gpu_mesh : + nbody_mesh = halo_reduce(nbody_mesh, sharding_info) + nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size] + + return nbody_mesh + +@jax.jit +def reduce_and_sum(mesh,neighboor_coords,kernel): + return (mesh[neighboor_coords[..., 0], + neighboor_coords[..., 1], + neighboor_coords[..., 3]]*kernel).sum(axis=-1) -def cic_read(mesh, positions, halo_size=0, sharding_info=None): + +def cic_read(gpu_mesh , mesh, positions, halo_size=0, sharding_info=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] """ + @partial(shard_map, mesh=gpu_mesh, in_specs=(P('z', 'y'),P()), + out_specs=P('z', 'y')) + def sharded_pad(arr , padding_width): + return jnp.pad(arr,pad_width=padding_width) if sharding_info is not None: # Add some padding and perfom hao exchange to retrieve # neighboring regions - mesh = jnp.pad(mesh, [[halo_size, halo_size], - [halo_size, halo_size], - [0, 0]]) + # mesh = halo_reduce(mesh, sharding_info) - import jaxdecomp - mesh = jaxdecomp.halo_exchange(mesh, - halo_extents=sharding_info.halo_extents, - halo_periods=(True,True,True), - pdims=sharding_info.pdims, - global_shape=sharding_info.global_shape) - positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) + with gpu_mesh: + padding_width = jnp.array([(halo_size, halo_size), (halo_size, halo_size), (0, 0)]) + #mesh = sharded_pad(mesh,padding_width) + mesh = jaxdecomp.halo_exchange(mesh, + halo_extents=sharding_info.halo_extents, + halo_periods=(True,True,True)) + positions = add_halo(positions , halo_size) + + with gpu_mesh: + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + with gpu_mesh: + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - neighboor_coords = jnp.mod( - neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + neighboor_coords = jnp.mod( + neighboor_coords.astype('int32'), jnp.array(mesh.shape)) - return (mesh[neighboor_coords[..., 0], - neighboor_coords[..., 1], - neighboor_coords[..., 3]]*kernel).sum(axis=-1) + reduced = reduce_and_sum(mesh,neighboor_coords,kernel) + + return reduced def cic_paint_2d(mesh, positions, weight): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 672f7b6..35532b3 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -8,9 +8,14 @@ from jaxpm.ops import fft3d, ifft3d, zeros, normal from jaxpm.kernels import fftk, apply_gradient_laplace from jaxpm.painting import cic_paint, cic_read from jaxpm.growth import growth_factor, growth_rate, dGfa +from jax.experimental import mesh_utils, multihost_utils +from jax.sharding import Mesh, PartitionSpec as P,NamedSharding +from jax.experimental.shard_map import shard_map +from functools import partial -def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None): + +def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None): """ Computes gravitational forces on particles using a PM scheme """ @@ -22,38 +27,88 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_in # Computes gravitational forces kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info) - forces_k = apply_gradient_laplace(delta_k, kvec) - # Interpolate forces at the position of particles - return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real, - positions, halo_size=halo_size, sharding_info=sharding_info) - for i in range(3)], axis=-1) + local_kx = kvec[0] + local_ky = kvec[1] + replicated_kz = kvec[2] + + gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z')) + gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y')) + + @partial(jax.jit,static_argnums=(1)) + def ifft3d_c2r(forces_k , i): + return ifft3d(forces_k[..., i], sharding_info=sharding_info).real + + forces = [] + with mesh: + forces_k = apply_gradient_laplace(delta_k, gspmd_kx , gspmd_ky , replicated_kz) + # Interpolate forces at the position of particles + + for i in range(3): + with mesh: + ifft_forces = ifft3d_c2r(forces_k , i) + + force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info) + forces.append(force) + print(f"Shape {ifft_forces.shape}") + + return jnp.stack(forces) -def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None): + +def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None): """ Computes first order LPT displacement """ - initial_force = pm_forces( + initial_force = pm_forces(mesh, positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info) a = jnp.atleast_1d(a) - dx = growth_factor(cosmo, a) * initial_force - p = a**2 * growth_rate(cosmo, a) * \ - jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx - f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \ - dGfa(cosmo, a) * initial_force + + print(f"Shape initial {initial_conditions.shape}") + + @jax.jit + def compute_dx(cosmo , i_force): + return growth_factor(cosmo, a) * i_force + + @jax.jit + def compute_p(cosmo , dx): + return a**2 * growth_rate(cosmo, a) * \ + jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx + + @jax.jit + def compute_f(cosmo , initial_force): + return a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \ + dGfa(cosmo, a) * initial_force + + with mesh: + dx = compute_dx(cosmo , initial_force) + p = compute_p(cosmo , dx) + f = compute_f(cosmo , initial_force) + return dx, p, f -def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None): +@jax.jit +def interpolate(kfield, kx, ky, kz , k , pk): + return kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), k, jnp.sqrt(pk)) + + +def linear_field(cosmo, mesh, mesh_shape, box_size, key, sharding_info=None): """ Generate initial conditions in Fourier space. """ # Sample normal field - field = normal(key, mesh_shape, sharding_info=sharding_info) + pdims = sharding_info.pdims + slice_shape = (mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0],mesh_shape[2]) + + slice_field = normal(key, slice_shape, sharding_info=sharding_info) + + field = multihost_utils.host_local_array_to_global_array( + slice_field, mesh, P('z', 'y')) # Transform to Fourier space - kfield = fft3d(field, sharding_info=sharding_info) + with mesh : + kfield = fft3d(field, sharding_info=sharding_info) # Rescaling k to physical units kvec = [k / box_size[i] * mesh_shape[i] @@ -68,11 +123,17 @@ def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None): ) / (box_size[0] * box_size[1] * box_size[2]) # Multipliyng the field by the proper power spectrum - kfield = xmap(lambda kfield, kx, ky, kz: - kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), - k, jnp.sqrt(pk)), - in_axes=(('x', 'y', ...), ['x'], ['y'], [...]), - out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2]) + + local_kx = kvec[0] + local_ky = kvec[1] + replicated_kz = kvec[2] + + gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z')) + gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y')) + + + with mesh: + kfield = interpolate(kfield,gspmd_kx, gspmd_ky, replicated_kz ,k, pk) return kfield diff --git a/scripts/test_nbody.py b/scripts/test_nbody.py index 8bf8092..4ce6d95 100644 --- a/scripts/test_nbody.py +++ b/scripts/test_nbody.py @@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint from jax.experimental.ode import odeint import jax_cosmo as jc import time - +from jax.experimental import mesh_utils, multihost_utils +from jax.sharding import Mesh, PartitionSpec as P,NamedSharding +from functools import partial ### Setting up a whole bunch of things ####### # Create communicators world = MPI.COMM_WORLD rank = world.Get_rank() size = world.Get_size() +jax.config.update("jax_enable_x64", True) + # Here we assume clients are on the same node, so we restrict which device # they can use based on their rank -os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1) - - -jaxdecomp.init() +jax.distributed.initialize() # Setup random keys master_key = jax.random.PRNGKey(42) @@ -35,37 +36,57 @@ mesh_shape = (N, N, N) box_size = [500, 500, 500] # Mpc/h halo_size = 32 sharding_info = ShardingInfo(global_shape=mesh_shape, - pdims=(1,2), + pdims=(2,2), halo_extents=(halo_size, halo_size, 0), rank=rank) cosmo = jc.Planck15() a = 0.1 +devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1]) +mesh = Mesh(devices, axis_names=('z', 'y')) + +initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key, + sharding_info=sharding_info) + @jax.jit -def run_sim(cosmo, key): - initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, - sharding_info=sharding_info) +def ifft3d_c2r(initial_conditions): + return ifft3d(initial_conditions, sharding_info=sharding_info).real - init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real +@jax.jit +def compute_displacement(p , dx): + return p + dx - # Initialize particles - pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) + + +def run_sim(mesh , initial_conditions, cosmo, key): + + with mesh: + init_field = ifft3d_c2r(initial_conditions) + + # Initialize particles + pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) # Initial displacement by LPT cosmo = jc.Planck15() - dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info) + + dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info) # And now, we run an actual nbody - res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info), - [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, - rtol=1e-3, atol=1e-3) - # Painting on a new mesh - field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), - res[0][-1], halo_size, sharding_info=sharding_info) - - # field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), - # pos+dx, halo_size, sharding_info=sharding_info) + #res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info), + # [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, + # rtol=1e-3, atol=1e-3) + ## Painting on a new mesh + print(f"shape of p {p.shape}") + print(f"shape of dx {dx.shape}") + with mesh: + displacement = compute_displacement(p , dx) + + empty_field = zeros(mesh_shape, sharding_info=sharding_info) + + field = cic_paint(mesh , empty_field, + displacement, halo_size, sharding_info=sharding_info) + return init_field, field # initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, @@ -87,7 +108,8 @@ def run_sim(cosmo, key): # pos+dx, halo_size, sharding_info=sharding_info) # # Recover the real space initial conditions -init_field, field = run_sim(cosmo, key) +run_sim(mesh , initial_conditions,cosmo, key) +#init_field, field = run_sim(mesh , initial_conditions,cosmo, key) # import jaxdecomp # field = jaxdecomp.halo_exchange(field, @@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key) # time2 = time.time() # if rank == 0: -onp.save('simulation_%d.npy'%rank, field) +#onp.save('simulation_%d.npy'%rank, field) # print('Done in', time2-time1) + +print("Done") + +jaxdecomp.finalize() \ No newline at end of file