mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
Shared operation in fourrier space now take inverted sharding axis for
slabs
This commit is contained in:
parent
8c5bd76c33
commit
75604d2396
2 changed files with 23 additions and 8 deletions
|
@ -35,15 +35,26 @@ def autoshmap(f: Callable,
|
|||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = True,
|
||||
auto: frozenset[AxisName] = frozenset()):
|
||||
auto: frozenset[AxisName] = frozenset(),
|
||||
in_fourrier_space=False) -> Callable:
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
return f
|
||||
else:
|
||||
if in_fourrier_space and 1 in mesh.devices.shape:
|
||||
in_specs , out_specs = switch_specs((in_specs , out_specs))
|
||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
def switch_specs(specs):
|
||||
if isinstance(specs, P):
|
||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
|
||||
return P(*new_axes)
|
||||
elif isinstance(specs, tuple):
|
||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
||||
else:
|
||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
||||
|
||||
def fft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
|
@ -87,17 +98,21 @@ def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
|||
def slice_unpad_impl(x, pad_width):
|
||||
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[0]
|
||||
|
||||
halo_y, _ = pad_width[1]
|
||||
# Apply corrections along x
|
||||
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
return x[halo_x:-halo_x, halo_y:-halo_y, :]
|
||||
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x , -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y , -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
|
|
|
@ -46,7 +46,7 @@ def fftk(shape, dtype=np.float32):
|
|||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x'), P('y'), P(None)),
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)))
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
|
||||
def get_kvec(ky, kz, kx):
|
||||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
|
@ -73,7 +73,7 @@ def interpolate_power_spectrum(input, k, pk):
|
|||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
|
|
Loading…
Add table
Reference in a new issue