JaxPM/jaxpm/ops.py
Wassim KABALAN 055ceedb7e temp commit
2024-04-19 01:11:25 +02:00

81 lines
2.8 KiB
Python

# Module for custom ops, typically mpi4jax
import jax
import jax.numpy as jnp
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
from functools import partial
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
global_shape: Tuple[int, int, int]
pdims: Tuple[int, int]
halo_extents: Tuple[int, int, int]
rank: int = 0
def fft3d(arr, sharding_info=None):
""" Computes forward FFT, note that the output is transposed
"""
if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else:
arr = jaxdecomp.pfft3d(arr)
return arr
def ifft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = jaxdecomp.pifft3d(arr)
return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True))
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr
def meshgrid3d(shape, sharding_info=None):
if sharding_info is not None:
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if sharding_info is None:
return jnp.zeros(shape)
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
def normal(key, shape, sharding_info=None):
""" Generates a normal variable for the given
global shape.
"""
if sharding_info is None:
return jax.random.normal(key, shape)
return jax.random.normal(key,
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])