From a757b62f4b4ef259000706a147663e5a8b6b7b8c Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 30 Oct 2024 01:55:20 +0100 Subject: [PATCH] get local shape and zeros can be used by users --- jaxpm/distributed.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 510c002..21694e5 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -108,7 +108,7 @@ def slice_unpad(x, pad_width, sharding): return x -def get_local_shape(mesh_shape, sharding): +def get_local_shape(mesh_shape, sharding=None): """ Helper function to get the local size of a mesh given the global size. """ gpu_mesh = sharding.mesh if sharding is not None else None @@ -122,10 +122,11 @@ def get_local_shape(mesh_shape, sharding): ] -def zeros(mesh_shape, sharding): +def zeros(mesh_shape, sharding=None): gpu_mesh = sharding.mesh if sharding is not None else None if not gpu_mesh is None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) + spec = sharding.spec return shard_map( partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), mesh=gpu_mesh,