mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-28 16:11:11 +00:00
Merge cf979f3250
into cb2a7ab17f
This commit is contained in:
commit
7834305478
1 changed files with 16 additions and 1 deletions
|
@ -79,6 +79,21 @@ def slice_unpad_impl(x, pad_width):
|
||||||
|
|
||||||
return x[tuple(unpad_slice)]
|
return x[tuple(unpad_slice)]
|
||||||
|
|
||||||
|
def slice_pad_impl(x, pad_width):
|
||||||
|
|
||||||
|
x = jnp.pad(x, pad_width=pad_width)
|
||||||
|
|
||||||
|
halo_x, _ = pad_width[0]
|
||||||
|
halo_y, _ = pad_width[1]
|
||||||
|
|
||||||
|
# Apply corrections along x
|
||||||
|
x = x.at[halo_x // 2:halo_x].add(x[halo_x:halo_x + halo_x // 2])
|
||||||
|
x = x.at[-halo_x:-halo_x // 2].add(x[-(halo_x + halo_x // 2):-halo_x])
|
||||||
|
# Apply corrections along y
|
||||||
|
x = x.at[:, halo_y // 2:halo_y].add(x[:, halo_y:halo_y + halo_y // 2])
|
||||||
|
x = x.at[:, -halo_y:-halo_y // 2].add(x[:, -(halo_y + halo_y // 2):-halo_y])
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
def slice_pad(x, pad_width, sharding):
|
def slice_pad(x, pad_width, sharding):
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
@ -86,7 +101,7 @@ def slice_pad(x, pad_width, sharding):
|
||||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||||
assert sharding is not None
|
assert sharding is not None
|
||||||
spec = sharding.spec
|
spec = sharding.spec
|
||||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
return shard_map((partial(slice_pad_impl, pad_width=pad_width)),
|
||||||
mesh=gpu_mesh,
|
mesh=gpu_mesh,
|
||||||
in_specs=spec,
|
in_specs=spec,
|
||||||
out_specs=spec)(x)
|
out_specs=spec)(x)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue