diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 510c002..21694e5 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -108,7 +108,7 @@ def slice_unpad(x, pad_width, sharding): return x -def get_local_shape(mesh_shape, sharding): +def get_local_shape(mesh_shape, sharding=None): """ Helper function to get the local size of a mesh given the global size. """ gpu_mesh = sharding.mesh if sharding is not None else None @@ -122,10 +122,11 @@ def get_local_shape(mesh_shape, sharding): ] -def zeros(mesh_shape, sharding): +def zeros(mesh_shape, sharding=None): gpu_mesh = sharding.mesh if sharding is not None else None if not gpu_mesh is None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) + spec = sharding.spec return shard_map( partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), mesh=gpu_mesh,