From a8b194f32639be456eccd3614d915c815d3869e1 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 12:15:17 -0400 Subject: [PATCH] quick fix in kernels --- jaxpm/kernels.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8123ad5..dad13ba 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -33,7 +33,8 @@ def interpolate_power_spectrum(input, k, pk, sharding=None): gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P() - out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh)) + out_specs = P(*get_output_specs( + FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P() return autoshmap(pk_fn, gpu_mesh=gpu_mesh,