mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
update painting functions to accept pytrees
This commit is contained in:
parent
7c3577ea71
commit
f5755b4b5d
2 changed files with 53 additions and 43 deletions
|
@ -19,37 +19,37 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions)
|
||||||
floor = jnp.floor(positions)
|
floor = jax.tree.map(jnp.floor , positions)
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords))
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
if jnp.isscalar(weight):
|
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
|
||||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1)
|
||||||
|
, k) , kernel , weight)
|
||||||
else:
|
else:
|
||||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight)
|
||||||
kernel)
|
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(
|
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords)
|
||||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
|
||||||
jnp.array(grid_mesh.shape))
|
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||||
inserted_window_dims=(0, 1, 2),
|
inserted_window_dims=(0, 1, 2),
|
||||||
scatter_dims_to_operand_dims=(0, 1,
|
scatter_dims_to_operand_dims=(0, 1,
|
||||||
2))
|
2))
|
||||||
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel)
|
||||||
kernel.reshape([-1, 8]), dnums)
|
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(3, 4))
|
@partial(jax.jit, static_argnums=(3, 4))
|
||||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||||
|
|
||||||
|
positions_structure = jax.tree_structure(positions)
|
||||||
|
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
|
||||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
||||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||||
|
@ -79,25 +79,25 @@ def _cic_read_impl(grid_mesh, positions):
|
||||||
# Reshape positions to a flat list of 3D coordinates
|
# Reshape positions to a flat list of 3D coordinates
|
||||||
positions = positions.reshape([-1, 3])
|
positions = positions.reshape([-1, 3])
|
||||||
# Expand dimensions to calculate neighbor coordinates
|
# Expand dimensions to calculate neighbor coordinates
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions)
|
||||||
# Floor the positions to get the base grid cell for each particle
|
# Floor the positions to get the base grid cell for each particle
|
||||||
floor = jnp.floor(positions)
|
floor = jax.tree.map(jnp.floor , positions)
|
||||||
# Define connections to calculate all neighbor coordinates
|
# Define connections to calculate all neighbor coordinates
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||||
# Calculate the 8 neighboring coordinates
|
# Calculate the 8 neighboring coordinates
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
# Modulo operation to wrap around edges if necessary
|
# Modulo operation to wrap around edges if necessary
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32')
|
||||||
jnp.array(grid_mesh.shape))
|
,jnp.array(grid_mesh.shape)) , neighboor_coords)
|
||||||
|
|
||||||
# Ensure grid_mesh shape is as expected
|
# Ensure grid_mesh shape is as expected
|
||||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||||
return (grid_mesh[neighboor_coords[..., 0],
|
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel)
|
||||||
neighboor_coords[..., 1],
|
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||||
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(2, 3))
|
@partial(jax.jit, static_argnums=(2, 3))
|
||||||
|
@ -157,7 +157,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||||
halo_y, _ = halo_size[1]
|
halo_y, _ = halo_size[1]
|
||||||
|
|
||||||
original_shape = displacements.shape
|
original_shape = displacements.shape
|
||||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements)
|
||||||
if not jnp.isscalar(weight):
|
if not jnp.isscalar(weight):
|
||||||
if weight.shape != original_shape[:-1]:
|
if weight.shape != original_shape[:-1]:
|
||||||
raise ValueError("Weight shape must match particle shape")
|
raise ValueError("Weight shape must match particle shape")
|
||||||
|
@ -165,13 +165,13 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||||
weight = weight.flatten()
|
weight = weight.flatten()
|
||||||
# Padding is forced to be zero in a single gpu run
|
# Padding is forced to be zero in a single gpu run
|
||||||
|
|
||||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
|
||||||
jnp.arange(particle_mesh.shape[1]),
|
jnp.arange(x.shape[1]),
|
||||||
jnp.arange(particle_mesh.shape[2]),
|
jnp.arange(x.shape[2]),
|
||||||
indexing='ij')
|
indexing='ij') , axis=0), particle_mesh)
|
||||||
|
|
||||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
|
||||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||||
return scatter(pmid.reshape([-1, 3]),
|
return scatter(pmid.reshape([-1, 3]),
|
||||||
displacements.reshape([-1, 3]),
|
displacements.reshape([-1, 3]),
|
||||||
particle_mesh,
|
particle_mesh,
|
||||||
|
@ -217,9 +217,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
jnp.arange(original_shape[1]),
|
jnp.arange(original_shape[1]),
|
||||||
jnp.arange(original_shape[2]),
|
jnp.arange(original_shape[2]),
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
|
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||||
|
jnp.arange(original_shape[1]),
|
||||||
|
jnp.arange(original_shape[2]),
|
||||||
|
indexing='ij') , axis=0), grid_mesh)
|
||||||
|
|
||||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||||
|
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
disp = disp.reshape([-1, 3])
|
disp = disp.reshape([-1, 3])
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
new_cell_size, new_shape):
|
new_cell_size, new_shape):
|
||||||
"""Multilinear enmeshing."""
|
"""Multilinear enmeshing."""
|
||||||
base_indices = jnp.asarray(base_indices)
|
base_indices = jax.tree.map(jnp.asarray , base_indices)
|
||||||
displacements = jnp.asarray(displacements)
|
displacements = jax.tree.map(jnp.asarray , displacements)
|
||||||
with jax.experimental.enable_x64():
|
with jax.experimental.enable_x64():
|
||||||
cell_size = jnp.float64(
|
cell_size = jnp.float64(
|
||||||
cell_size) if new_cell_size is not None else jnp.array(
|
cell_size) if new_cell_size is not None else jnp.array(
|
||||||
|
@ -61,7 +61,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
new_displacements = particle_positions - new_indices * new_cell_size
|
new_displacements = particle_positions - new_indices * new_cell_size
|
||||||
|
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
new_displacements -= jnp.rint(
|
new_displacements -= jax.tree.map(jnp.rint ,
|
||||||
new_displacements / grid_length
|
new_displacements / grid_length
|
||||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
new_indices %= base_shape
|
new_indices %= base_shape
|
||||||
|
|
||||||
weights = 1 - jnp.abs(new_displacements)
|
weights = 1 - jax.tree.map(jnp.abs , new_displacements)
|
||||||
|
|
||||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||||
|
@ -109,8 +109,11 @@ def _scatter_chunk(carry, chunk):
|
||||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
spatial_shape)
|
spatial_shape)
|
||||||
# scatter
|
# scatter
|
||||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
mesh_structure = jax.tree_structure(mesh)
|
||||||
|
val_flat = jax.tree.leaves(val)
|
||||||
|
val_tree = jax.tree_unflatten(mesh_structure, val_flat)
|
||||||
|
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
|
||||||
carry = mesh, offset, cell_size, mesh_shape
|
carry = mesh, offset, cell_size, mesh_shape
|
||||||
return carry, None
|
return carry, None
|
||||||
|
|
||||||
|
@ -122,9 +125,10 @@ def scatter(pmid,
|
||||||
val=1.,
|
val=1.,
|
||||||
offset=0,
|
offset=0,
|
||||||
cell_size=1.):
|
cell_size=1.):
|
||||||
|
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
val = jnp.asarray(val)
|
val = jax.tree.map(jnp.asarray , val)
|
||||||
mesh = jnp.asarray(mesh)
|
mesh = jax.tree.map(jnp.asarray , mesh)
|
||||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
carry = mesh, offset, cell_size, mesh.shape
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
if remainder is not None:
|
if remainder is not None:
|
||||||
|
@ -147,9 +151,9 @@ def _chunk_cat(remainder_array, chunked_array):
|
||||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
mesh = jnp.asarray(mesh)
|
mesh = jax.tree.map(jnp.asarray , mesh)
|
||||||
|
|
||||||
val = jnp.asarray(val)
|
val = jax.tree.map(jnp.asarray , val)
|
||||||
|
|
||||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||||
raise ValueError('channel shape mismatch: '
|
raise ValueError('channel shape mismatch: '
|
||||||
|
@ -183,8 +187,11 @@ def _gather_chunk(carry, chunk):
|
||||||
spatial_shape)
|
spatial_shape)
|
||||||
|
|
||||||
# gather
|
# gather
|
||||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||||
frac = jnp.expand_dims(frac, chan_axis)
|
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
|
||||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
ind_structure = jax.tree_structure(ind)
|
||||||
|
frac_structure = jax.tree_structure(frac)
|
||||||
|
mesh_structure = jax.tree_structure(mesh)
|
||||||
|
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
|
||||||
|
|
||||||
return carry, val
|
return carry, val
|
||||||
|
|
Loading…
Add table
Reference in a new issue