From 0ce7219ad881ac5ea40d9b440aced0aa25719390 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 21 Oct 2024 13:59:00 -0400 Subject: [PATCH] make normal_field work with single controller --- jaxpm/distributed.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 04263ea..f4fad8a 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -14,6 +14,7 @@ from functools import partial import jax import jax.numpy as jnp +from jax import lax from jax._src import mesh as mesh_lib from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec as P @@ -139,19 +140,27 @@ def get_local_shape(mesh_shape): ] -def normal_field(mesh_shape, seed=None): +def normal_field(mesh_shape, seed): """Generate a Gaussian random field with the given power spectrum.""" if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape) - if seed is None: - key = None - else: - size = jax.process_count() - rank = jax.process_index() - key = jax.random.split(seed, size)[rank] + + size = jax.device_count() + # rank = jax.process_index() + # process_index is multi_host only + # to make the code work both in multi host and single controller we can do this trick + keys = jax.random.split(seed, size) + + def normal(keys, shape, dtype): + x_index = lax.axis_index('x') + y_index = lax.axis_index('y') + x_size = lax.psum(1, axis_name='x') + idx = x_index + y_index * x_size + return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype) + return autoshmap( - partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), + partial(normal, shape=local_mesh_shape, dtype='float32'), in_specs=P(None), - out_specs=P('x', 'y'))(key) # yapf: disable + out_specs=P('x', 'y'))(keys) # yapf: disable else: return jax.random.normal(shape=mesh_shape, key=seed)