apply formating

This commit is contained in:
Wassim KABALAN 2024-10-27 03:50:34 +01:00
parent c93894f561
commit 19011d0712
5 changed files with 22 additions and 15 deletions

View file

@ -5,7 +5,7 @@ import jax_cosmo as jc
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
normal_field,zeros)
normal_field, zeros)
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
@ -29,8 +29,10 @@ def pm_forces(positions,
mesh_shape = delta.shape
if paint_particles:
paint_fn = lambda x: cic_paint(
zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding)
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
x,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda x: cic_read(
x, positions, halo_size=halo_size, sharding=sharding)
else: