diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index dbbb8fd..4007355 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -117,9 +117,20 @@ def get_local_shape(mesh_shape, sharding): else: pdims = gpu_mesh.devices.shape return [ - mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2] + mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:] ] +def zeros(mesh_shape , sharding): + gpu_mesh = sharding.mesh if sharding is not None else None + if not gpu_mesh is None and not (gpu_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape, sharding) + return shard_map( + partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), + mesh=gpu_mesh, + in_specs=(), + out_specs=spec)() # yapf: disable + else: + return jnp.zeros(mesh_shape) def normal_field(mesh_shape, seed, sharding): """Generate a Gaussian random field with the given power spectrum."""