Fix again rfft call broken by pytorch breaking BC

This commit is contained in:
Yin Li 2021-04-22 13:32:13 -04:00
parent 795cefe38e
commit 5e4633b125

View File

@ -19,10 +19,11 @@ def power(x):
try:
x = torch.fft.rfftn(x, s=signal_size) # new version broke BC
P = x.real.square() + x.imag.square()
except AttributeError:
x = torch.rfft(x, signal_ndim)
P = x.square().sum(dim=-1)
P = x.pow(2).sum(dim=-1)
P = P.mean(dim=0)
P = P.sum(dim=0)
del x