diff --git a/jaxpm/ops.py b/jaxpm/ops.py index 41879a5..0d16b26 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -5,7 +5,10 @@ import jaxdecomp from dataclasses import dataclass from typing import Tuple from functools import partial - +from jax import jit +from jax.experimental import mesh_utils, multihost_utils +from jax.sharding import Mesh, PartitionSpec as P,NamedSharding +from jax.experimental.shard_map import shard_map @dataclass class ShardingInfo: """Class for keeping track of the distribution strategy""" @@ -31,22 +34,41 @@ def ifft3d(arr, sharding_info=None): arr = jaxdecomp.pifft3d(arr) return arr -def halo_reduce(arr, sharding_info=None): - if sharding_info is None: + + +def halo_reduce(arr, halo_size , gpu_mesh): + + with gpu_mesh: + arr = jaxdecomp.halo_exchange(arr, + halo_extents=(halo_size//2, halo_size//2, 0), + halo_periods=(True,True,True)) + + @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y')) + def apply_correction_x(arr): + arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) + arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:]) + return arr - halo_size = sharding_info.halo_extents[0] - global_shape = sharding_info.global_shape - arr = jaxdecomp.halo_exchange(arr, - halo_extents=(halo_size//2, halo_size//2, 0), - halo_periods=(True,True,True)) + + @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y')) + def apply_correction_y(arr): + arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :]) + arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :]) + + return arr + + @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y')) + def un_pad(arr): + return arr[halo_size:-halo_size, halo_size:-halo_size] + # Apply correction along x - arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) - arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:]) - + arr = apply_correction_x(arr) # Apply correction along y - arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :]) - arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :]) + arr = apply_correction_y(arr) + + arr = un_pad(arr) + return arr @@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None): return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) -def zeros(shape, sharding_info=None): +def zeros(mesh , shape, sharding_info=None): """ Initialize an array of given global shape partitionned if need be accross dimensions. """ if sharding_info is None: return jnp.zeros(shape) - return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:])) + zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \ + sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:])) + + gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y')) + return gspmd_zeros def normal(key, shape, sharding_info=None): diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 88edd90..d8e9492 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,4 +1,5 @@ import jax +from jax import jit import jax.numpy as jnp import jax.lax as lax @@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None): mesh: [nx, ny, nz] positions: [npart, 3] """ - print(f" positions {positions.shape}") if sharding_info is not None: @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'), @@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None): # Add some padding for the halo exchange with gpu_mesh: nbody_mesh = sharded_pad(nbody_mesh) + positions = add_halo(positions , halo_size) with gpu_mesh: positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) + floor = jit(jnp.floor)(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) + + @jit + def compute_kernels(positions , neighboor_coords): + kernel = (1. - jnp.abs(positions - neighboor_coords)) + return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2]) with gpu_mesh: - neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) - kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + neighboor_coords = jit(jnp.add)(floor , connection) + kernel = compute_kernels(positions , neighboor_coords) neighboor_coords = jnp.mod(neighboor_coords.reshape( [-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape)) @@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None): if sharding_info == None: return nbody_mesh else: - with gpu_mesh : - nbody_mesh = halo_reduce(nbody_mesh, sharding_info) - nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size] + nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh) return nbody_mesh diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 35532b3..6c67215 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info) forces.append(force) - print(f"Shape {ifft_forces.shape}") - return jnp.stack(forces) + return jnp.stack(forces , axis=-1) @@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info) a = jnp.atleast_1d(a) - print(f"Shape initial {initial_conditions.shape}") @jax.jit def compute_dx(cosmo , i_force): @@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf p = compute_p(cosmo , dx) f = compute_f(cosmo , initial_force) + + return dx, p, f diff --git a/scripts/test_nbody.py b/scripts/test_nbody.py index 4ce6d95..d405dc9 100644 --- a/scripts/test_nbody.py +++ b/scripts/test_nbody.py @@ -1,6 +1,7 @@ from mpi4py import MPI import os import jax +from jax import jit import jax.numpy as jnp import numpy as onp import jaxdecomp @@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key, def ifft3d_c2r(initial_conditions): return ifft3d(initial_conditions, sharding_info=sharding_info).real -@jax.jit -def compute_displacement(p , dx): - return p + dx - - - def run_sim(mesh , initial_conditions, cosmo, key): with mesh: @@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key): # [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, # rtol=1e-3, atol=1e-3) ## Painting on a new mesh - print(f"shape of p {p.shape}") - print(f"shape of dx {dx.shape}") with mesh: - displacement = compute_displacement(p , dx) + displacement = jit(jnp.add)(p , dx) - empty_field = zeros(mesh_shape, sharding_info=sharding_info) + empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info) field = cic_paint(mesh , empty_field, displacement, halo_size, sharding_info=sharding_info) @@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key): # pos+dx, halo_size, sharding_info=sharding_info) # # Recover the real space initial conditions -run_sim(mesh , initial_conditions,cosmo, key) +init_field, field = run_sim(mesh , initial_conditions,cosmo, key) #init_field, field = run_sim(mesh , initial_conditions,cosmo, key) # import jaxdecomp @@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key) # init_field.block_until_ready() # time2 = time.time() + + # if rank == 0: -#onp.save('simulation_%d.npy'%rank, field) +onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16)) +onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16)) # print('Done in', time2-time1)