add a user facing function to create uniform particle grid

This commit is contained in:
Wassim KABALAN 2024-10-30 01:55:47 +01:00
parent a757b62f4b
commit f3b431aa74

View file

@ -136,7 +136,51 @@ def zeros(mesh_shape, sharding=None):
return jnp.zeros(mesh_shape)
def normal_field(mesh_shape, seed, sharding):
def __axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
single_axis = True
elif len(spec) == 2:
x_axis, y_axis = spec
if y_axis == None:
single_axis = True
elif x_axis == None:
x_axis = y_axis
single_axis = True
else:
single_axis = False
else:
raise ValueError("Only 1 or 2 axis sharding is supported")
return x_axis, y_axis, single_axis
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
y_indx = 0 if single_axis else lax.axis_index(y_axis)
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
z = jnp.arange(local_mesh_shape[2])
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
out_specs=spec)()
else:
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
indexing='ij'),
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
@ -147,23 +191,8 @@ def normal_field(mesh_shape, seed, sharding):
# process_index is multi_host only
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
if len(spec) == 1:
x_axis, = spec
y_axis = None
single_axis = True
elif len(spec) == 2:
x_axis, y_axis = spec
if y_axis == None:
single_axis = True
elif x_axis == None:
x_axis = y_axis
single_axis = True
else:
single_axis = False
else:
raise ValueError("Only 1 or 2 axis sharding is supported")
x_axis, y_axis, single_axis = __axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)