update code

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:24 +01:00
parent e0c118a540
commit 21373b89ee
7 changed files with 84 additions and 100 deletions

View file

@ -82,7 +82,7 @@ def slice_unpad_impl(x, pad_width):
def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty) and (
if gpu_mesh is not None and not (gpu_mesh.empty) and (
pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
@ -96,7 +96,7 @@ def slice_pad(x, pad_width, sharding):
def slice_unpad(x, pad_width, sharding):
mesh = sharding.mesh if sharding is not None else None
if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
@ -122,20 +122,6 @@ def get_local_shape(mesh_shape, sharding=None):
]
def zeros(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
return shard_map(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
else:
return jnp.zeros(mesh_shape)
def __axis_names(spec):
if len(spec) == 1:
x_axis, = spec
@ -158,7 +144,7 @@ def __axis_names(spec):
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
@ -183,7 +169,7 @@ def uniform_particles(mesh_shape, sharding=None):
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()