diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8123ad5..dad13ba 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -33,7 +33,8 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P() - out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh)) + out_specs = P(*get_output_specs( + FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P() return autoshmap(pk_fn, gpu_mesh=gpu_mesh,