From 79b28561d55e664c4c83ada60bfe302735417478 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sun, 23 Aug 2020 14:48:02 -0500 Subject: [PATCH] Fix two power spectrum bugs about freq shift and flatten --- map2map/models/power.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/map2map/models/power.py b/map2map/models/power.py index 5447c61..d65acc1 100644 --- a/map2map/models/power.py +++ b/map2map/models/power.py @@ -18,32 +18,30 @@ def power(x): x = torch.rfft(x, signal_ndim) P = x.pow(2).sum(dim=-1) + P = P.mean(dim=0) + P = P.sum(dim=0) del x - batch_ndim = P.dim() - signal_ndim - 1 - if batch_ndim > 0: - P = P.mean(tuple(range(batch_ndim))) - if P.dim() > signal_ndim: - P = P.sum(dim=0) - P = P.flatten() - k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape] + k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]] k = torch.meshgrid(*k) k = torch.stack(k, dim=0) k = k.norm(p=2, dim=0) - k = k.flatten() N = torch.full_like(P, 2, dtype=torch.int32) N[..., 0] = 1 if even: N[..., -1] = 1 + + k = k.flatten() + P = P.flatten() N = N.flatten() kbin = k.ceil().to(torch.int32) k = torch.bincount(kbin, weights=k * N) P = torch.bincount(kbin, weights=P * N) - N = torch.bincount(kbin, weights=N) + N = torch.bincount(kbin, weights=N).round().to(torch.int32) del kbin # drop k=0 mode and cut at kmax (smallest Nyquist)