Add aboucaud comments

This commit is contained in:
Wassim Kabalan 2024-12-08 23:11:11 +01:00
parent adaf7d236d
commit d8c68ace7a
10 changed files with 26 additions and 777 deletions

View file

@ -10,13 +10,13 @@ import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P
def autoshmap(
f: Callable,
gpu_mesh: Mesh | None,
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None):
]
def __axis_names(spec):
def _axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
x_axis, y_axis, single_axis = _axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None):
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)

View file

@ -3,6 +3,7 @@ from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
@ -11,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -54,9 +55,9 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
grid_mesh = autoshmap(cic_paint_impl,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
@ -68,7 +69,7 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
return grid_mesh
def cic_read_impl(grid_mesh, positions):
def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [nx,ny,nz, 3]
@ -110,10 +111,10 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(cic_read_impl,
displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
@ -150,7 +151,7 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -187,9 +188,9 @@ def cic_paint_dx(displacements,
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
grid_mesh = autoshmap(partial(cic_paint_dx_impl,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
@ -204,7 +205,7 @@ def cic_paint_dx(displacements,
return grid_mesh
def cic_read_dx_impl(grid_mesh, disp, halo_size):
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -233,9 +234,9 @@ def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh, disp)