mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
# Module for custom ops, typically mpi4jax
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jaxdecomp
|
|
from dataclasses import dataclass
|
|
from typing import Tuple
|
|
from functools import partial
|
|
|
|
@dataclass
|
|
class ShardingInfo:
|
|
"""Class for keeping track of the distribution strategy"""
|
|
global_shape: Tuple[int, int, int]
|
|
pdims: Tuple[int, int]
|
|
halo_extents: Tuple[int, int, int]
|
|
rank: int = 0
|
|
|
|
|
|
def fft3d(arr, sharding_info=None):
|
|
""" Computes forward FFT, note that the output is transposed
|
|
"""
|
|
if sharding_info is None:
|
|
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
|
else:
|
|
arr = jaxdecomp.pfft3d(arr)
|
|
return arr
|
|
|
|
def ifft3d(arr, sharding_info=None):
|
|
if sharding_info is None:
|
|
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
|
else:
|
|
arr = jaxdecomp.pifft3d(arr)
|
|
return arr
|
|
|
|
def halo_reduce(arr, sharding_info=None):
|
|
if sharding_info is None:
|
|
return arr
|
|
halo_size = sharding_info.halo_extents[0]
|
|
global_shape = sharding_info.global_shape
|
|
arr = jaxdecomp.halo_exchange(arr,
|
|
halo_extents=(halo_size//2, halo_size//2, 0),
|
|
halo_periods=(True,True,True))
|
|
|
|
# Apply correction along x
|
|
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
|
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
|
|
|
# Apply correction along y
|
|
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
|
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
|
|
|
return arr
|
|
|
|
|
|
def meshgrid3d(shape, sharding_info=None):
|
|
if sharding_info is not None:
|
|
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
|
|
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
|
|
else:
|
|
coords = [jnp.arange(s) for s in shape[2:]]
|
|
|
|
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
|
|
|
def zeros(shape, sharding_info=None):
|
|
""" Initialize an array of given global shape
|
|
partitionned if need be accross dimensions.
|
|
"""
|
|
if sharding_info is None:
|
|
return jnp.zeros(shape)
|
|
|
|
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
|
|
|
|
|
def normal(key, shape, sharding_info=None):
|
|
""" Generates a normal variable for the given
|
|
global shape.
|
|
"""
|
|
if sharding_info is None:
|
|
return jax.random.normal(key, shape)
|
|
|
|
return jax.random.normal(key,
|
|
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])
|