Fix two power spectrum bugs about freq shift and flatten

This commit is contained in:
Yin Li 2020-08-23 14:48:02 -05:00
parent 6c67eaa788
commit 79b28561d5

View file

@ -18,32 +18,30 @@ def power(x):
x = torch.rfft(x, signal_ndim)
P = x.pow(2).sum(dim=-1)
P = P.mean(dim=0)
P = P.sum(dim=0)
del x
batch_ndim = P.dim() - signal_ndim - 1
if batch_ndim > 0:
P = P.mean(tuple(range(batch_ndim)))
if P.dim() > signal_ndim:
P = P.sum(dim=0)
P = P.flatten()
k = [torch.arange(d, dtype=torch.float32, device=P.device)
for d in P.shape]
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
k = torch.meshgrid(*k)
k = torch.stack(k, dim=0)
k = k.norm(p=2, dim=0)
k = k.flatten()
N = torch.full_like(P, 2, dtype=torch.int32)
N[..., 0] = 1
if even:
N[..., -1] = 1
k = k.flatten()
P = P.flatten()
N = N.flatten()
kbin = k.ceil().to(torch.int32)
k = torch.bincount(kbin, weights=k * N)
P = torch.bincount(kbin, weights=P * N)
N = torch.bincount(kbin, weights=N)
N = torch.bincount(kbin, weights=N).round().to(torch.int32)
del kbin
# drop k=0 mode and cut at kmax (smallest Nyquist)