mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
Allow applying weights with relative cic paint
This commit is contained in:
parent
b3a264ad53
commit
b09580d59e
2 changed files with 28 additions and 12 deletions
|
@ -16,6 +16,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
"""
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
@ -26,7 +27,11 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -144,14 +149,18 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def cic_paint_dx_impl(displacements, halo_size):
|
||||
def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
else:
|
||||
weight = weight.flatten()
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
|
@ -161,18 +170,28 @@ def cic_paint_dx_impl(displacements, halo_size):
|
|||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
val=weight)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2))
|
||||
def cic_paint_dx(displacements, halo_size=0, sharding=None):
|
||||
@partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def cic_paint_dx(displacements,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
weight=1.0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
grid_mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||
grid_mesh = autoshmap(partial(cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
|
|
|
@ -104,8 +104,7 @@ def _scatter_chunk(carry, chunk):
|
|||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(val * frac)
|
||||
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
@ -117,11 +116,9 @@ def scatter(pmid,
|
|||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue