update distributed ops for new jd

This commit is contained in:
Wassim Kabalan 2025-01-18 01:12:53 +01:00
parent cb2a7ab17f
commit 7c3577ea71

View file

@ -79,6 +79,9 @@ def slice_unpad_impl(x, pad_width):
return x[tuple(unpad_slice)] return x[tuple(unpad_slice)]
def slice_pad_impl(x, pad_width):
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)
def slice_pad(x, pad_width, sharding): def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
@ -86,7 +89,7 @@ def slice_pad(x, pad_width, sharding):
pad_width[0][0] > 0 or pad_width[1][0] > 0): pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None assert sharding is not None
spec = sharding.spec spec = sharding.spec
return shard_map((partial(jnp.pad, pad_width=pad_width)), return shard_map((partial(slice_pad_impl, pad_width=pad_width)),
mesh=gpu_mesh, mesh=gpu_mesh,
in_specs=spec, in_specs=spec,
out_specs=spec)(x) out_specs=spec)(x)