Add aboucaud comments

This commit is contained in:
Wassim Kabalan 2024-12-08 23:11:11 +01:00
parent adaf7d236d
commit d8c68ace7a
10 changed files with 26 additions and 777 deletions

View file

@ -3,6 +3,7 @@ from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
@ -11,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint_impl(grid_mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -54,9 +55,9 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
grid_mesh = autoshmap(cic_paint_impl,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
@ -68,7 +69,7 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
return grid_mesh
def cic_read_impl(grid_mesh, positions):
def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [nx,ny,nz, 3]
@ -110,10 +111,10 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(cic_read_impl,
displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
@ -150,7 +151,7 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -187,9 +188,9 @@ def cic_paint_dx(displacements,
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
grid_mesh = autoshmap(partial(cic_paint_dx_impl,
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
@ -204,7 +205,7 @@ def cic_paint_dx(displacements,
return grid_mesh
def cic_read_dx_impl(grid_mesh, disp, halo_size):
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -233,9 +234,9 @@ def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh, disp)