From 0433c615f3a60ea11fd392b782b864159dee2317 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 12:15:37 -0400 Subject: [PATCH] quick fix in distributed --- jaxpm/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index a31854f..05444d9 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -153,9 +153,9 @@ def normal_field(mesh_shape, seed, sharding): def normal(keys, shape, dtype): idx = lax.axis_index(x_axis) if not single_axis: - y_index = lax.axis_index(y_axis) - x_size = lax.psum(1, axis_name=x_axis) - idx += y_index * x_size + y_index = lax.axis_index(y_axis) + x_size = lax.psum(1, axis_name=x_axis) + idx += y_index * x_size return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)