mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
fix painting issue with slabs
This commit is contained in:
parent
f25eb7d465
commit
8c5bd76c33
1 changed files with 13 additions and 9 deletions
|
@ -152,7 +152,6 @@ def cic_paint_dx_impl(displacements, halo_size):
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
|
|
||||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||||
|
|
||||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
|
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
|
||||||
|
@ -166,6 +165,7 @@ def cic_paint_dx(displacements, halo_size=0):
|
||||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||||
in_specs=(P('x', 'y')),
|
in_specs=(P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))(displacements)
|
out_specs=P('x', 'y'))(displacements)
|
||||||
|
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
halo_periods=(True, True, True))
|
halo_periods=(True, True, True))
|
||||||
|
@ -173,16 +173,19 @@ def cic_paint_dx(displacements, halo_size=0):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def cic_read_dx_impl(mesh):
|
def cic_read_dx_impl(mesh , halo_size):
|
||||||
|
|
||||||
original_shape = mesh.shape
|
halo_x, _ = halo_size[0]
|
||||||
|
halo_y, _ = halo_size[1]
|
||||||
|
|
||||||
|
original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)]
|
||||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||||
jnp.arange(original_shape[1]),
|
jnp.arange(original_shape[1]),
|
||||||
jnp.arange(original_shape[2]),
|
jnp.arange(original_shape[2]),
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
|
|
||||||
pmid = jnp.stack([a, b, c], axis=-1)
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||||
|
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
|
||||||
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
|
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
|
||||||
|
@ -190,15 +193,16 @@ def cic_read_dx_impl(mesh):
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(1, ))
|
@partial(jax.jit, static_argnums=(1, ))
|
||||||
def cic_read_dx(mesh, halo_size=0):
|
def cic_read_dx(mesh, halo_size=0):
|
||||||
|
# return mesh
|
||||||
halo_size, halo_extents = get_halo_size(halo_size)
|
halo_size, halo_extents = get_halo_size(halo_size)
|
||||||
mesh = slice_pad(mesh, halo_size)
|
mesh = slice_pad(mesh, halo_size)
|
||||||
mesh = halo_exchange(mesh,
|
mesh = halo_exchange(mesh,
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
halo_periods=(True, True, True))
|
halo_periods=(True, True, True))
|
||||||
displacements = autoshmap(cic_read_dx_impl,
|
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
|
||||||
in_specs=(P('x', 'y')),
|
in_specs=(P('x', 'y')),
|
||||||
out_specs=P('x', 'y'))(mesh)
|
out_specs=P('x', 'y'))(mesh)
|
||||||
|
|
||||||
return displacements
|
return displacements
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue