fixed a whole lot of issues

This commit is contained in:
EiffL 2022-10-22 15:58:32 -04:00
parent 429813ad92
commit 72ae0fd88f
5 changed files with 251 additions and 155 deletions

View file

@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
def cic_paint(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
if comms == None:
return mesh
else:
mesh, token = halo_reduce(mesh, halo_size, token, comms)
mesh = halo_reduce(mesh, halo_size, comms)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
def cic_read(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
mesh, token = halo_reduce(mesh, halo_size, token, comms)
mesh = halo_reduce(mesh, halo_size, comms)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
@ -75,14 +75,9 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
res = (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
if comms is not None:
return res
else:
return res, token
return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):