Fix two power spectrum bugs about freq shift and flatten
This commit is contained in:
parent
6c67eaa788
commit
79b28561d5
@ -18,32 +18,30 @@ def power(x):
|
||||
|
||||
x = torch.rfft(x, signal_ndim)
|
||||
P = x.pow(2).sum(dim=-1)
|
||||
del x
|
||||
|
||||
batch_ndim = P.dim() - signal_ndim - 1
|
||||
if batch_ndim > 0:
|
||||
P = P.mean(tuple(range(batch_ndim)))
|
||||
if P.dim() > signal_ndim:
|
||||
P = P.mean(dim=0)
|
||||
P = P.sum(dim=0)
|
||||
P = P.flatten()
|
||||
del x
|
||||
|
||||
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
||||
for d in P.shape]
|
||||
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
|
||||
k = torch.meshgrid(*k)
|
||||
k = torch.stack(k, dim=0)
|
||||
k = k.norm(p=2, dim=0)
|
||||
k = k.flatten()
|
||||
|
||||
N = torch.full_like(P, 2, dtype=torch.int32)
|
||||
N[..., 0] = 1
|
||||
if even:
|
||||
N[..., -1] = 1
|
||||
|
||||
k = k.flatten()
|
||||
P = P.flatten()
|
||||
N = N.flatten()
|
||||
|
||||
kbin = k.ceil().to(torch.int32)
|
||||
k = torch.bincount(kbin, weights=k * N)
|
||||
P = torch.bincount(kbin, weights=P * N)
|
||||
N = torch.bincount(kbin, weights=N)
|
||||
N = torch.bincount(kbin, weights=N).round().to(torch.int32)
|
||||
del kbin
|
||||
|
||||
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
||||
|
Loading…
Reference in New Issue
Block a user