mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
quick fix in distributed
This commit is contained in:
parent
a5b267bd63
commit
56ffd263f6
1 changed files with 4 additions and 4 deletions
|
@ -152,10 +152,10 @@ def normal_field(mesh_shape, seed, sharding):
|
||||||
|
|
||||||
def normal(keys, shape, dtype):
|
def normal(keys, shape, dtype):
|
||||||
idx = lax.axis_index(x_axis)
|
idx = lax.axis_index(x_axis)
|
||||||
if single_axis:
|
if not single_axis:
|
||||||
y_index = lax.axis_index(y_axis)
|
y_index = lax.axis_index(y_axis)
|
||||||
x_size = lax.psum(1, axis_name=x_axis)
|
x_size = lax.psum(1, axis_name=x_axis)
|
||||||
idx += y_index * x_size
|
idx += y_index * x_size
|
||||||
|
|
||||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue