add distributed zeros

This commit is contained in:
Wassim KABALAN 2024-10-27 03:48:17 +01:00
parent 4342279817
commit cc4f310508

View file

@ -117,9 +117,20 @@ def get_local_shape(mesh_shape, sharding):
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
]
def zeros(mesh_shape , sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
return shard_map(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
else:
return jnp.zeros(mesh_shape)
def normal_field(mesh_shape, seed, sharding):
"""Generate a Gaussian random field with the given power spectrum."""