mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
quick fix in kernels
This commit is contained in:
parent
4d944f01f2
commit
a8b194f326
1 changed files with 2 additions and 1 deletions
|
@ -33,7 +33,8 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
|
|||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
|
||||
out_specs = P(*get_output_specs(
|
||||
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||
|
||||
return autoshmap(pk_fn,
|
||||
gpu_mesh=gpu_mesh,
|
||||
|
|
Loading…
Add table
Reference in a new issue