This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -5,7 +5,10 @@ import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
from functools import partial
from jax import jit
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
@ -31,22 +34,41 @@ def ifft3d(arr, sharding_info=None):
arr = jaxdecomp.pifft3d(arr)
return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
def halo_reduce(arr, halo_size , gpu_mesh):
with gpu_mesh:
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True))
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_x(arr):
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
return arr
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True))
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_y(arr):
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def un_pad(arr):
return arr[halo_size:-halo_size, halo_size:-halo_size]
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
arr = apply_correction_x(arr)
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
arr = apply_correction_y(arr)
arr = un_pad(arr)
return arr
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None):
def zeros(mesh , shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if sharding_info is None:
return jnp.zeros(shape)
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
return gspmd_zeros
def normal(key, shape, sharding_info=None):