mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
temp commit
This commit is contained in:
parent
6ca4c9191e
commit
055ceedb7e
5 changed files with 220 additions and 110 deletions
19
jaxpm/ops.py
19
jaxpm/ops.py
|
@ -1,10 +1,10 @@
|
|||
# Module for custom ops, typically mpi4jax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mpi4jax
|
||||
import jaxdecomp
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from functools import partial
|
||||
|
||||
@dataclass
|
||||
class ShardingInfo:
|
||||
|
@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None):
|
|||
if sharding_info is None:
|
||||
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
||||
else:
|
||||
arr = jaxdecomp.pfft3d(arr,
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=sharding_info.global_shape)
|
||||
arr = jaxdecomp.pfft3d(arr)
|
||||
return arr
|
||||
|
||||
|
||||
def ifft3d(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
||||
else:
|
||||
arr = jaxdecomp.pifft3d(arr,
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=sharding_info.global_shape)
|
||||
arr = jaxdecomp.pifft3d(arr)
|
||||
return arr
|
||||
|
||||
|
||||
def halo_reduce(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
return arr
|
||||
|
@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None):
|
|||
global_shape = sharding_info.global_shape
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True),
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=(global_shape[0]+2*halo_size,
|
||||
global_shape[1]+halo_size,
|
||||
global_shape[2]))
|
||||
halo_periods=(True,True,True))
|
||||
|
||||
# Apply correction along x
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
|
@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None):
|
|||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
|
||||
def zeros(shape, sharding_info=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue