From a2811c06065a69804253843cacaed00a517e52a5 Mon Sep 17 00:00:00 2001
From: EiffL <fr.eiffel@gmail.com>
Date: Tue, 9 Jul 2024 14:54:34 -0400
Subject: [PATCH] Applying formatting

---
 .pre-commit-config.yaml               |  17 +++
 design.md                             |   4 +-
 dev/test_pfft.py                      |  77 ++++++----
 dev/test_script.py                    |  69 ++++-----
 jaxpm/experimental/distributed_ops.py | 193 +++++++++++++++++---------
 jaxpm/experimental/distributed_pm.py  |  36 +++--
 jaxpm/growth.py                       |  65 ++++-----
 jaxpm/kernels.py                      | 116 ++++++++--------
 jaxpm/lensing.py                      |  64 ++++-----
 jaxpm/nn.py                           |  60 ++++----
 jaxpm/painting.py                     | 128 ++++++++---------
 jaxpm/pm.py                           |  58 +++++---
 jaxpm/utils.py                        | 117 ++++++++--------
 notebooks/Introduction.ipynb          |   2 +-
 setup.py                              |   6 +-
 15 files changed, 566 insertions(+), 446 deletions(-)
 create mode 100644 .pre-commit-config.yaml

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..d476f32
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,17 @@
+repos:
+-   repo: https://github.com/pre-commit/pre-commit-hooks
+    rev: v2.3.0
+    hooks:
+    -   id: check-yaml
+    -   id: end-of-file-fixer
+    -   id: trailing-whitespace
+-   repo: https://github.com/google/yapf
+    rev: v0.40.2
+    hooks:
+    - id: yapf
+      args: ['--parallel', '--in-place']
+-   repo: https://github.com/pycqa/isort
+    rev: 5.13.2
+    hooks:
+      - id: isort
+        name: isort (python)
\ No newline at end of file
diff --git a/design.md b/design.md
index 329270a..a0727a1 100644
--- a/design.md
+++ b/design.md
@@ -4,14 +4,14 @@ This document aims to detail some of the API, implementation choices, and intern
 
 ## Objective
 
-Provide a user-friendly framework for distributed Particle-Mesh N-body simulations. 
+Provide a user-friendly framework for distributed Particle-Mesh N-body simulations.
 
 ## Related Work
 
 This project would be the latest iteration of a number of past libraries that have provided differentiable N-body models.
 
 - [FlowPM](https://github.com/DifferentiableUniverseInitiative/flowpm): TensorFlow
-- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD 
+- [vmad FastPM](https://github.com/rainwoodman/vmad/blob/master/vmad/lib/fastpm.py): VMAD
 - Borg
 
 
diff --git a/dev/test_pfft.py b/dev/test_pfft.py
index 873c238..5a956d8 100644
--- a/dev/test_pfft.py
+++ b/dev/test_pfft.py
@@ -1,57 +1,80 @@
 # Can be executed with:
 # srun  -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
-import jax 
+from functools import partial
+
+import jax
+import jax.lax as lax
 import jax.numpy as jnp
 import numpy as np
-import jax.lax as lax
-from jax.experimental.maps import xmap
-from jax.experimental.maps import Mesh
+from jax.experimental.maps import Mesh, xmap
 from jax.experimental.pjit import PartitionSpec, pjit
-from functools import partial
 
 jax.distributed.initialize()
 
 cube_size = 2048
 
+
 @partial(xmap,
          in_axes=[...],
-         out_axes=['x','y', ...],
-         axis_sizes={'x':cube_size, 'y':cube_size},
-         axis_resources={'x': 'nx', 'y':'ny',
-                         'key_x':'nx', 'key_y':'ny'})
+         out_axes=['x', 'y', ...],
+         axis_sizes={
+             'x': cube_size,
+             'y': cube_size
+         },
+         axis_resources={
+             'x': 'nx',
+             'y': 'ny',
+             'key_x': 'nx',
+             'key_y': 'ny'
+         })
 def pnormal(key):
     return jax.random.normal(key, shape=[cube_size])
 
+
 @partial(xmap,
-         in_axes={0:'x', 1:'y'},
-         out_axes=['x','y', ...],
-         axis_resources={'x': 'nx', 'y': 'ny'})
+         in_axes={
+             0: 'x',
+             1: 'y'
+         },
+         out_axes=['x', 'y', ...],
+         axis_resources={
+             'x': 'nx',
+             'y': 'ny'
+         })
 @jax.jit
 def pfft3d(mesh):
     # [x, y, z]
-    mesh = jnp.fft.fft(mesh) # Transform on z
-    mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
-    mesh = jnp.fft.fft(mesh) # Transform on x
-    mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
-    mesh = jnp.fft.fft(mesh) # Transform on y
+    mesh = jnp.fft.fft(mesh)  # Transform on z
+    mesh = lax.all_to_all(mesh, 'x', 0, 0)  # Now x is exposed, [z,y,x]
+    mesh = jnp.fft.fft(mesh)  # Transform on x
+    mesh = lax.all_to_all(mesh, 'y', 0, 0)  # Now y is exposed, [z,x,y]
+    mesh = jnp.fft.fft(mesh)  # Transform on y
     # [z, x, y]
     return mesh
 
+
 @partial(xmap,
-         in_axes={0:'x', 1:'y'},
-         out_axes=['x','y', ...],
-         axis_resources={'x': 'nx', 'y': 'ny'})
+         in_axes={
+             0: 'x',
+             1: 'y'
+         },
+         out_axes=['x', 'y', ...],
+         axis_resources={
+             'x': 'nx',
+             'y': 'ny'
+         })
 @jax.jit
 def pifft3d(mesh):
     # [z, x, y]
-    mesh = jnp.fft.ifft(mesh) # Transform on y
-    mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
-    mesh = jnp.fft.ifft(mesh) # Transform on x
-    mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
-    mesh = jnp.fft.ifft(mesh) # Transform on z
+    mesh = jnp.fft.ifft(mesh)  # Transform on y
+    mesh = lax.all_to_all(mesh, 'y', 0, 0)  # Now x is exposed, [z,y,x]
+    mesh = jnp.fft.ifft(mesh)  # Transform on x
+    mesh = lax.all_to_all(mesh, 'x', 0, 0)  # Now z is exposed, [x,y,z]
+    mesh = jnp.fft.ifft(mesh)  # Transform on z
     # [x, y, z]
     return mesh
 
+
 key = jax.random.PRNGKey(42)
 # keys = jax.random.split(key, 4).reshape((2,2,2))
 
@@ -68,6 +91,6 @@ with Mesh(devices, ('nx', 'ny')):
 #     mesh = pnormal(key)
 #     kmesh = pfft3d(mesh)
 #     kmesh.block_until_ready()
-# jax.profiler.stop_trace()    
+# jax.profiler.stop_trace()
 
-print('Done')
\ No newline at end of file
+print('Done')
diff --git a/dev/test_script.py b/dev/test_script.py
index a9566c2..4f3ca06 100644
--- a/dev/test_script.py
+++ b/dev/test_script.py
@@ -1,48 +1,53 @@
 # Start this script with:
 # mpirun -np 4 python test_script.py
 import os
+
 os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
-import matplotlib.pylab as plt
-import jax 
-import numpy as np
-import jax.numpy as jnp
+import jax
 import jax.lax as lax
+import jax.numpy as jnp
+import matplotlib.pylab as plt
+import numpy as np
+import tensorflow_probability as tfp
 from jax.experimental.maps import mesh, xmap
 from jax.experimental.pjit import PartitionSpec, pjit
-import tensorflow_probability as tfp; tfp = tfp.substrates.jax
+
+tfp = tfp.substrates.jax
 tfd = tfp.distributions
 
+
 def cic_paint(mesh, positions):
-  """ Paints positions onto mesh
+    """ Paints positions onto mesh
   mesh: [nx, ny, nz]
   positions: [npart, 3]
   """
-  positions = jnp.expand_dims(positions, 1)
-  floor = jnp.floor(positions)
-  connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], 
-                           [0., 0, 1], [1., 1, 0], [1., 0, 1], 
-                           [0., 1, 1], [1., 1, 1]]])
+    positions = jnp.expand_dims(positions, 1)
+    floor = jnp.floor(positions)
+    connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
+                             [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
 
-  neighboor_coords = floor + connection
-  kernel = 1. - jnp.abs(positions - neighboor_coords)
-  kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]  
+    neighboor_coords = floor + connection
+    kernel = 1. - jnp.abs(positions - neighboor_coords)
+    kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
+
+    dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
+                                            inserted_window_dims=(0, 1, 2),
+                                            scatter_dims_to_operand_dims=(0, 1,
+                                                                          2))
+    mesh = lax.scatter_add(
+        mesh,
+        neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
+        kernel.reshape([-1, 8]), dnums)
+    return mesh
 
-  dnums = jax.lax.ScatterDimensionNumbers(
-    update_window_dims=(),
-    inserted_window_dims=(0, 1, 2),
-    scatter_dims_to_operand_dims=(0, 1, 2))
-  mesh = lax.scatter_add(mesh, 
-                         neighboor_coords.reshape([-1,8,3]).astype('int32'), 
-                         kernel.reshape([-1,8]),
-                         dnums)
-  return mesh
 
 # And let's draw some points from some 3D distribution
-dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
+dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
+                                  scale_identity_multiplier=3.)
 pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
 
 f = pjit(lambda x: cic_paint(x, pos),
-         in_axis_resources=PartitionSpec('x', 'y', 'z'), 
+         in_axis_resources=PartitionSpec('x', 'y', 'z'),
          out_axis_resources=None)
 
 devices = np.array(jax.devices()).reshape((2, 2, 1))
@@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
 m = jnp.zeros([32, 32, 32])
 
 with mesh(devices, ('x', 'y', 'z')):
-  # Shard the mesh, I'm not sure this is absolutely necessary
-  m = pjit(lambda x: x,
-           in_axis_resources=None,
-           out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
+    # Shard the mesh, I'm not sure this is absolutely necessary
+    m = pjit(lambda x: x,
+             in_axis_resources=None,
+             out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
 
-  # Apply the sharded CiC function
-  res = f(m)
+    # Apply the sharded CiC function
+    res = f(m)
 
 plt.imshow(res.sum(axis=2))
-plt.show()
\ No newline at end of file
+plt.show()
diff --git a/jaxpm/experimental/distributed_ops.py b/jaxpm/experimental/distributed_ops.py
index 6417f35..a06b03c 100644
--- a/jaxpm/experimental/distributed_ops.py
+++ b/jaxpm/experimental/distributed_ops.py
@@ -1,11 +1,12 @@
-import jax
-import jax.numpy as jnp
-import jax.lax as lax
 from functools import partial
-from jax.experimental.maps import xmap
-from jax.experimental.pjit import pjit, PartitionSpec
 
+import jax
+import jax.lax as lax
+import jax.numpy as jnp
 import jax_cosmo as jc
+from jax.experimental.maps import xmap
+from jax.experimental.pjit import PartitionSpec, pjit
+
 import jaxpm.painting as paint
 
 # TODO: add a way to configure axis resources from command line
@@ -14,35 +15,59 @@ mesh_size = {'nx': 2, 'ny': 2}
 
 
 @partial(xmap,
-         in_axes=({0: 'x', 2: 'y'},
-                  {0: 'x', 2: 'y'},
-                  {0: 'x', 2: 'y'}),
-         out_axes=({0: 'x', 2: 'y'}),
+         in_axes=({
+             0: 'x',
+             2: 'y'
+         }, {
+             0: 'x',
+             2: 'y'
+         }, {
+             0: 'x',
+             2: 'y'
+         }),
+         out_axes=({
+             0: 'x',
+             2: 'y'
+         }),
          axis_resources=axis_resources)
 def stack3d(a, b, c):
     return jnp.stack([a, b, c], axis=-1)
 
 
 @partial(xmap,
-         in_axes=({0: 'x', 2: 'y'},[...]),
-         out_axes=({0: 'x', 2: 'y'}),
+         in_axes=({
+             0: 'x',
+             2: 'y'
+         }, [...]),
+         out_axes=({
+             0: 'x',
+             2: 'y'
+         }),
          axis_resources=axis_resources)
 def scalar_multiply(a, factor):
     return a * factor
 
 
 @partial(xmap,
-         in_axes=({0: 'x', 2: 'y'},
-                  {0: 'x', 2: 'y'}),
-         out_axes=({0: 'x', 2: 'y'}),
+         in_axes=({
+             0: 'x',
+             2: 'y'
+         }, {
+             0: 'x',
+             2: 'y'
+         }),
+         out_axes=({
+             0: 'x',
+             2: 'y'
+         }),
          axis_resources=axis_resources)
 def add(a, b):
     return a + b
 
 
 @partial(xmap,
-         in_axes=['x', 'y',...],
-         out_axes=['x', 'y',...],
+         in_axes=['x', 'y', ...],
+         out_axes=['x', 'y', ...],
          axis_resources=axis_resources)
 def fft3d(mesh):
     """ Performs a 3D complex Fourier transform
@@ -51,7 +76,7 @@ def fft3d(mesh):
         mesh: a real 3D tensor of shape [Nx, Ny, Nz]
 
     Returns:
-        3D FFT of the input, note that the dimensions of the output 
+        3D FFT of the input, note that the dimensions of the output
         are tranposed.
     """
     mesh = jnp.fft.fft(mesh)
@@ -62,8 +87,8 @@ def fft3d(mesh):
 
 
 @partial(xmap,
-         in_axes=['x', 'y',...],
-         out_axes=['x', 'y',...],
+         in_axes=['x', 'y', ...],
+         out_axes=['x', 'y', ...],
          axis_resources=axis_resources)
 def ifft3d(mesh):
     mesh = jnp.fft.ifft(mesh)
@@ -72,10 +97,15 @@ def ifft3d(mesh):
     mesh = lax.all_to_all(mesh, 'x', 0, 0)
     return jnp.fft.ifft(mesh).real
 
+
 def normal(key, shape=[]):
+
     @partial(xmap,
-             in_axes=['x', 'y',...],
-             out_axes={0: 'x', 2: 'y'},
+             in_axes=['x', 'y', ...],
+             out_axes={
+                 0: 'x',
+                 2: 'y'
+             },
              axis_resources=axis_resources)
     def fn(key):
         """ Generate a distributed random normal distributions
@@ -83,99 +113,126 @@ def normal(key, shape=[]):
             key: array of random keys with same layout as computational mesh
             shape: logical shape of array to sample
         """
-        return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
-                                             shape[1]//mesh_size['ny']]+shape[2:])
+        return jax.random.normal(
+            key,
+            shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
+            shape[2:])
+
     return fn(key)
 
 
 @partial(xmap,
-         in_axes=(['x', 'y', ...],
-                  [['x'], ['y'], [...]], [...], [...]),
+         in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
          out_axes=['x', 'y', ...],
          axis_resources=axis_resources)
 @jax.jit
 def scale_by_power_spectrum(kfield, kvec, k, pk):
     kx, ky, kz = kvec
-    kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
+    kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
     return kfield * jc.scipy.interpolate.interp(kk, k, pk)
 
 
 @partial(xmap,
-         in_axes=(['x', 'y', 'z'],
-                  [['x'], ['y'], ['z']]),
+         in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
          out_axes=(['x', 'y', 'z']),
          axis_resources=axis_resources)
 def gradient_laplace_kernel(kfield, kvec):
     kx, ky, kz = kvec
     kk = (kx**2 + ky**2 + kz**2)
-    kernel = jnp.where(kk == 0, 1., 1./kk)
-    return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
-            kfield * kernel * 1j * 1 / 6.0 *
-            (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
-            kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
+    kernel = jnp.where(kk == 0, 1., 1. / kk)
+    return (kfield * kernel * 1j * 1 / 6.0 *
+            (8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
+            6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
+            1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
 
 
 @partial(xmap,
          in_axes=([...]),
-         out_axes={0: 'x', 2: 'y'},
-         axis_sizes={'x': mesh_size['nx'],
-                     'y': mesh_size['ny']},
+         out_axes={
+             0: 'x',
+             2: 'y'
+         },
+         axis_sizes={
+             'x': mesh_size['nx'],
+             'y': mesh_size['ny']
+         },
          axis_resources=axis_resources)
 def meshgrid(x, y, z):
-    """ Generates a mesh grid of appropriate size for the 
+    """ Generates a mesh grid of appropriate size for the
     computational mesh we have.
     """
-    return jnp.stack(jnp.meshgrid(x, 
-                                  y,
-                                  z), axis=-1)
+    return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
 
 
 def cic_paint(pos, mesh_shape, halo_size=0):
+
     @partial(xmap,
-             in_axes=({0: 'x', 2: 'y'}),
-             out_axes=({0: 'x', 2: 'y'}),
+             in_axes=({
+                 0: 'x',
+                 2: 'y'
+             }),
+             out_axes=({
+                 0: 'x',
+                 2: 'y'
+             }),
              axis_resources=axis_resources)
     def fn(pos):
 
-        mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
-                          mesh_shape[1]//mesh_size['ny']+2*halo_size]
-                         + mesh_shape[2:])
+        mesh = jnp.zeros([
+            mesh_shape[0] // mesh_size['nx'] +
+            2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
+        ] + mesh_shape[2:])
 
         # Paint particles
-        mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) +
-                             jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
+        mesh = paint.cic_paint(
+            mesh,
+            pos.reshape(-1, 3) +
+            jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
 
         # Perform halo exchange
         # Halo exchange along x
-        left = lax.pshuffle(mesh[-2*halo_size:],
+        left = lax.pshuffle(mesh[-2 * halo_size:],
                             perm=range(mesh_size['nx'])[::-1],
                             axis_name='x')
-        right = lax.pshuffle(mesh[:2*halo_size],
+        right = lax.pshuffle(mesh[:2 * halo_size],
                              perm=range(mesh_size['nx'])[::-1],
                              axis_name='x')
-        mesh = mesh.at[:2*halo_size].add(left)
-        mesh = mesh.at[-2*halo_size:].add(right)
+        mesh = mesh.at[:2 * halo_size].add(left)
+        mesh = mesh.at[-2 * halo_size:].add(right)
 
         # Halo exchange along y
-        left = lax.pshuffle(mesh[:, -2*halo_size:],
+        left = lax.pshuffle(mesh[:, -2 * halo_size:],
                             perm=range(mesh_size['ny'])[::-1],
                             axis_name='y')
-        right = lax.pshuffle(mesh[:, :2*halo_size],
+        right = lax.pshuffle(mesh[:, :2 * halo_size],
                              perm=range(mesh_size['ny'])[::-1],
                              axis_name='y')
-        mesh = mesh.at[:, :2*halo_size].add(left)
-        mesh = mesh.at[:, -2*halo_size:].add(right)
+        mesh = mesh.at[:, :2 * halo_size].add(left)
+        mesh = mesh.at[:, -2 * halo_size:].add(right)
 
         # removing halo and returning mesh
         return mesh[halo_size:-halo_size, halo_size:-halo_size]
 
     return fn(pos)
 
+
 def cic_read(mesh, pos, halo_size=0):
+
     @partial(xmap,
-             in_axes=({0: 'x', 2: 'y'},
-                      {0: 'x', 2: 'y'},),
-             out_axes=({0: 'x', 2: 'y'}),
+             in_axes=(
+                 {
+                     0: 'x',
+                     2: 'y'
+                 },
+                 {
+                     0: 'x',
+                     2: 'y'
+                 },
+             ),
+             out_axes=({
+                 0: 'x',
+                 2: 'y'
+             }),
              axis_resources=axis_resources)
     def fn(mesh, pos):
 
@@ -198,11 +255,13 @@ def cic_read(mesh, pos, halo_size=0):
         mesh = jnp.concatenate([left, mesh, right], axis=1)
 
         # Reading field at particles positions
-        res = paint.cic_read(mesh, pos.reshape(-1, 3) +
-                                    jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
+        res = paint.cic_read(
+            mesh,
+            pos.reshape(-1, 3) +
+            jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
 
         return res.reshape(pos.shape[:-1])
-    
+
     return fn(mesh, pos)
 
 
@@ -211,12 +270,14 @@ def cic_read(mesh, pos, halo_size=0):
          out_axis_resources=PartitionSpec('nx', None, 'ny', None))
 def reshape_dense_to_split(x):
     """ Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
-    Changes the logical shape of the array, but no shuffling of the 
+    Changes the logical shape of the array, but no shuffling of the
     data should be necessary
     """
     shape = list(x.shape)
-    return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
-                      mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
+    return x.reshape([
+        mesh_size['nx'], shape[0] //
+        mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
+    ] + shape[2:])
 
 
 @partial(pjit,
@@ -224,8 +285,8 @@ def reshape_dense_to_split(x):
          out_axis_resources=PartitionSpec('nx', 'ny'))
 def reshape_split_to_dense(x):
     """ Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
-    Changes the logical shape of the array, but no shuffling of the 
+    Changes the logical shape of the array, but no shuffling of the
     data should be necessary
     """
     shape = list(x.shape)
-    return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
+    return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])
diff --git a/jaxpm/experimental/distributed_pm.py b/jaxpm/experimental/distributed_pm.py
index ef3e48c..b633cf7 100644
--- a/jaxpm/experimental/distributed_pm.py
+++ b/jaxpm/experimental/distributed_pm.py
@@ -1,13 +1,14 @@
-import jax
-from jax.lax import linear_solve_p
-import jax.numpy as jnp
-from jax.experimental.maps import xmap
 from functools import partial
-import jax_cosmo as jc
 
-from jaxpm.kernels import fftk
+import jax
+import jax.numpy as jnp
+import jax_cosmo as jc
+from jax.experimental.maps import xmap
+from jax.lax import linear_solve_p
+
 import jaxpm.experimental.distributed_ops as dops
-from jaxpm.growth import growth_factor, growth_rate, dGfa
+from jaxpm.growth import dGfa, growth_factor, growth_rate
+from jaxpm.kernels import fftk
 
 
 def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
@@ -25,8 +26,10 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
     forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
 
     # Recovers forces at particle positions
-    forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
-                            positions, halo_size) for f in forces_k]
+    forces = [
+        dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions,
+                      halo_size) for f in forces_k
+    ]
 
     return dops.stack3d(*forces)
 
@@ -44,12 +47,14 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
     field = dops.fft3d(dops.reshape_split_to_dense(field))
 
     # Rescaling k to physical units
-    kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
-            for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
+    kvec = [
+        k.squeeze() / box_size[i] * mesh_shape[i]
+        for i, k in enumerate(fftk(mesh_shape, symmetric=False))
+    ]
     k = jnp.logspace(-4, 2, 256)
     pk = jc.power.linear_matter_power(cosmo, k)
-    pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
-               ) / (box_size[0] * box_size[1] * box_size[2])
+    pk = pk * (mesh_shape[0] * mesh_shape[1] *
+               mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
 
     field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
 
@@ -66,8 +71,9 @@ def lpt(cosmo, initial_conditions, positions, a):
     initial_force = pm_forces(positions, delta_k=initial_conditions)
     a = jnp.atleast_1d(a)
     dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a))
-    p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
-                             jnp.sqrt(jc.background.Esqr(cosmo, a)))
+    p = dops.scalar_multiply(
+        dx,
+        a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)))
     return dx, p
 
 
diff --git a/jaxpm/growth.py b/jaxpm/growth.py
index 0be4718..5b6908c 100644
--- a/jaxpm/growth.py
+++ b/jaxpm/growth.py
@@ -1,8 +1,8 @@
 import jax.numpy as np
-
+from jax_cosmo.background import *
 from jax_cosmo.scipy.interpolate import interp
 from jax_cosmo.scipy.ode import odeint
-from jax_cosmo.background import *
+
 
 def E(cosmo, a):
     r"""Scale factor dependent factor E(a) in the Hubble
@@ -52,12 +52,8 @@ def df_de(cosmo, a, epsilon=1e-5):
     \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
     \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
     """
-    return (
-        3
-        * cosmo.wa
-        * (np.log(a - epsilon) - (a - 1) / (a - epsilon))
-        / np.power(np.log(a - epsilon), 2)
-    )
+    return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
+            np.power(np.log(a - epsilon), 2))
 
 
 def dEa(cosmo, a):
@@ -89,15 +85,11 @@ def dEa(cosmo, a):
     where :math:`f(a)` is the Dark Energy evolution parameter computed
     by :py:meth:`.f_de`.
     """
-    return (
-        0.5
-        * (
-            -3 * cosmo.Omega_m * np.power(a, -4)
-            - 2 * cosmo.Omega_k * np.power(a, -3)
-            + df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))
-        )
-        / np.power(Esqr(cosmo, a), 0.5)
-    )
+    return (0.5 *
+            (-3 * cosmo.Omega_m * np.power(a, -4) -
+             2 * cosmo.Omega_k * np.power(a, -3) +
+             df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
+            np.power(Esqr(cosmo, a), 0.5))
 
 
 def growth_factor(cosmo, a):
@@ -155,8 +147,7 @@ def growth_factor_second(cosmo, a):
     """
     if cosmo._flags["gamma_growth"]:
         raise NotImplementedError(
-            "Gamma growth rate is not implemented for second order growth!"
-        )
+            "Gamma growth rate is not implemented for second order growth!")
         return None
     else:
         return _growth_factor_second_ODE(cosmo, a)
@@ -228,8 +219,7 @@ def growth_rate_second(cosmo, a):
     """
     if cosmo._flags["gamma_growth"]:
         raise NotImplementedError(
-            "Gamma growth factor is not implemented for second order growth!"
-        )
+            "Gamma growth factor is not implemented for second order growth!")
         return None
     else:
         return _growth_rate_second_ODE(cosmo, a)
@@ -258,23 +248,19 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
         atab = np.logspace(log10_amin, 0.0, steps)
 
         def D_derivs(y, x):
-            q = (
-                2.0
-                - 0.5
-                * (
-                    Omega_m_a(cosmo, x)
-                    + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x)
-                )
-            ) / x
+            q = (2.0 - 0.5 *
+                 (Omega_m_a(cosmo, x) +
+                  (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
             r = 1.5 * Omega_m_a(cosmo, x) / x / x
 
             g1, g2 = y[0]
             f1, f2 = y[1]
             dy1da = [f1, -q * f1 + r * g1]
-            dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2]
+            dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
             return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
 
-        y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]])
+        y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
+                       [1.0, -6.0 / 7 * atab[0]]])
         y = odeint(D_derivs, y0, atab)
 
         # compute second order derivatives growth
@@ -473,8 +459,7 @@ def _growth_rate_gamma(cosmo, a):
 
     see :cite:`2019:Euclid Preparation VII, eqn.32`
     """
-    return Omega_m_a(cosmo, a) ** cosmo.gamma
-
+    return Omega_m_a(cosmo, a)**cosmo.gamma
 
 
 def Gf(cosmo, a):
@@ -503,7 +488,7 @@ def Gf(cosmo, a):
     """
     f1 = growth_rate(cosmo, a)
     g1 = growth_factor(cosmo, a)
-    D1f = f1*g1/ a
+    D1f = f1 * g1 / a
     return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
 
 
@@ -532,7 +517,7 @@ def Gf2(cosmo, a):
     """
     f2 = growth_rate_second(cosmo, a)
     g2 = growth_factor_second(cosmo, a)
