quick fix in distributed

This commit is contained in:
Wassim KABALAN 2024-10-22 12:15:37 -04:00
parent a8b194f326
commit 0433c615f3

View file

@ -153,9 +153,9 @@ def normal_field(mesh_shape, seed, sharding):
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)
if not single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)