mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
226 lines
6.4 KiB
Python
226 lines
6.4 KiB
Python
from abc import ABCMeta, abstractmethod
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from inspect import signature
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jaxdecomp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
from jaxdecomp import halo_exchange
|
|
from jaxdecomp.fft import pfft3d, pifft3d
|
|
|
|
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
|
ShardedOperator, register_operator)
|
|
|
|
|
|
class FFTOperator(CustomPartionedOperator):
|
|
|
|
name = 'fftn'
|
|
|
|
def single_gpu_impl(x):
|
|
return jnp.fft.fftn(x)
|
|
|
|
def multi_gpu_impl(x):
|
|
return pfft3d(x)
|
|
|
|
|
|
class IFFTOperator(CustomPartionedOperator):
|
|
|
|
name = 'ifftn'
|
|
|
|
def single_gpu_impl(x):
|
|
return jnp.fft.ifftn(x)
|
|
|
|
def multi_gpu_impl(x):
|
|
return pifft3d(x)
|
|
|
|
|
|
class HaloExchangeOperator(CustomPartionedOperator):
|
|
|
|
name = 'halo_exchange'
|
|
|
|
# Halo exchange does nothing on a single GPU
|
|
# Inside a jit , this will be optimized out
|
|
def single_gpu_impl(x):
|
|
return x
|
|
|
|
def multi_gpu_impl(x):
|
|
return halo_exchange(x)
|
|
|
|
|
|
# Padding and unpadding operators should not do anything in case of single GPU
|
|
# Since there is no halo exchange for a single GPU
|
|
# Inside a jit , this will be optimized out
|
|
|
|
|
|
class PaddingOperator(ShardedOperator):
|
|
|
|
name = 'slice_pad'
|
|
|
|
def single_gpu_impl(x, pad_width):
|
|
return x
|
|
|
|
def multi_gpu_impl(x, pad_width):
|
|
return jnp.pad(x, pad_width)
|
|
|
|
def infer_sharding_from_base_sharding(base_sharding):
|
|
|
|
in_spec = base_sharding, P()
|
|
out_spec = base_sharding
|
|
|
|
return in_spec, out_spec
|
|
|
|
|
|
class UnpaddingOperator(ShardedOperator):
|
|
|
|
name = 'slice_unpad'
|
|
|
|
def single_gpu_impl(x, pad_width):
|
|
return x
|
|
|
|
def multi_gpu_impl(x, pad_width):
|
|
|
|
# WARNING : unequal halo size is not supported
|
|
halo_x, _ = pad_width[0]
|
|
halo_y, _ = pad_width[0]
|
|
|
|
# Apply corrections along x
|
|
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
|
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
|
# Apply corrections along y
|
|
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
|
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
|
|
|
return x[halo_x:-halo_x, halo_y:-halo_y, :]
|
|
|
|
def infer_sharding_from_base_sharding(base_sharding):
|
|
|
|
in_spec = base_sharding, P()
|
|
out_spec = base_sharding
|
|
|
|
return in_spec, out_spec
|
|
|
|
|
|
class NormalFieldOperator(CallBackOperator):
|
|
|
|
name = 'normal'
|
|
|
|
def single_gpu_impl(shape, key, dtype='float32'):
|
|
return jax.random.normal(key, shape, dtype=dtype)
|
|
|
|
def multi_gpu_impl(shape, key, dtype='float32', base_sharding=None):
|
|
|
|
assert (isinstance(base_sharding, NamedSharding))
|
|
sharding = NormalFieldOperator.shardings_to_use_in_impl(base_sharding)
|
|
|
|
def get_axis_size(sharding, index):
|
|
axis_name = sharding.spec[index]
|
|
if axis_name == None:
|
|
return 1
|
|
else:
|
|
return sharding.mesh.shape[sharding.spec[index]]
|
|
|
|
pdims = [get_axis_size(sharding, i) for i in range(2)]
|
|
local_mesh_shape = [
|
|
shape[0] // pdims[1], shape[1] // pdims[0], shape[2]
|
|
]
|
|
|
|
return jax.make_array_from_single_device_arrays(
|
|
shape=shape,
|
|
sharding=sharding,
|
|
arrays=[jax.random.normal(key, local_mesh_shape, dtype=dtype)])
|
|
|
|
def shardings_to_use_in_impl(base_sharding):
|
|
return base_sharding
|
|
|
|
|
|
class FFTKOperator(CallBackOperator):
|
|
name = 'fftk'
|
|
|
|
def single_gpu_impl(shape, symmetric=True, finite=False, dtype=np.float32):
|
|
k = []
|
|
for d in range(len(shape)):
|
|
kd = np.fft.fftfreq(shape[d])
|
|
kd *= 2 * np.pi
|
|
kdshape = np.ones(len(shape), dtype='int')
|
|
if symmetric and d == len(shape) - 1:
|
|
kd = kd[:shape[d] // 2 + 1]
|
|
kdshape[d] = len(kd)
|
|
kd = kd.reshape(kdshape)
|
|
|
|
k.append(kd.astype(dtype))
|
|
del kd, kdshape
|
|
return k
|
|
|
|
def multi_gpu_impl(shape,
|
|
symmetric=True,
|
|
finite=False,
|
|
dtype=np.float32,
|
|
base_sharding=None):
|
|
|
|
assert (isinstance(base_sharding, NamedSharding))
|
|
kvec = FFTKOperator.single_gpu_impl(shape, symmetric, finite, dtype)
|
|
|
|
z_sharding, y_sharding = FFTKOperator.shardings_to_use_in_impl(shape)
|
|
|
|
return [
|
|
jax.make_array_from_callback(
|
|
(shape[0], 1, 1),
|
|
sharding=z_sharding,
|
|
data_callback=lambda x: kvec[0].reshape([-1, 1, 1])[x]),
|
|
jax.make_array_from_callback(
|
|
(1, shape[1], 1),
|
|
sharding=y_sharding,
|
|
data_callback=lambda x: kvec[1].reshape([1, -1, 1])[x]),
|
|
kvec[2].reshape([1, 1, -1])
|
|
]
|
|
|
|
@staticmethod
|
|
def shardings_to_use_in_impl(base_sharding):
|
|
spec = base_sharding.spec
|
|
|
|
z_sharding = NamedSharding(P(spec[0], None, None))
|
|
y_sharding = NamedSharding(P(None, spec[1], None))
|
|
|
|
return z_sharding, y_sharding
|
|
|
|
|
|
class GenerateParticlesOperator(CallBackOperator):
|
|
|
|
name = 'generate_initial_positions'
|
|
|
|
def single_gpu_impl(shape):
|
|
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in shape]),
|
|
axis=-1)
|
|
|
|
def multi_gpu_impl(shape, base_sharding=None):
|
|
assert (isinstance(base_sharding, NamedSharding))
|
|
sharding = GenerateParticlesOperator.shardings_to_use_in_impl(
|
|
base_sharding)
|
|
|
|
return jax.make_array_from_callback(
|
|
shape=tuple([*shape, 3]),
|
|
sharding=sharding,
|
|
data_callback=lambda x: jnp.stack(jnp.meshgrid(
|
|
jnp.arange(shape[0])[x[0]],
|
|
jnp.arange(shape[1])[x[1]],
|
|
jnp.arange(shape[2]),
|
|
indexing='ij'),
|
|
axis=-1))
|
|
|
|
def shardings_to_use_in_impl(base_sharding):
|
|
return base_sharding
|
|
|
|
|
|
register_operator(FFTOperator)
|
|
register_operator(IFFTOperator)
|
|
register_operator(HaloExchangeOperator)
|
|
register_operator(PaddingOperator)
|
|
register_operator(UnpaddingOperator)
|
|
register_operator(NormalFieldOperator)
|
|
register_operator(FFTKOperator)
|
|
register_operator(GenerateParticlesOperator)
|