diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 21694e5..b8d888e 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -136,7 +136,51 @@ def zeros(mesh_shape, sharding=None): return jnp.zeros(mesh_shape) -def normal_field(mesh_shape, seed, sharding): +def __axis_names(spec): + if len(spec) == 1: + x_axis, = spec + y_axis = None + single_axis = True + elif len(spec) == 2: + x_axis, y_axis = spec + if y_axis == None: + single_axis = True + elif x_axis == None: + x_axis = y_axis + single_axis = True + else: + single_axis = False + else: + raise ValueError("Only 1 or 2 axis sharding is supported") + return x_axis, y_axis, single_axis + + +def uniform_particles(mesh_shape, sharding=None): + + gpu_mesh = sharding.mesh if sharding is not None else None + if not gpu_mesh is None and not (gpu_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape, sharding) + spec = sharding.spec + x_axis, y_axis, single_axis = __axis_names(spec) + + def particles(): + x_indx = lax.axis_index(x_axis) + y_indx = 0 if single_axis else lax.axis_index(y_axis) + + x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0] + y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1] + z = jnp.arange(local_mesh_shape[2]) + return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1) + + return shard_map(particles, mesh=gpu_mesh, in_specs=(), + out_specs=spec)() + else: + return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape], + indexing='ij'), + axis=-1) + + +def normal_field(mesh_shape, seed, sharding=None): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if not gpu_mesh is None and not (gpu_mesh.empty): @@ -147,23 +191,8 @@ def normal_field(mesh_shape, seed, sharding): # process_index is multi_host only # to make the code work both in multi host and single controller we can do this trick keys = jax.random.split(seed, size) - spec = sharding.spec - if len(spec) == 1: - x_axis, = spec - y_axis = None - single_axis = True - elif len(spec) == 2: - x_axis, y_axis = spec - if y_axis == None: - single_axis = True - elif x_axis == None: - x_axis = y_axis - single_axis = True - else: - single_axis = False - else: - raise ValueError("Only 1 or 2 axis sharding is supported") + x_axis, y_axis, single_axis = __axis_names(spec) def normal(keys, shape, dtype): idx = lax.axis_index(x_axis)