mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Shared operation in fourrier space now take inverted sharding axis for
slabs
This commit is contained in:
parent
8c5bd76c33
commit
75604d2396
2 changed files with 23 additions and 8 deletions
|
@ -46,7 +46,7 @@ def fftk(shape, dtype=np.float32):
|
|||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x'), P('y'), P(None)),
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)))
|
||||
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
|
||||
def get_kvec(ky, kz, kx):
|
||||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
|
@ -73,7 +73,7 @@ def interpolate_power_spectrum(input, k, pk):
|
|||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue