quick fix in kernels

This commit is contained in:
Wassim KABALAN 2024-10-22 12:15:17 -04:00
parent 4d944f01f2
commit a8b194f326

View file

@ -33,7 +33,8 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
out_specs = P(*get_output_specs(
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,