This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -5,7 +5,10 @@ import jaxdecomp
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple
from functools import partial from functools import partial
from jax import jit
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
@dataclass @dataclass
class ShardingInfo: class ShardingInfo:
"""Class for keeping track of the distribution strategy""" """Class for keeping track of the distribution strategy"""
@ -31,25 +34,44 @@ def ifft3d(arr, sharding_info=None):
arr = jaxdecomp.pifft3d(arr) arr = jaxdecomp.pifft3d(arr)
return arr return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr def halo_reduce(arr, halo_size , gpu_mesh):
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape with gpu_mesh:
arr = jaxdecomp.halo_exchange(arr, arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0), halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True)) halo_periods=(True,True,True))
# Apply correction along x @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_x(arr):
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:]) arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
# Apply correction along y return arr
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_y(arr):
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :]) arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :]) arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr return arr
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def un_pad(arr):
return arr[halo_size:-halo_size, halo_size:-halo_size]
# Apply correction along x
arr = apply_correction_x(arr)
# Apply correction along y
arr = apply_correction_y(arr)
arr = un_pad(arr)
return arr
def meshgrid3d(shape, sharding_info=None): def meshgrid3d(shape, sharding_info=None):
if sharding_info is not None: if sharding_info is not None:
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None): def zeros(mesh , shape, sharding_info=None):
""" Initialize an array of given global shape """ Initialize an array of given global shape
partitionned if need be accross dimensions. partitionned if need be accross dimensions.
""" """
if sharding_info is None: if sharding_info is None:
return jnp.zeros(shape) return jnp.zeros(shape)
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:])) zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
return gspmd_zeros
def normal(key, shape, sharding_info=None): def normal(key, shape, sharding_info=None):

View file

@ -1,4 +1,5 @@
import jax import jax
from jax import jit
import jax.numpy as jnp import jax.numpy as jnp
import jax.lax as lax import jax.lax as lax
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
print(f" positions {positions.shape}")
if sharding_info is not None: if sharding_info is not None:
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'), @partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
# Add some padding for the halo exchange # Add some padding for the halo exchange
with gpu_mesh: with gpu_mesh:
nbody_mesh = sharded_pad(nbody_mesh) nbody_mesh = sharded_pad(nbody_mesh)
positions = add_halo(positions , halo_size) positions = add_halo(positions , halo_size)
with gpu_mesh: with gpu_mesh:
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jit(jnp.floor)(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]]) [0., 1, 1], [1., 1, 1]]])
@jit
def compute_kernels(positions , neighboor_coords):
kernel = (1. - jnp.abs(positions - neighboor_coords))
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
with gpu_mesh: with gpu_mesh:
neighboor_coords = floor + connection neighboor_coords = jit(jnp.add)(floor , connection)
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = compute_kernels(positions , neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.reshape( neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape)) [-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
if sharding_info == None: if sharding_info == None:
return nbody_mesh return nbody_mesh
else: else:
with gpu_mesh : nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
return nbody_mesh return nbody_mesh

View file

@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info) force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
forces.append(force) forces.append(force)
print(f"Shape {ifft_forces.shape}")
return jnp.stack(forces) return jnp.stack(forces , axis=-1)
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info) positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
print(f"Shape initial {initial_conditions.shape}")
@jax.jit @jax.jit
def compute_dx(cosmo , i_force): def compute_dx(cosmo , i_force):
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
p = compute_p(cosmo , dx) p = compute_p(cosmo , dx)
f = compute_f(cosmo , initial_force) f = compute_f(cosmo , initial_force)
return dx, p, f return dx, p, f

View file

@ -1,6 +1,7 @@
from mpi4py import MPI from mpi4py import MPI
import os import os
import jax import jax
from jax import jit
import jax.numpy as jnp import jax.numpy as jnp
import numpy as onp import numpy as onp
import jaxdecomp import jaxdecomp
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
def ifft3d_c2r(initial_conditions): def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real return ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
def run_sim(mesh , initial_conditions, cosmo, key): def run_sim(mesh , initial_conditions, cosmo, key):
with mesh: with mesh:
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, # [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
# rtol=1e-3, atol=1e-3) # rtol=1e-3, atol=1e-3)
## Painting on a new mesh ## Painting on a new mesh
print(f"shape of p {p.shape}")
print(f"shape of dx {dx.shape}")
with mesh: with mesh:
displacement = compute_displacement(p , dx) displacement = jit(jnp.add)(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info) empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field, field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info) displacement, halo_size, sharding_info=sharding_info)
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info) # pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions # # Recover the real space initial conditions
run_sim(mesh , initial_conditions,cosmo, key) init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key) #init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp # import jaxdecomp
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
# init_field.block_until_ready() # init_field.block_until_ready()
# time2 = time.time() # time2 = time.time()
# if rank == 0: # if rank == 0:
#onp.save('simulation_%d.npy'%rank, field) onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
# print('Done in', time2-time1) # print('Done in', time2-time1)