-    D2f = f2*g2/ a
+    D2f = f2 * g2 / a
     return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
 
 
@@ -563,13 +548,12 @@ def dGfa(cosmo, a):
     """
     f1 = growth_rate(cosmo, a)
     g1 = growth_factor(cosmo, a)
-    D1f = f1*g1/ a
+    D1f = f1 * g1 / a
     cache = cosmo._workspace['background.growth_factor']
     f1p = cache['h'] / cache['a'] * cache['g']
     f1p = interp(np.log(a), np.log(cache['a']), f1p)
     Ea = E(cosmo, a)
-    return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) +
-            3 * a**2 * Ea * D1f)
+    return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
 
 
 def dGf2a(cosmo, a):
@@ -599,10 +583,9 @@ def dGf2a(cosmo, a):
     """
     f2 = growth_rate_second(cosmo, a)
     g2 = growth_factor_second(cosmo, a)
-    D2f = f2*g2/ a
+    D2f = f2 * g2 / a
     cache = cosmo._workspace['background.growth_factor']
     f2p = cache['h2'] / cache['a'] * cache['g2']
     f2p = interp(np.log(a), np.log(cache['a']), f2p)
     E = E(cosmo, a)
-    return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) +
-            3 * a**2 * E * D2f)
\ No newline at end of file
+    return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py
index 97d34dd..8447f8a 100644
--- a/jaxpm/kernels.py
+++ b/jaxpm/kernels.py
@@ -1,25 +1,27 @@
-import numpy as np
 import jax.numpy as jnp
+import numpy as np
+
 
 def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
-  """ Return k_vector given a shape (nc, nc, nc) and box_size
+    """ Return k_vector given a shape (nc, nc, nc) and box_size
   """
-  k = []
-  for d in range(len(shape)):
-    kd = np.fft.fftfreq(shape[d])
-    kd *= 2 * np.pi
-    kdshape = np.ones(len(shape), dtype='int')
-    if symmetric and d == len(shape) - 1:
-      kd = kd[:shape[d] // 2 + 1]
-    kdshape[d] = len(kd)
-    kd = kd.reshape(kdshape)
+    k = []
+    for d in range(len(shape)):
+        kd = np.fft.fftfreq(shape[d])
+        kd *= 2 * np.pi
+        kdshape = np.ones(len(shape), dtype='int')
+        if symmetric and d == len(shape) - 1:
+            kd = kd[:shape[d] // 2 + 1]
+        kdshape[d] = len(kd)
+        kd = kd.reshape(kdshape)
+
+        k.append(kd.astype(dtype))
+    del kd, kdshape
+    return k
 
-    k.append(kd.astype(dtype))
-  del kd, kdshape
-  return k
 
 def gradient_kernel(kvec, direction, order=1):
-  """
+    """
   Computes the gradient kernel in the requested direction
   Parameters:
   -----------
@@ -32,20 +34,21 @@ def gradient_kernel(kvec, direction, order=1):
   wts: array
     Complex kernel
   """
-  if order == 0:
-    wts = 1j * kvec[direction]
-    wts = jnp.squeeze(wts)
-    wts[len(wts) // 2] = 0
-    wts = wts.reshape(kvec[direction].shape)
-    return wts
-  else:
-    w = kvec[direction]
-    a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
-    wts = a * 1j
-    return wts
+    if order == 0:
+        wts = 1j * kvec[direction]
+        wts = jnp.squeeze(wts)
+        wts[len(wts) // 2] = 0
+        wts = wts.reshape(kvec[direction].shape)
+        return wts
+    else:
+        w = kvec[direction]
+        a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w))
+        wts = a * 1j
+        return wts
+
 
 def laplace_kernel(kvec):
-  """
+    """
   Compute the Laplace kernel from a given K vector
   Parameters:
   -----------
@@ -56,16 +59,17 @@ def laplace_kernel(kvec):
   wts: array
     Complex kernel
   """
-  kk = sum(ki**2 for ki in kvec)
-  mask = (kk == 0).nonzero()
-  kk[mask] = 1
-  wts = 1. / kk
-  imask = (~(kk == 0)).astype(int)
-  wts *= imask
-  return wts
+    kk = sum(ki**2 for ki in kvec)
+    mask = (kk == 0).nonzero()
+    kk[mask] = 1
+    wts = 1. / kk
+    imask = (~(kk == 0)).astype(int)
+    wts *= imask
+    return wts
+
 
 def longrange_kernel(kvec, r_split):
-  """
+    """
   Computes a long range kernel
   Parameters:
   -----------
@@ -78,29 +82,31 @@ def longrange_kernel(kvec, r_split):
   wts: array
     kernel
   """
-  if r_split != 0:
-    kk = sum(ki**2 for ki in kvec)
-    return np.exp(-kk * r_split**2)
-  else:
-    return 1.
+    if r_split != 0:
+        kk = sum(ki**2 for ki in kvec)
+        return np.exp(-kk * r_split**2)
+    else:
+        return 1.
+
 
 def cic_compensation(kvec):
