mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
remove deprecated stuff
This commit is contained in:
parent
8e8e8964be
commit
0f833f0cb4
8 changed files with 96 additions and 831 deletions
|
@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
|
|||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
|
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
|||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
print(f"shape of grid_mesh: {grid_mesh.shape}")
|
||||
return grid_mesh
|
||||
|
||||
|
||||
|
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(mesh, displacement, halo_size=0, sharding=None):
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
mesh = slice_pad(mesh, halo_size, sharding=sharding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(mesh, displacement)
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
print(f"shape of displacement: {displacement.shape}")
|
||||
|
||||
return displacement
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue