mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
update distributed ops for new jd
This commit is contained in:
parent
cb2a7ab17f
commit
7c3577ea71
1 changed files with 4 additions and 1 deletions
|
@ -79,6 +79,9 @@ def slice_unpad_impl(x, pad_width):
|
|||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
def slice_pad_impl(x, pad_width):
|
||||
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)
|
||||
|
||||
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
|
@ -86,7 +89,7 @@ def slice_pad(x, pad_width, sharding):
|
|||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||
return shard_map((partial(slice_pad_impl, pad_width=pad_width)),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
|
|
Loading…
Add table
Reference in a new issue