diff --git a/map2map/models/power.py b/map2map/models/power.py index 9b72d2a..81fb7ac 100644 --- a/map2map/models/power.py +++ b/map2map/models/power.py @@ -19,10 +19,11 @@ def power(x): try: x = torch.fft.rfftn(x, s=signal_size) # new version broke BC + P = x.real.square() + x.imag.square() except AttributeError: x = torch.rfft(x, signal_ndim) + P = x.square().sum(dim=-1) - P = x.pow(2).sum(dim=-1) P = P.mean(dim=0) P = P.sum(dim=0) del x