mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add a user facing function to create uniform particle grid
This commit is contained in:
parent
a757b62f4b
commit
f3b431aa74
1 changed files with 46 additions and 17 deletions
|
@ -136,7 +136,51 @@ def zeros(mesh_shape, sharding=None):
|
|||
return jnp.zeros(mesh_shape)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding):
|
||||
def __axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
return x_axis, y_axis, single_axis
|
||||
|
||||
|
||||
def uniform_particles(mesh_shape, sharding=None):
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||
|
||||
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||
z = jnp.arange(local_mesh_shape[2])
|
||||
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||
|
||||
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||
out_specs=spec)()
|
||||
else:
|
||||
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||
indexing='ij'),
|
||||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
|
@ -147,23 +191,8 @@ def normal_field(mesh_shape, seed, sharding):
|
|||
# process_index is multi_host only
|
||||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
|
||||
spec = sharding.spec
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
|
|
Loading…
Add table
Reference in a new issue