diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 7b46949..fb5dbd5 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,18 +6,10 @@ from jaxpm.kernels import cic_compensation, fftk def cic_paint(mesh, positions, weight=None): - """ - Paint positions onto mesh - - Parameters: - ----------- - mesh: [nx, ny, nz] - positions: [npart, 3] - - Returns: - -------- - mesh: [nx, ny, nz] - """ + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -43,18 +35,10 @@ def cic_paint(mesh, positions, weight=None): def cic_read(mesh, positions): - """ - Read mesh at positions - - Parameters: - ----------- - mesh: [nx, ny, nz] - positions: [npart, 3] - - Returns: - -------- - values: [npart] - """ + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -72,19 +56,11 @@ def cic_read(mesh, positions): def cic_paint_2d(mesh, positions, weight): - """ - Paints positions onto 2d mesh - - Parameters: - ----------- - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - - Returns: - -------- - mesh: [nx, ny] - """ + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -110,12 +86,12 @@ def cic_paint_2d(mesh, positions, weight): def compensate_cic(field): """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ nc = field.shape kvec = fftk(nc)