mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
quick fix in distributed
This commit is contained in:
parent
a5b267bd63
commit
56ffd263f6
1 changed files with 4 additions and 4 deletions
|
@ -152,10 +152,10 @@ def normal_field(mesh_shape, seed, sharding):
|
|||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
if single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
if not single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
|
||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue