diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 3083f08..f8797f2 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def _cic_paint_impl(grid_mesh, positions, weight=None): +def _cic_paint_impl(grid_mesh, positions, weight=1.): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] @@ -27,12 +27,11 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - if weight is not None: - if jnp.isscalar(weight): - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) - else: - kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), - kernel) + if jnp.isscalar(weight): + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + else: + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), + kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), @@ -48,7 +47,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): @partial(jax.jit, static_argnums=(3, 4)) -def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): +def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None): + + if sharding is not None: + print(""" + WARNING : absolute painting is not recommended in multi-device mode. + Please use relative painting instead. + """) positions = positions.reshape((*grid_mesh.shape, 3)) @@ -57,9 +62,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + weight_spec = P() if jnp.isscalar(weight) else spec + grid_mesh = autoshmap(_cic_paint_impl, gpu_mesh=gpu_mesh, - in_specs=(spec, spec, P()), + in_specs=(spec, spec, weight_spec), out_specs=spec)(grid_mesh, positions, weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, @@ -151,7 +158,7 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): +def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -190,13 +197,13 @@ def cic_paint_dx(displacements, gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + weight_spec = P() if jnp.isscalar(weight) else spec grid_mesh = autoshmap(partial(_cic_paint_dx_impl, halo_size=halo_size, - weight=weight, chunk_size=chunk_size), gpu_mesh=gpu_mesh, - in_specs=spec, - out_specs=spec)(displacements) + in_specs=(spec, weight_spec), + out_specs=spec)(displacements , weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents,