mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
make normal_field work with single controller
This commit is contained in:
parent
ff1c5e8362
commit
0ce7219ad8
1 changed files with 18 additions and 9 deletions
|
@ -14,6 +14,7 @@ from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax import lax
|
||||||
from jax._src import mesh as mesh_lib
|
from jax._src import mesh as mesh_lib
|
||||||
from jax.experimental.shard_map import shard_map
|
from jax.experimental.shard_map import shard_map
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
@ -139,19 +140,27 @@ def get_local_shape(mesh_shape):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed=None):
|
def normal_field(mesh_shape, seed):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
local_mesh_shape = get_local_shape(mesh_shape)
|
local_mesh_shape = get_local_shape(mesh_shape)
|
||||||
if seed is None:
|
|
||||||
key = None
|
size = jax.device_count()
|
||||||
else:
|
# rank = jax.process_index()
|
||||||
size = jax.process_count()
|
# process_index is multi_host only
|
||||||
rank = jax.process_index()
|
# to make the code work both in multi host and single controller we can do this trick
|
||||||
key = jax.random.split(seed, size)[rank]
|
keys = jax.random.split(seed, size)
|
||||||
|
|
||||||
|
def normal(keys, shape, dtype):
|
||||||
|
x_index = lax.axis_index('x')
|
||||||
|
y_index = lax.axis_index('y')
|
||||||
|
x_size = lax.psum(1, axis_name='x')
|
||||||
|
idx = x_index + y_index * x_size
|
||||||
|
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||||
|
|
||||||
return autoshmap(
|
return autoshmap(
|
||||||
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
|
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||||
in_specs=P(None),
|
in_specs=P(None),
|
||||||
out_specs=P('x', 'y'))(key) # yapf: disable
|
out_specs=P('x', 'y'))(keys) # yapf: disable
|
||||||
else:
|
else:
|
||||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||||
|
|
Loading…
Add table
Reference in a new issue