temp commit

This commit is contained in:
Wassim KABALAN 2024-04-19 01:11:25 +02:00
parent 6ca4c9191e
commit 055ceedb7e
5 changed files with 220 additions and 110 deletions

View file

@ -1,10 +1,10 @@
# Module for custom ops, typically mpi4jax
import jax
import jax.numpy as jnp
import mpi4jax
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
from functools import partial
@dataclass
class ShardingInfo:
@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else:
arr = jaxdecomp.pfft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
arr = jaxdecomp.pfft3d(arr)
return arr
def ifft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = jaxdecomp.pifft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
arr = jaxdecomp.pifft3d(arr)
return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr
@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None):
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=(global_shape[0]+2*halo_size,
global_shape[1]+halo_size,
global_shape[2]))
halo_periods=(True,True,True))
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.