apply formating

This commit is contained in:
Wassim KABALAN 2024-10-27 03:50:34 +01:00
parent c93894f561
commit 19011d0712
5 changed files with 22 additions and 15 deletions

View file

@ -117,10 +117,12 @@ def get_local_shape(mesh_shape, sharding):
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]
def zeros(mesh_shape , sharding):
def zeros(mesh_shape, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
@ -132,6 +134,7 @@ def zeros(mesh_shape , sharding):
else:
return jnp.zeros(mesh_shape)
def normal_field(mesh_shape, seed, sharding):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None