Fix rfft call broken by pytorch breaking BC
This commit is contained in:
parent
4898724ac1
commit
002e249925
@ -13,10 +13,15 @@ def power(x):
|
|||||||
frequency of the input.
|
frequency of the input.
|
||||||
"""
|
"""
|
||||||
signal_ndim = x.dim() - 2
|
signal_ndim = x.dim() - 2
|
||||||
kmax = min(d for d in x.shape[-signal_ndim:]) // 2
|
signal_size = x.shape[-signal_ndim:]
|
||||||
|
kmax = min(s for s in signal_size) // 2
|
||||||
even = x.shape[-1] % 2 == 0
|
even = x.shape[-1] % 2 == 0
|
||||||
|
|
||||||
x = torch.rfft(x, signal_ndim)
|
try:
|
||||||
|
x = torch.fft.rfftn(x, s=signal_size) # new version broke BC
|
||||||
|
except AttributeError:
|
||||||
|
x = torch.rfft(x, signal_ndim)
|
||||||
|
|
||||||
P = x.pow(2).sum(dim=-1)
|
P = x.pow(2).sum(dim=-1)
|
||||||
P = P.mean(dim=0)
|
P = P.mean(dim=0)
|
||||||
P = P.sum(dim=0)
|
P = P.sum(dim=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user