mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
get local shape and zeros can be used by users
This commit is contained in:
parent
19011d0712
commit
a757b62f4b
1 changed files with 3 additions and 2 deletions
|
@ -108,7 +108,7 @@ def slice_unpad(x, pad_width, sharding):
|
|||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape, sharding):
|
||||
def get_local_shape(mesh_shape, sharding=None):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
|
@ -122,10 +122,11 @@ def get_local_shape(mesh_shape, sharding):
|
|||
]
|
||||
|
||||
|
||||
def zeros(mesh_shape, sharding):
|
||||
def zeros(mesh_shape, sharding=None):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
return shard_map(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
|
|
Loading…
Add table
Reference in a new issue