mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
Format
This commit is contained in:
parent
7823fdaf98
commit
af29c4005d
7 changed files with 68 additions and 63 deletions
|
@ -222,12 +222,11 @@ def cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, disp,
|
||||
grid_mesh).reshape(original_shape)
|
||||
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
|
||||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
|
@ -239,7 +238,7 @@ def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
|
|||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh , disp)
|
||||
out_specs=spec)(grid_mesh, disp)
|
||||
|
||||
return displacements
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue