mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-19 01:20:55 +00:00
get local shape and zeros can be used by users
This commit is contained in:
parent
19011d0712
commit
a757b62f4b
1 changed files with 3 additions and 2 deletions
|
@ -108,7 +108,7 @@ def slice_unpad(x, pad_width, sharding):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_local_shape(mesh_shape, sharding):
|
def get_local_shape(mesh_shape, sharding=None):
|
||||||
""" Helper function to get the local size of a mesh given the global size.
|
""" Helper function to get the local size of a mesh given the global size.
|
||||||
"""
|
"""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
@ -122,10 +122,11 @@ def get_local_shape(mesh_shape, sharding):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def zeros(mesh_shape, sharding):
|
def zeros(mesh_shape, sharding=None):
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
spec = sharding.spec
|
||||||
return shard_map(
|
return shard_map(
|
||||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||||
mesh=gpu_mesh,
|
mesh=gpu_mesh,
|
||||||
|
|
Loading…
Add table
Reference in a new issue