mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
updates for FullField
This commit is contained in:
parent
34bf577b1b
commit
9b8e42aa00
2 changed files with 25 additions and 12 deletions
|
@ -140,26 +140,39 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
|
||||
floor = jax.tree.map(jnp.floor, positions)
|
||||
connection = jnp.array([[[0, 0], [1., 0], [0., 1], [1., 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
if weight is not None:
|
||||
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
|
||||
kernel = jax.tree.map(
|
||||
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
|
||||
kernel, weight)
|
||||
else:
|
||||
kernel = jax.tree.map(
|
||||
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
|
||||
kernel, weight)
|
||||
|
||||
neighboor_coords = jax.tree.map(
|
||||
lambda nc: jnp.mod(
|
||||
nc.reshape([-1, 4, 2]).astype('int32'), jnp.array(mesh.shape)
|
||||
), neighboor_coords)
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1),
|
||||
scatter_dims_to_operand_dims=(0,
|
||||
1))
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
|
||||
dnums)
|
||||
mesh = jax.tree.map(
|
||||
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 4]), dnums),
|
||||
mesh, neighboor_coords, kernel)
|
||||
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
|
|
|
@ -221,4 +221,4 @@ def gaussian_smoothing(im, sigma):
|
|||
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
|
||||
filter /= filter[0, 0]
|
||||
|
||||
return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
|
||||
return jax.tree.map(lambda im : jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real , im)
|
||||
|
|
Loading…
Add table
Reference in a new issue