get local shape and zeros can be used by users

This commit is contained in:
Wassim KABALAN 2024-10-30 01:55:20 +01:00
parent 19011d0712
commit a757b62f4b

View file

@ -108,7 +108,7 @@ def slice_unpad(x, pad_width, sharding):
return x
def get_local_shape(mesh_shape, sharding):
def get_local_shape(mesh_shape, sharding=None):
""" Helper function to get the local size of a mesh given the global size.
"""
gpu_mesh = sharding.mesh if sharding is not None else None
@ -122,10 +122,11 @@ def get_local_shape(mesh_shape, sharding):
]
def zeros(mesh_shape, sharding):
def zeros(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
return shard_map(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
mesh=gpu_mesh,