mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
update code
This commit is contained in:
parent
e0c118a540
commit
21373b89ee
7 changed files with 84 additions and 100 deletions
|
@ -204,7 +204,7 @@ def cic_paint_dx(displacements,
|
|||
return grid_mesh
|
||||
|
||||
|
||||
def cic_read_dx_impl(grid_mesh, halo_size):
|
||||
def cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
@ -220,14 +220,15 @@ def cic_read_dx_impl(grid_mesh, halo_size):
|
|||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, jnp.zeros_like(pmid),
|
||||
return gather(pmid, disp,
|
||||
grid_mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2))
|
||||
def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
|
||||
# return mesh
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
|
@ -238,7 +239,7 @@ def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
|
|||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh)
|
||||
out_specs=spec)(grid_mesh , disp)
|
||||
|
||||
return displacements
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue