mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
fixed a whole lot of issues
This commit is contained in:
parent
429813ad92
commit
72ae0fd88f
5 changed files with 251 additions and 155 deletions
|
@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce
|
|||
from jaxpm.kernels import fftk, cic_compensation
|
||||
|
||||
|
||||
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
|
||||
def cic_paint(mesh, positions, halo_size=0, comms=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
|
@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
if comms == None:
|
||||
return mesh
|
||||
else:
|
||||
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
|
||||
def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
||||
def cic_read(mesh, positions, halo_size=0, comms=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
|
@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||
[halo_size, halo_size],
|
||||
[0, 0]])
|
||||
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
|
@ -75,14 +75,9 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||
|
||||
res = (mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
||||
|
||||
if comms is not None:
|
||||
return res
|
||||
else:
|
||||
return res, token
|
||||
return (mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue