mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
add distributed zeros
This commit is contained in:
parent
4342279817
commit
cc4f310508
1 changed files with 12 additions and 1 deletions
|
@ -117,9 +117,20 @@ def get_local_shape(mesh_shape, sharding):
|
||||||
else:
|
else:
|
||||||
pdims = gpu_mesh.devices.shape
|
pdims = gpu_mesh.devices.shape
|
||||||
return [
|
return [
|
||||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def zeros(mesh_shape , sharding):
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||||
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
return shard_map(
|
||||||
|
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||||
|
mesh=gpu_mesh,
|
||||||
|
in_specs=(),
|
||||||
|
out_specs=spec)() # yapf: disable
|
||||||
|
else:
|
||||||
|
return jnp.zeros(mesh_shape)
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed, sharding):
|
def normal_field(mesh_shape, seed, sharding):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
|
|
Loading…
Add table
Reference in a new issue