mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
add a user facing function to create uniform particle grid
This commit is contained in:
parent
a757b62f4b
commit
f3b431aa74
1 changed files with 46 additions and 17 deletions
|
@ -136,7 +136,51 @@ def zeros(mesh_shape, sharding=None):
|
||||||
return jnp.zeros(mesh_shape)
|
return jnp.zeros(mesh_shape)
|
||||||
|
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed, sharding):
|
def __axis_names(spec):
|
||||||
|
if len(spec) == 1:
|
||||||
|
x_axis, = spec
|
||||||
|
y_axis = None
|
||||||
|
single_axis = True
|
||||||
|
elif len(spec) == 2:
|
||||||
|
x_axis, y_axis = spec
|
||||||
|
if y_axis == None:
|
||||||
|
single_axis = True
|
||||||
|
elif x_axis == None:
|
||||||
|
x_axis = y_axis
|
||||||
|
single_axis = True
|
||||||
|
else:
|
||||||
|
single_axis = False
|
||||||
|
else:
|
||||||
|
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||||
|
return x_axis, y_axis, single_axis
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_particles(mesh_shape, sharding=None):
|
||||||
|
|
||||||
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
|
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||||
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
spec = sharding.spec
|
||||||
|
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||||
|
|
||||||
|
def particles():
|
||||||
|
x_indx = lax.axis_index(x_axis)
|
||||||
|
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||||
|
|
||||||
|
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||||
|
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||||
|
z = jnp.arange(local_mesh_shape[2])
|
||||||
|
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||||
|
|
||||||
|
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||||
|
out_specs=spec)()
|
||||||
|
else:
|
||||||
|
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||||
|
indexing='ij'),
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_field(mesh_shape, seed, sharding=None):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||||
|
@ -147,23 +191,8 @@ def normal_field(mesh_shape, seed, sharding):
|
||||||
# process_index is multi_host only
|
# process_index is multi_host only
|
||||||
# to make the code work both in multi host and single controller we can do this trick
|
# to make the code work both in multi host and single controller we can do this trick
|
||||||
keys = jax.random.split(seed, size)
|
keys = jax.random.split(seed, size)
|
||||||
|
|
||||||
spec = sharding.spec
|
spec = sharding.spec
|
||||||
if len(spec) == 1:
|
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||||
x_axis, = spec
|
|
||||||
y_axis = None
|
|
||||||
single_axis = True
|
|
||||||
elif len(spec) == 2:
|
|
||||||
x_axis, y_axis = spec
|
|
||||||
if y_axis == None:
|
|
||||||
single_axis = True
|
|
||||||
elif x_axis == None:
|
|
||||||
x_axis = y_axis
|
|
||||||
single_axis = True
|
|
||||||
else:
|
|
||||||
single_axis = False
|
|
||||||
else:
|
|
||||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
|
||||||
|
|
||||||
def normal(keys, shape, dtype):
|
def normal(keys, shape, dtype):
|
||||||
idx = lax.axis_index(x_axis)
|
idx = lax.axis_index(x_axis)
|
||||||
|
|
Loading…
Add table
Reference in a new issue