mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
Add aboucaud comments
This commit is contained in:
parent
adaf7d236d
commit
d8c68ace7a
10 changed files with 26 additions and 777 deletions
|
@ -10,13 +10,13 @@ import jax.numpy as jnp
|
|||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import AbstractMesh, Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | None,
|
||||
gpu_mesh: Mesh | AbstractMesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = False,
|
||||
|
@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None):
|
|||
]
|
||||
|
||||
|
||||
def __axis_names(spec):
|
||||
def _axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
|
@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
|
@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None):
|
|||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
|
|
|
@ -3,6 +3,7 @@ from functools import partial
|
|||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
|
@ -11,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
|||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -54,9 +55,9 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
grid_mesh = autoshmap(cic_paint_impl,
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
|
@ -68,7 +69,7 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|||
return grid_mesh
|
||||
|
||||
|
||||
def cic_read_impl(grid_mesh, positions):
|
||||
def _cic_read_impl(grid_mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [nx,ny,nz, 3]
|
||||
|
@ -110,10 +111,10 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
|||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
displacement = autoshmap(_cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
|
@ -150,7 +151,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
@ -187,9 +188,9 @@ def cic_paint_dx(displacements,
|
|||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
grid_mesh = autoshmap(partial(cic_paint_dx_impl,
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
|
@ -204,7 +205,7 @@ def cic_paint_dx(displacements,
|
|||
return grid_mesh
|
||||
|
||||
|
||||
def cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
@ -233,9 +234,9 @@ def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
|||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh, disp)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue