mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Add aboucaud comments
This commit is contained in:
parent
adaf7d236d
commit
d8c68ace7a
10 changed files with 26 additions and 777 deletions
|
@ -10,13 +10,13 @@ import jax.numpy as jnp
|
|||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import AbstractMesh, Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | None,
|
||||
gpu_mesh: Mesh | AbstractMesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = False,
|
||||
|
@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None):
|
|||
]
|
||||
|
||||
|
||||
def __axis_names(spec):
|
||||
def _axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
|
@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
|
@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue