mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
update code
This commit is contained in:
parent
e0c118a540
commit
21373b89ee
7 changed files with 84 additions and 100 deletions
|
@ -82,7 +82,7 @@ def slice_unpad_impl(x, pad_width):
|
|||
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty) and (
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
|
@ -96,7 +96,7 @@ def slice_pad(x, pad_width, sharding):
|
|||
|
||||
def slice_unpad(x, pad_width, sharding):
|
||||
mesh = sharding.mesh if sharding is not None else None
|
||||
if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
|
@ -122,20 +122,6 @@ def get_local_shape(mesh_shape, sharding=None):
|
|||
]
|
||||
|
||||
|
||||
def zeros(mesh_shape, sharding=None):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
return shard_map(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
else:
|
||||
return jnp.zeros(mesh_shape)
|
||||
|
||||
|
||||
def __axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
|
@ -158,7 +144,7 @@ def __axis_names(spec):
|
|||
def uniform_particles(mesh_shape, sharding=None):
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
|
@ -183,7 +169,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue