roll back painting

This commit is contained in:
Wassim KABALAN 2024-08-05 19:37:33 +02:00
parent 9b21eed3b5
commit 30060e82ea

View file

@ -6,18 +6,10 @@ from jaxpm.kernels import cic_compensation, fftk
def cic_paint(mesh, positions, weight=None):
"""
Paint positions onto mesh
Parameters:
-----------
mesh: [nx, ny, nz]
positions: [npart, 3]
Returns:
--------
mesh: [nx, ny, nz]
"""
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -43,18 +35,10 @@ def cic_paint(mesh, positions, weight=None):
def cic_read(mesh, positions):
"""
Read mesh at positions
Parameters:
-----------
mesh: [nx, ny, nz]
positions: [npart, 3]
Returns:
--------
values: [npart]
"""
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -72,19 +56,11 @@ def cic_read(mesh, positions):
def cic_paint_2d(mesh, positions, weight):
"""
Paints positions onto 2d mesh
Parameters:
-----------
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
Returns:
--------
mesh: [nx, ny]
"""
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -110,12 +86,12 @@ def cic_paint_2d(mesh, positions, weight):
def compensate_cic(field):
"""
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)