roll back painting

This commit is contained in:
Wassim KABALAN 2024-08-05 19:37:33 +02:00
parent 9b21eed3b5
commit 30060e82ea

View file

@ -6,18 +6,10 @@ from jaxpm.kernels import cic_compensation, fftk
def cic_paint(mesh, positions, weight=None): def cic_paint(mesh, positions, weight=None):
""" """ Paints positions onto mesh
Paint positions onto mesh mesh: [nx, ny, nz]
positions: [npart, 3]
Parameters: """
-----------
mesh: [nx, ny, nz]
positions: [npart, 3]
Returns:
--------
mesh: [nx, ny, nz]
"""
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -43,18 +35,10 @@ def cic_paint(mesh, positions, weight=None):
def cic_read(mesh, positions): def cic_read(mesh, positions):
""" """ Paints positions onto mesh
Read mesh at positions mesh: [nx, ny, nz]
positions: [npart, 3]
Parameters: """
-----------
mesh: [nx, ny, nz]
positions: [npart, 3]
Returns:
--------
values: [npart]
"""
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -72,19 +56,11 @@ def cic_read(mesh, positions):
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
""" """ Paints positions onto a 2d mesh
Paints positions onto 2d mesh mesh: [nx, ny]
positions: [npart, 2]
Parameters: weight: [npart]
----------- """
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
Returns:
--------
mesh: [nx, ny]
"""
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -110,12 +86,12 @@ def cic_paint_2d(mesh, positions, weight):
def compensate_cic(field): def compensate_cic(field):
""" """
Compensate for CiC painting Compensate for CiC painting
Args: Args:
field: input 3D cic-painted field field: input 3D cic-painted field
Returns: Returns:
compensated_field compensated_field
""" """
nc = field.shape nc = field.shape
kvec = fftk(nc) kvec = fftk(nc)