Fix rfft call broken by pytorch breaking BC
This commit is contained in:
parent
4898724ac1
commit
002e249925
@ -13,10 +13,15 @@ def power(x):
|
||||
frequency of the input.
|
||||
"""
|
||||
signal_ndim = x.dim() - 2
|
||||
kmax = min(d for d in x.shape[-signal_ndim:]) // 2
|
||||
signal_size = x.shape[-signal_ndim:]
|
||||
kmax = min(s for s in signal_size) // 2
|
||||
even = x.shape[-1] % 2 == 0
|
||||
|
||||
x = torch.rfft(x, signal_ndim)
|
||||
try:
|
||||
x = torch.fft.rfftn(x, s=signal_size) # new version broke BC
|
||||
except AttributeError:
|
||||
x = torch.rfft(x, signal_ndim)
|
||||
|
||||
P = x.pow(2).sum(dim=-1)
|
||||
P = P.mean(dim=0)
|
||||
P = P.sum(dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user