mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
quick fix in kernels
This commit is contained in:
parent
4d944f01f2
commit
a8b194f326
1 changed files with 2 additions and 1 deletions
|
@ -33,7 +33,8 @@ def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
specs = sharding.spec if sharding is not None else P()
|
specs = sharding.spec if sharding is not None else P()
|
||||||
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
|
out_specs = P(*get_output_specs(
|
||||||
|
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||||
|
|
||||||
return autoshmap(pk_fn,
|
return autoshmap(pk_fn,
|
||||||
gpu_mesh=gpu_mesh,
|
gpu_mesh=gpu_mesh,
|
||||||
|
|
Loading…
Add table
Reference in a new issue