JaxPM/jaxpm/_src/base_ops.py
2024-07-08 00:23:35 +02:00

226 lines
6.4 KiB
Python

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import partial
from inspect import signature
import jax
import jax.numpy as jnp
import jaxdecomp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp import halo_exchange
from jaxdecomp.fft import pfft3d, pifft3d
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
ShardedOperator, register_operator)
class FFTOperator(CustomPartionedOperator):
name = 'fftn'
def single_gpu_impl(x):
return jnp.fft.fftn(x)
def multi_gpu_impl(x):
return pfft3d(x)
class IFFTOperator(CustomPartionedOperator):
name = 'ifftn'
def single_gpu_impl(x):
return jnp.fft.ifftn(x)
def multi_gpu_impl(x):
return pifft3d(x)
class HaloExchangeOperator(CustomPartionedOperator):
name = 'halo_exchange'
# Halo exchange does nothing on a single GPU
# Inside a jit , this will be optimized out
def single_gpu_impl(x):
return x
def multi_gpu_impl(x):
return halo_exchange(x)
# Padding and unpadding operators should not do anything in case of single GPU
# Since there is no halo exchange for a single GPU
# Inside a jit , this will be optimized out
class PaddingOperator(ShardedOperator):
name = 'slice_pad'
def single_gpu_impl(x, pad_width):
return x
def multi_gpu_impl(x, pad_width):
return jnp.pad(x, pad_width)
def infer_sharding_from_base_sharding(base_sharding):
in_spec = base_sharding, P()
out_spec = base_sharding
return in_spec, out_spec
class UnpaddingOperator(ShardedOperator):
name = 'slice_unpad'
def single_gpu_impl(x, pad_width):
return x
def multi_gpu_impl(x, pad_width):
# WARNING : unequal halo size is not supported
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[0]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y, :]
def infer_sharding_from_base_sharding(base_sharding):
in_spec = base_sharding, P()
out_spec = base_sharding
return in_spec, out_spec
class NormalFieldOperator(CallBackOperator):
name = 'normal'
def single_gpu_impl(shape, key, dtype='float32'):
return jax.random.normal(key, shape, dtype=dtype)
def multi_gpu_impl(shape, key, dtype='float32', base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
sharding = NormalFieldOperator.shardings_to_use_in_impl(base_sharding)
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
pdims = [get_axis_size(sharding, i) for i in range(2)]
local_mesh_shape = [
shape[0] // pdims[1], shape[1] // pdims[0], shape[2]
]
return jax.make_array_from_single_device_arrays(
shape=shape,
sharding=sharding,
arrays=[jax.random.normal(key, local_mesh_shape, dtype=dtype)])
def shardings_to_use_in_impl(base_sharding):
return base_sharding
class FFTKOperator(CallBackOperator):
name = 'fftk'
def single_gpu_impl(shape, symmetric=True, finite=False, dtype=np.float32):
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
def multi_gpu_impl(shape,
symmetric=True,
finite=False,
dtype=np.float32,
base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
kvec = FFTKOperator.single_gpu_impl(shape, symmetric, finite, dtype)
z_sharding, y_sharding = FFTKOperator.shardings_to_use_in_impl(shape)
return [
jax.make_array_from_callback(
(shape[0], 1, 1),
sharding=z_sharding,
data_callback=lambda x: kvec[0].reshape([-1, 1, 1])[x]),
jax.make_array_from_callback(
(1, shape[1], 1),
sharding=y_sharding,
data_callback=lambda x: kvec[1].reshape([1, -1, 1])[x]),
kvec[2].reshape([1, 1, -1])
]
@staticmethod
def shardings_to_use_in_impl(base_sharding):
spec = base_sharding.spec
z_sharding = NamedSharding(P(spec[0], None, None))
y_sharding = NamedSharding(P(None, spec[1], None))
return z_sharding, y_sharding
class GenerateParticlesOperator(CallBackOperator):
name = 'generate_initial_positions'
def single_gpu_impl(shape):
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in shape]),
axis=-1)
def multi_gpu_impl(shape, base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
sharding = GenerateParticlesOperator.shardings_to_use_in_impl(
base_sharding)
return jax.make_array_from_callback(
shape=tuple([*shape, 3]),
sharding=sharding,
data_callback=lambda x: jnp.stack(jnp.meshgrid(
jnp.arange(shape[0])[x[0]],
jnp.arange(shape[1])[x[1]],
jnp.arange(shape[2]),
indexing='ij'),
axis=-1))
def shardings_to_use_in_impl(base_sharding):
return base_sharding
register_operator(FFTOperator)
register_operator(IFFTOperator)
register_operator(HaloExchangeOperator)
register_operator(PaddingOperator)
register_operator(UnpaddingOperator)
register_operator(NormalFieldOperator)
register_operator(FFTKOperator)
register_operator(GenerateParticlesOperator)