From 7c3577ea715de17bdf75ab75a93ff8f0173145ca Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 18 Jan 2025 01:12:53 +0100 Subject: [PATCH] update distributed ops for new jd --- jaxpm/distributed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 721a971..af9aecd 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -79,6 +79,9 @@ def slice_unpad_impl(x, pad_width): return x[tuple(unpad_slice)] +def slice_pad_impl(x, pad_width): + return jax.tree.map(lambda x: jnp.pad(x, pad_width), x) + def slice_pad(x, pad_width, sharding): gpu_mesh = sharding.mesh if sharding is not None else None @@ -86,7 +89,7 @@ def slice_pad(x, pad_width, sharding): pad_width[0][0] > 0 or pad_width[1][0] > 0): assert sharding is not None spec = sharding.spec - return shard_map((partial(jnp.pad, pad_width=pad_width)), + return shard_map((partial(slice_pad_impl, pad_width=pad_width)), mesh=gpu_mesh, in_specs=spec, out_specs=spec)(x)