mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add halo exchange and slice pad
This commit is contained in:
parent
319942a6bc
commit
ac86468c7c
1 changed files with 42 additions and 2 deletions
|
@ -13,7 +13,8 @@ except ImportError:
|
|||
import jax.numpy as jnp
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
from functools import partial
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
def autoshmap(f: Callable,
|
||||
in_specs: Specs,
|
||||
|
@ -34,13 +35,52 @@ def fft3d(x):
|
|||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||
else:
|
||||
return jnp.fft.rfftn(x)
|
||||
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
else:
|
||||
return jnp.fft.irfftn(x)
|
||||
|
||||
def halo_exchange(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.halo_exchange(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))
|
||||
def slice_pad_impl(x, pad_width):
|
||||
return jnp.pad(x, pad_width)
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))
|
||||
def slice_unpad_impl(x, pad_width):
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[0]
|
||||
|
||||
# Apply corrections along x
|
||||
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
return x
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return slice_pad_impl(x, pad_width)
|
||||
else:
|
||||
return x
|
||||
|
||||
def slice_unpad(x, pad_width):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return slice_unpad_impl(x, pad_width)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape):
|
||||
|
|
Loading…
Add table
Reference in a new issue