Fix rfft call broken by pytorch breaking BC

This commit is contained in:
Yin Li 2021-04-12 12:43:29 -04:00
parent 4898724ac1
commit 002e249925

View File

@ -13,10 +13,15 @@ def power(x):
frequency of the input. frequency of the input.
""" """
signal_ndim = x.dim() - 2 signal_ndim = x.dim() - 2
kmax = min(d for d in x.shape[-signal_ndim:]) // 2 signal_size = x.shape[-signal_ndim:]
kmax = min(s for s in signal_size) // 2
even = x.shape[-1] % 2 == 0 even = x.shape[-1] % 2 == 0
try:
x = torch.fft.rfftn(x, s=signal_size) # new version broke BC
except AttributeError:
x = torch.rfft(x, signal_ndim) x = torch.rfft(x, signal_ndim)
P = x.pow(2).sum(dim=-1) P = x.pow(2).sum(dim=-1)
P = P.mean(dim=0) P = P.mean(dim=0)
P = P.sum(dim=0) P = P.sum(dim=0)