mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-11 21:50:55 +00:00
add correct annotations for weights in painting and warning for cic_paint in distributed pm
This commit is contained in:
parent
f8325b1c67
commit
91d3292923
1 changed files with 20 additions and 13 deletions
|
@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
|
||||||
from jaxpm.painting_utils import gather, scatter
|
from jaxpm.painting_utils import gather, scatter
|
||||||
|
|
||||||
|
|
||||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
displacement field: [nx, ny, nz, 3]
|
displacement field: [nx, ny, nz, 3]
|
||||||
|
@ -27,12 +27,11 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
if weight is not None:
|
if jnp.isscalar(weight):
|
||||||
if jnp.isscalar(weight):
|
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
else:
|
||||||
else:
|
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
kernel)
|
||||||
kernel)
|
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(
|
neighboor_coords = jnp.mod(
|
||||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
|
@ -48,7 +47,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(3, 4))
|
@partial(jax.jit, static_argnums=(3, 4))
|
||||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
if sharding is not None:
|
||||||
|
print("""
|
||||||
|
WARNING : absolute painting is not recommended in multi-device mode.
|
||||||
|
Please use relative painting instead.
|
||||||
|
""")
|
||||||
|
|
||||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
||||||
|
@ -57,9 +62,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||||
|
|
||||||
grid_mesh = autoshmap(_cic_paint_impl,
|
grid_mesh = autoshmap(_cic_paint_impl,
|
||||||
gpu_mesh=gpu_mesh,
|
gpu_mesh=gpu_mesh,
|
||||||
in_specs=(spec, spec, P()),
|
in_specs=(spec, spec, weight_spec),
|
||||||
out_specs=spec)(grid_mesh, positions, weight)
|
out_specs=spec)(grid_mesh, positions, weight)
|
||||||
grid_mesh = halo_exchange(grid_mesh,
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
|
@ -151,7 +158,7 @@ def cic_paint_2d(mesh, positions, weight):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24):
|
||||||
|
|
||||||
halo_x, _ = halo_size[0]
|
halo_x, _ = halo_size[0]
|
||||||
halo_y, _ = halo_size[1]
|
halo_y, _ = halo_size[1]
|
||||||
|
@ -190,13 +197,13 @@ def cic_paint_dx(displacements,
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||||
|
weight_spec = P() if jnp.isscalar(weight) else spec
|
||||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
weight=weight,
|
|
||||||
chunk_size=chunk_size),
|
chunk_size=chunk_size),
|
||||||
gpu_mesh=gpu_mesh,
|
gpu_mesh=gpu_mesh,
|
||||||
in_specs=spec,
|
in_specs=(spec, weight_spec),
|
||||||
out_specs=spec)(displacements)
|
out_specs=spec)(displacements , weight)
|
||||||
|
|
||||||
grid_mesh = halo_exchange(grid_mesh,
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
|
|
Loading…
Add table
Reference in a new issue