Add aboucaud comments

This commit is contained in:
Wassim Kabalan 2024-12-08 23:11:11 +01:00
parent adaf7d236d
commit d8c68ace7a
10 changed files with 26 additions and 777 deletions

View file

@ -10,13 +10,13 @@ import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P
def autoshmap(
f: Callable,
gpu_mesh: Mesh | None,
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None):
]
def __axis_names(spec):
def _axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
x_axis, y_axis, single_axis = _axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None):
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)