apply formatting

This commit is contained in:
Wassim KABALAN 2024-10-27 00:52:14 +02:00
parent 11f7e90066
commit 4342279817
5 changed files with 26 additions and 22 deletions

View file

@ -40,7 +40,7 @@ def ifft3d(x):
def get_halo_size(halo_size, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
zero_ext = (0, 0, 0)
zero_ext = (0, 0)
zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else:

View file

@ -5,8 +5,8 @@ import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
slice_pad, slice_unpad, fft3d, ifft3d)
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
ifft3d, slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter