From 56ffd263f682c902c1b844d4c00eac57b18538ac Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 12:07:06 -0400 Subject: [PATCH] quick fix in distributed --- jaxpm/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index bdae517..a31854f 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -152,10 +152,10 @@ def normal_field(mesh_shape, seed, sharding): def normal(keys, shape, dtype): idx = lax.axis_index(x_axis) - if single_axis: - y_index = lax.axis_index(y_axis) - x_size = lax.psum(1, axis_name=x_axis) - idx += y_index * x_size + if not single_axis: + y_index = lax.axis_index(y_axis) + x_size = lax.psum(1, axis_name=x_axis) + idx += y_index * x_size return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)