mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
56
jaxpm/ops.py
56
jaxpm/ops.py
|
@ -5,7 +5,10 @@ import jaxdecomp
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from functools import partial
|
||||
|
||||
from jax import jit
|
||||
from jax.experimental import mesh_utils, multihost_utils
|
||||
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@dataclass
|
||||
class ShardingInfo:
|
||||
"""Class for keeping track of the distribution strategy"""
|
||||
|
@ -31,22 +34,41 @@ def ifft3d(arr, sharding_info=None):
|
|||
arr = jaxdecomp.pifft3d(arr)
|
||||
return arr
|
||||
|
||||
def halo_reduce(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
|
||||
|
||||
def halo_reduce(arr, halo_size , gpu_mesh):
|
||||
|
||||
with gpu_mesh:
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True))
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def apply_correction_x(arr):
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||
|
||||
return arr
|
||||
halo_size = sharding_info.halo_extents[0]
|
||||
global_shape = sharding_info.global_shape
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True))
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def apply_correction_y(arr):
|
||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||
|
||||
return arr
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def un_pad(arr):
|
||||
return arr[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
|
||||
# Apply correction along x
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||
|
||||
arr = apply_correction_x(arr)
|
||||
# Apply correction along y
|
||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||
arr = apply_correction_y(arr)
|
||||
|
||||
arr = un_pad(arr)
|
||||
|
||||
|
||||
return arr
|
||||
|
||||
|
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
|
|||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
def zeros(shape, sharding_info=None):
|
||||
def zeros(mesh , shape, sharding_info=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
"""
|
||||
if sharding_info is None:
|
||||
return jnp.zeros(shape)
|
||||
|
||||
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||
zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
|
||||
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||
|
||||
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
|
||||
return gspmd_zeros
|
||||
|
||||
|
||||
def normal(key, shape, sharding_info=None):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue