diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index bdae517..a31854f 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -152,10 +152,10 @@ def normal_field(mesh_shape, seed, sharding): def normal(keys, shape, dtype): idx = lax.axis_index(x_axis) - if single_axis: - y_index = lax.axis_index(y_axis) - x_size = lax.psum(1, axis_name=x_axis) - idx += y_index * x_size + if not single_axis: + y_index = lax.axis_index(y_axis) + x_size = lax.psum(1, axis_name=x_axis) + idx += y_index * x_size return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)