mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add distributed zeros
This commit is contained in:
parent
4342279817
commit
cc4f310508
1 changed files with 12 additions and 1 deletions
|
@ -117,9 +117,20 @@ def get_local_shape(mesh_shape, sharding):
|
|||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
|
||||
]
|
||||
|
||||
def zeros(mesh_shape , sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
return shard_map(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
else:
|
||||
return jnp.zeros(mesh_shape)
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
|
|
Loading…
Add table
Reference in a new issue