-  """
+    """
   Computes cic compensation kernel.
   Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
   Itself based on equation 18 (with p=2) of
         `Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
   Args:
-    kvec: array of k values in Fourier space  
+    kvec: array of k values in Fourier space
   Returns:
     v: array of kernel
   """
-  kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
-  wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
-  return wts
+    kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
+    wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
+    return wts
+
 
 def PGD_kernel(kvec, kl, ks):
-  """
+    """
   Computes the PGD kernel
   Parameters:
   -----------
@@ -115,12 +121,12 @@ def PGD_kernel(kvec, kl, ks):
   v: array
     kernel
   """
-  kk = sum(ki**2 for ki in kvec)
-  kl2 = kl**2
-  ks4 = ks**4
-  mask = (kk == 0).nonzero()
-  kk[mask] = 1
-  v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
-  imask = (~(kk == 0)).astype(int)
-  v *= imask
-  return v
\ No newline at end of file
+    kk = sum(ki**2 for ki in kvec)
+    kl2 = kl**2
+    ks4 = ks**4
+    mask = (kk == 0).nonzero()
+    kk[mask] = 1
+    v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
+    imask = (~(kk == 0)).astype(int)
+    v *= imask
+    return v
diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py
index b4beeef..0143adc 100644
--- a/jaxpm/lensing.py
+++ b/jaxpm/lensing.py
@@ -1,11 +1,12 @@
-import jax 
+import jax
 import jax.numpy as jnp
-import jax_cosmo.constants as constants
 import jax_cosmo
-
+import jax_cosmo.constants as constants
 from jax.scipy.ndimage import map_coordinates
-from jaxpm.utils import gaussian_smoothing
+
 from jaxpm.painting import cic_paint_2d
+from jaxpm.utils import gaussian_smoothing
+
 
 def density_plane(positions,
                   box_shape,
@@ -26,9 +27,11 @@ def density_plane(positions,
     xy = xy / nx * plane_resolution
 
     # Selecting only particles that fall inside the volume of interest
-    weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
+    weight = jnp.where(
+        (d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
     # Painting density plane
-    density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
+    density_plane = cic_paint_2d(
+        jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
 
     # Apply density normalization
     density_plane = density_plane / ((nx / plane_resolution) *
@@ -36,45 +39,44 @@ def density_plane(positions,
 
     # Apply Gaussian smoothing if requested
     if smoothing_sigma is not None:
-        density_plane = gaussian_smoothing(density_plane, 
-                                           smoothing_sigma)
+        density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
 
     return density_plane
 
 
-def convergence_Born(cosmo,
-                     density_planes,
-                     coords,
-                     z_source):
-  """
+def convergence_Born(cosmo, density_planes, coords, z_source):
+    """
   Compute the Born convergence
   Args:
     cosmo: `Cosmology`, cosmology object.
-    density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use 
+    density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
     coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
     z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
     name: `string`, name of the operation.
   Returns:
     `Tensor` of shape [batch_size, N, Nz], of convergence values.
   """
-  # Compute constant prefactor:
-  constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
-  # Compute comoving distance of source galaxies
-  r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
+    # Compute constant prefactor:
+    constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
+    # Compute comoving distance of source galaxies
+    r_s = jax_cosmo.background.radial_comoving_distance(
+        cosmo, 1 / (1 + z_source))
 
-  convergence = 0
-  for entry in density_planes:
-    r = entry['r']; a = entry['a']; p = entry['plane']
-    dx = entry['dx']; dz = entry['dz']
-    # Normalize density planes
-    density_normalization = dz * r / a
-    p = (p - p.mean()) * constant_factor * density_normalization
+    convergence = 0
+    for entry in density_planes:
+        r = entry['r']
+        a = entry['a']
+        p = entry['plane']
+        dx = entry['dx']
+        dz = entry['dz']
+        # Normalize density planes
+        density_normalization = dz * r / a
+        p = (p - p.mean()) * constant_factor * density_normalization
 
-    # Interpolate at the density plane coordinates
-    im = map_coordinates(p, 
-                         coords * r / dx - 0.5, 
-                         order=1, mode="wrap")
+        # Interpolate at the density plane coordinates
+        im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
 
-    convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])
+        convergence += im * jnp.clip(1. -
+                                     (r / r_s), 0, 1000).reshape([-1, 1, 1])
 
-  return convergence
+    return convergence
diff --git a/jaxpm/nn.py b/jaxpm/nn.py
index 933ea53..d8f27be 100644
--- a/jaxpm/nn.py
+++ b/jaxpm/nn.py
@@ -1,6 +1,7 @@
+import haiku as hk
 import jax
 import jax.numpy as jnp
-import haiku as hk
+
 
 def _deBoorVectorized(x, t, c, p):
     """
@@ -13,48 +14,47 @@ def _deBoorVectorized(x, t, c, p):
     c: array of control points
     p: degree of B-spline
     """
-    k = jnp.digitize(x, t) -1
-    
-    d = [c[j + k - p] for j in range(0, p+1)]
-    for r in range(1, p+1):
-        for j in range(p, r-1, -1):
-            alpha = (x - t[j+k-p]) / (t[j+1+k-r] - t[j+k-p])
-            d[j] = (1.0 - alpha) * d[j-1] + alpha * d[j]
+    k = jnp.digitize(x, t) - 1
+
+    d = [c[j + k - p] for j in range(0, p + 1)]
+    for r in range(1, p + 1):
+        for j in range(p, r - 1, -1):
+            alpha = (x - t[j + k - p]) / (t[j + 1 + k - r] - t[j + k - p])
+            d[j] = (1.0 - alpha) * d[j - 1] + alpha * d[j]
     return d[p]
 
 
 class NeuralSplineFourierFilter(hk.Module):
-  """A rotationally invariant filter parameterized by 
+    """A rotationally invariant filter parameterized by
   a b-spline with parameters specified by a small NN."""
 
-  def __init__(self, n_knots=8, latent_size=16, name=None):
+    def __init__(self, n_knots=8, latent_size=16, name=None):
+        """
+    n_knots: number of control points for the spline
     """
-    n_knots: number of control points for the spline  
-    """
-    super().__init__(name=name)
-    self.n_knots = n_knots
-    self.latent_size = latent_size
+        super().__init__(name=name)
+        self.n_knots = n_knots
+        self.latent_size = latent_size
 
-  def __call__(self, x, a):
-    """ 
+    def __call__(self, x, a):
+        """
     x: array, scale, normalized to fftfreq default
     a: scalar, scale factor
     """
 
-    net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
-    net = jnp.sin(hk.Linear(self.latent_size)(net))
+        net = jnp.sin(hk.Linear(self.latent_size)(jnp.atleast_1d(a)))
+        net = jnp.sin(hk.Linear(self.latent_size)(net))
 
-    w = hk.Linear(self.n_knots+1)(net) 
-    k = hk.Linear(self.n_knots-1)(net)
-    
-    # make sure the knots sum to 1 and are in the interval 0,1
-    k = jnp.concatenate([jnp.zeros((1,)),
-                        jnp.cumsum(jax.nn.softmax(k))])
+        w = hk.Linear(self.n_knots + 1)(net)
+        k = hk.Linear(self.n_knots - 1)(net)
 
-    w = jnp.concatenate([jnp.zeros((1,)),
-                         w])
+        # make sure the knots sum to 1 and are in the interval 0,1
+        k = jnp.concatenate([jnp.zeros((1, )), jnp.cumsum(jax.nn.softmax(k))])
 
-    # Augment with repeating points
-    ak = jnp.concatenate([jnp.zeros((3,)), k, jnp.ones((3,))])
+        w = jnp.concatenate([jnp.zeros((1, )), w])
 
-    return _deBoorVectorized(jnp.clip(x/jnp.sqrt(3), 0, 1-1e-4), ak, w, 3)
\ No newline at end of file
+        # Augment with repeating points
+        ak = jnp.concatenate([jnp.zeros((3, )), k, jnp.ones((3, ))])
+
+        return _deBoorVectorized(jnp.clip(x / jnp.sqrt(3), 0, 1 - 1e-4), ak, w,
+                                 3)
diff --git a/jaxpm/painting.py b/jaxpm/painting.py
index 67d54b0..fb5dbd5 100644
--- a/jaxpm/painting.py
+++ b/jaxpm/painting.py
@@ -1,98 +1,100 @@
 import jax
-import jax.numpy as jnp
 import jax.lax as lax
+import jax.numpy as jnp
+
+from jaxpm.kernels import cic_compensation, fftk
 
-from jaxpm.kernels import fftk, cic_compensation
 
 def cic_paint(mesh, positions, weight=None):
-  """ Paints positions onto mesh
+    """ Paints positions onto mesh
   mesh: [nx, ny, nz]
   positions: [npart, 3]
   """
-  positions = jnp.expand_dims(positions, 1)
-  floor = jnp.floor(positions)
-  connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], 
-                           [0., 0, 1], [1., 1, 0], [1., 0, 1], 
-                           [0., 1, 1], [1., 1, 1]]])
+    positions = jnp.expand_dims(positions, 1)
+    floor = jnp.floor(positions)
+    connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
+                             [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
 
-  neighboor_coords = floor + connection
-  kernel = 1. - jnp.abs(positions - neighboor_coords)
-  kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]  
-  if weight is not None:
+    neighboor_coords = floor + connection
+    kernel = 1. - jnp.abs(positions - neighboor_coords)
+    kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
+    if weight is not None:
         kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
-  
-  neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
 
-  dnums = jax.lax.ScatterDimensionNumbers(
-    update_window_dims=(),
-    inserted_window_dims=(0, 1, 2),
-    scatter_dims_to_operand_dims=(0, 1, 2))
-  mesh = lax.scatter_add(mesh, 
-                         neighboor_coords, 
-                         kernel.reshape([-1,8]),
-                         dnums)
-  return mesh
+    neighboor_coords = jnp.mod(
+        neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
+        jnp.array(mesh.shape))
+
+    dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
+                                            inserted_window_dims=(0, 1, 2),
+                                            scatter_dims_to_operand_dims=(0, 1,
+                                                                          2))
+    mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
+                           dnums)
+    return mesh
+
 
 def cic_read(mesh, positions):
-  """ Paints positions onto mesh
+    """ Paints positions onto mesh
   mesh: [nx, ny, nz]
   positions: [npart, 3]
-  """  
-  positions = jnp.expand_dims(positions, 1)
-  floor = jnp.floor(positions)
-  connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], 
-                           [0., 0, 1], [1., 1, 0], [1., 0, 1], 
-                           [0., 1, 1], [1., 1, 1]]])
+  """
+    positions = jnp.expand_dims(positions, 1)
+    floor = jnp.floor(positions)
+    connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
+                             [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
 
-  neighboor_coords = floor + connection
-  kernel = 1. - jnp.abs(positions - neighboor_coords)
-  kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]  
+    neighboor_coords = floor + connection
+    kernel = 1. - jnp.abs(positions - neighboor_coords)
+    kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
 
-  neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape))
+    neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
+                               jnp.array(mesh.shape))
+
+    return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
+                 neighboor_coords[..., 3]] * kernel).sum(axis=-1)
 
-  return (mesh[neighboor_coords[...,0], 
-               neighboor_coords[...,1], 
-               neighboor_coords[...,3]]*kernel).sum(axis=-1)
 
 def cic_paint_2d(mesh, positions, weight):
-  """ Paints positions onto a 2d mesh
+    """ Paints positions onto a 2d mesh
   mesh: [nx, ny]
   positions: [npart, 2]
   weight: [npart]
   """
-  positions = jnp.expand_dims(positions, 1)
-  floor = jnp.floor(positions)
-  connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
+    positions = jnp.expand_dims(positions, 1)
+    floor = jnp.floor(positions)
+    connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
 
-  neighboor_coords = floor + connection
-  kernel = 1. - jnp.abs(positions - neighboor_coords)
-  kernel = kernel[..., 0] * kernel[..., 1] 
-  if weight is not None:
-    kernel = kernel * weight[...,jnp.newaxis]
-  
-  neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
+    neighboor_coords = floor + connection
+    kernel = 1. - jnp.abs(positions - neighboor_coords)
+    kernel = kernel[..., 0] * kernel[..., 1]
+    if weight is not None:
+        kernel = kernel * weight[..., jnp.newaxis]
+
+    neighboor_coords = jnp.mod(
+        neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
+        jnp.array(mesh.shape))
+
+    dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
+                                            inserted_window_dims=(0, 1),
+                                            scatter_dims_to_operand_dims=(0,
+                                                                          1))
+    mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
+                           dnums)
+    return mesh
 
-  dnums = jax.lax.ScatterDimensionNumbers(
-    update_window_dims=(),
-    inserted_window_dims=(0, 1),
-    scatter_dims_to_operand_dims=(0, 1))
-  mesh = lax.scatter_add(mesh, 
-                         neighboor_coords, 
-                         kernel.reshape([-1,4]),
-                         dnums)
-  return mesh
 
 def compensate_cic(field):
-  """
+    """
   Compensate for CiC painting
   Args:
     field: input 3D cic-painted field
   Returns:
     compensated_field
   """
-  nc = field.shape
-  kvec = fftk(nc)
+    nc = field.shape
+    kvec = fftk(nc)
 
-  delta_k = jnp.fft.rfftn(field)
-  delta_k = cic_compensation(kvec) * delta_k
-  return jnp.fft.irfftn(delta_k)
+    delta_k = jnp.fft.rfftn(field)
+    delta_k = cic_compensation(kvec) * delta_k
+    return jnp.fft.irfftn(delta_k)
diff --git a/jaxpm/pm.py b/jaxpm/pm.py
index d9870f7..41ab2a7 100644
--- a/jaxpm/pm.py
+++ b/jaxpm/pm.py
@@ -1,11 +1,12 @@
 import jax
 import jax.numpy as jnp
-
 import jax_cosmo as jc
 
-from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel, PGD_kernel
+from jaxpm.growth import dGfa, growth_factor, growth_rate
+from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
+                           longrange_kernel)
 from jaxpm.painting import cic_paint, cic_read
-from jaxpm.growth import growth_factor, growth_rate, dGfa
+
 
 def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
     """
@@ -21,10 +22,14 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
         delta_k = jnp.fft.rfftn(delta)
 
     # Computes gravitational potential
-    pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
+    pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
+                                                              r_split=r_split)
     # Computes gravitational forces
-    return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) 
-                      for i in range(3)],axis=-1)
+    return jnp.stack([
+        cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions)
+        for i in range(3)
+    ],
+                     axis=-1)
 
 
 def lpt(cosmo, initial_conditions, positions, a):
@@ -34,25 +39,31 @@ def lpt(cosmo, initial_conditions, positions, a):
     initial_force = pm_forces(positions, delta=initial_conditions)
     a = jnp.atleast_1d(a)
     dx = growth_factor(cosmo, a) * initial_force
-    p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
-    f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
+    p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
+                                                                   a)) * dx
+    f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo,
+                                                             a) * initial_force
     return dx, p, f
 
+
 def linear_field(mesh_shape, box_size, pk, seed):
     """
     Generate initial conditions.
     """
     kvec = fftk(mesh_shape)
-    kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
-    pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
+    kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
+                for i, kk in enumerate(kvec))**0.5
+    pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
+        box_size[0] * box_size[1] * box_size[2])
 
     field = jax.random.normal(seed, mesh_shape)
     field = jnp.fft.rfftn(field) * pkmesh**0.5
     field = jnp.fft.irfftn(field)
     return field
 
+
 def make_ode_fn(mesh_shape):
-    
+
     def nbody_ode(state, a, cosmo):
         """
         state is a tuple (position, velocities)
@@ -63,10 +74,10 @@ def make_ode_fn(mesh_shape):
 
         # Computes the update of position (drift)
         dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
-        
+
         # Computes the update of velocity (kick)
         dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
-        
+
         return dpos, dvel
 
     return nbody_ode
@@ -84,13 +95,16 @@ def pgd_correction(pos, params):
     delta = cic_paint(jnp.zeros(mesh_shape), pos)
     alpha, kl, ks = params
     delta_k = jnp.fft.rfftn(delta)
-    PGD_range=PGD_kernel(kvec, kl, ks)
-    
-    pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
+    PGD_range = PGD_kernel(kvec, kl, ks)
 
-    forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) 
-                      for i in range(3)],axis=-1)
-    
-    dpos_pgd = forces_pgd*alpha
-   
-    return dpos_pgd
\ No newline at end of file
+    pot_k_pgd = (delta_k * laplace_kernel(kvec)) * PGD_range
+
+    forces_pgd = jnp.stack([
+        cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k_pgd), pos)
+        for i in range(3)
+    ],
+                           axis=-1)
+
+    dpos_pgd = forces_pgd * alpha
+
+    return dpos_pgd
diff --git a/jaxpm/utils.py b/jaxpm/utils.py
index a01e188..fc00a79 100644
--- a/jaxpm/utils.py
+++ b/jaxpm/utils.py
@@ -1,99 +1,100 @@
-import numpy as np
 import jax.numpy as jnp
+import numpy as np
 from jax.scipy.stats import norm
 
 __all__ = ['power_spectrum']
 
+
 def _initialize_pk(shape, boxsize, kmin, dk):
-  """
+    """
        Helper function to initialize various (fixed) values for powerspectra... not differentiable!
     """
-  I = np.eye(len(shape), dtype='int') * -2 + 1
+    I = np.eye(len(shape), dtype='int') * -2 + 1
 
-  W = np.empty(shape, dtype='f4')
-  W[...] = 2.0
-  W[..., 0] = 1.0
-  W[..., -1] = 1.0
+    W = np.empty(shape, dtype='f4')
+    W[...] = 2.0
+    W[..., 0] = 1.0
+    W[..., -1] = 1.0
 
-  kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
-  kedges = np.arange(kmin, kmax, dk)
+    kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
+    kedges = np.arange(kmin, kmax, dk)
 
-  k = [
-      np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
-      for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
-  ]
-  kmag = sum(ki**2 for ki in k)**0.5
+    k = [
+        np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
+        for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
+    ]
+    kmag = sum(ki**2 for ki in k)**0.5
 
-  xsum = np.zeros(len(kedges) + 1)
-  Nsum = np.zeros(len(kedges) + 1)
+    xsum = np.zeros(len(kedges) + 1)
+    Nsum = np.zeros(len(kedges) + 1)
 
-  dig = np.digitize(kmag.flat, kedges)
+    dig = np.digitize(kmag.flat, kedges)
 
-  xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
-  Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
-  return dig, Nsum, xsum, W, k, kedges
+    xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
+    Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
+    return dig, Nsum, xsum, W, k, kedges
 
 
 def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
-  """
+    """
     Calculate the powerspectra given real space field
-    
+
     Args:
-        
-        field: real valued field 
+
+        field: real valued field
         kmin: minimum k-value for binned powerspectra
         dk: differential in each kbin
         boxsize: length of each boxlength (can be strangly shaped?)
-    
+
     Returns:
-        
+
         kbins: the central value of the bins for plotting
         power: real valued array of power in each bin
-        
+
   """
-  shape = field.shape
-  nx, ny, nz = shape
+    shape = field.shape
+    nx, ny, nz = shape
 
-  #initialze values related to powerspectra (mode bins and weights)
-  dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
+    #initialze values related to powerspectra (mode bins and weights)
+    dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
 
-  #fast fourier transform
-  fft_image = jnp.fft.fftn(field)
+    #fast fourier transform
+    fft_image = jnp.fft.fftn(field)
 
-  #absolute value of fast fourier transform
-  pk = jnp.real(fft_image * jnp.conj(fft_image))
+    #absolute value of fast fourier transform
+    pk = jnp.real(fft_image * jnp.conj(fft_image))
 
+    #calculating powerspectra
+    real = jnp.real(pk).reshape([-1])
+    imag = jnp.imag(pk).reshape([-1])
 
-  #calculating powerspectra
-  real = jnp.real(pk).reshape([-1])
-  imag = jnp.imag(pk).reshape([-1])
+    Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
+                        length=xsum.size) * 1j
+    Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
 
-  Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
-  Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
+    P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
 
-  P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
+    #normalization for powerspectra
+    norm = np.prod(np.array(shape[:])).astype('float32')**2
 
-  #normalization for powerspectra
-  norm = np.prod(np.array(shape[:])).astype('float32')**2
+    #find central values of each bin
+    kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
 
-  #find central values of each bin
-  kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
+    return kbins, P / norm
 
-  return kbins, P / norm
 
 def gaussian_smoothing(im, sigma):
-  """
+    """
   im: 2d image
-  sigma: smoothing scale in px 
+  sigma: smoothing scale in px
   """
-  # Compute k vector
-  kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
-                                jnp.fft.fftfreq(im.shape[1])),
-                 axis=-1)
-  k = jnp.linalg.norm(kvec, axis=-1)
-  # We compute the value of the filter at frequency k
-  filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
-  filter /= filter[0,0]
-
-  return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
+    # Compute k vector
+    kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
+                                  jnp.fft.fftfreq(im.shape[1])),
+                     axis=-1)
+    k = jnp.linalg.norm(kvec, axis=-1)
+    # We compute the value of the filter at frequency k
+    filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
+    filter /= filter[0, 0]
 
+    return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
diff --git a/notebooks/Introduction.ipynb b/notebooks/Introduction.ipynb
index b1ef596..0ca27ac 100644
--- a/notebooks/Introduction.ipynb
+++ b/notebooks/Introduction.ipynb
@@ -202,4 +202,4 @@
   },
   "nbformat": 4,
   "nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/setup.py b/setup.py
index 44be5a1..a58759a 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,4 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 
 setup(
     name='JaxPM',
@@ -6,6 +6,6 @@ setup(
     url='https://github.com/DifferentiableUniverseInitiative/JaxPM',
     author='JaxPM developers',
     description='A dead simple FastPM implementation in JAX',
-    packages=find_packages(),    
+    packages=find_packages(),
     install_requires=['jax', 'jax_cosmo'],
-)
\ No newline at end of file
+)