quick fix in distributed

This commit is contained in:
Wassim KABALAN 2024-10-22 12:07:06 -04:00
parent a5b267bd63
commit 56ffd263f6

View file

@ -152,10 +152,10 @@ def normal_field(mesh_shape, seed, sharding):
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)
if single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
if not single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)