fix painting issue with read_cic

This commit is contained in:
Wassim KABALAN 2024-10-26 18:52:51 +02:00
parent 0f833f0cb4
commit d2f1eb2fa4

View file

@ -11,17 +11,11 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint_impl(grid_mesh, displacement, weight=None):
def cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
"""
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'),
axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@ -47,7 +41,7 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
return mesh
@partial(jax.jit, static_argnums=(2, 3, 4))
#@partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
@ -66,43 +60,46 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
print(f"shape of grid_mesh: {grid_mesh.shape}")
return grid_mesh
def cic_read_impl(mesh, displacement):
def cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement: [nx,ny,nz, 3]
positions: [nx,ny,nz, 3]
"""
# Compute the position of the particles on a regular grid
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'),
axis=-1) + displacement
# Save original shape for reshaping output later
original_shape = positions.shape
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jnp.expand_dims(positions, 1)
# Floor the positions to get the base grid cell for each particle
floor = jnp.floor(positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
displacement.shape[:-1])
jnp.array(grid_mesh.shape))
# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
return (grid_mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
@partial(jax.jit, static_argnums=(2, 3))
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
original_shape = positions.shape
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
@ -114,9 +111,8 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
print(f"shape of displacement: {displacement.shape}")
return displacement
return displacement.reshape(original_shape[:-1])
def cic_paint_2d(mesh, positions, weight):