Fix two power spectrum bugs about freq shift and flatten
This commit is contained in:
parent
6c67eaa788
commit
79b28561d5
@ -18,32 +18,30 @@ def power(x):
|
|||||||
|
|
||||||
x = torch.rfft(x, signal_ndim)
|
x = torch.rfft(x, signal_ndim)
|
||||||
P = x.pow(2).sum(dim=-1)
|
P = x.pow(2).sum(dim=-1)
|
||||||
|
P = P.mean(dim=0)
|
||||||
|
P = P.sum(dim=0)
|
||||||
del x
|
del x
|
||||||
|
|
||||||
batch_ndim = P.dim() - signal_ndim - 1
|
|
||||||
if batch_ndim > 0:
|
|
||||||
P = P.mean(tuple(range(batch_ndim)))
|
|
||||||
if P.dim() > signal_ndim:
|
|
||||||
P = P.sum(dim=0)
|
|
||||||
P = P.flatten()
|
|
||||||
|
|
||||||
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
||||||
for d in P.shape]
|
for d in P.shape]
|
||||||
|
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
|
||||||
k = torch.meshgrid(*k)
|
k = torch.meshgrid(*k)
|
||||||
k = torch.stack(k, dim=0)
|
k = torch.stack(k, dim=0)
|
||||||
k = k.norm(p=2, dim=0)
|
k = k.norm(p=2, dim=0)
|
||||||
k = k.flatten()
|
|
||||||
|
|
||||||
N = torch.full_like(P, 2, dtype=torch.int32)
|
N = torch.full_like(P, 2, dtype=torch.int32)
|
||||||
N[..., 0] = 1
|
N[..., 0] = 1
|
||||||
if even:
|
if even:
|
||||||
N[..., -1] = 1
|
N[..., -1] = 1
|
||||||
|
|
||||||
|
k = k.flatten()
|
||||||
|
P = P.flatten()
|
||||||
N = N.flatten()
|
N = N.flatten()
|
||||||
|
|
||||||
kbin = k.ceil().to(torch.int32)
|
kbin = k.ceil().to(torch.int32)
|
||||||
k = torch.bincount(kbin, weights=k * N)
|
k = torch.bincount(kbin, weights=k * N)
|
||||||
P = torch.bincount(kbin, weights=P * N)
|
P = torch.bincount(kbin, weights=P * N)
|
||||||
N = torch.bincount(kbin, weights=N)
|
N = torch.bincount(kbin, weights=N).round().to(torch.int32)
|
||||||
del kbin
|
del kbin
|
||||||
|
|
||||||
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
||||||
|
Loading…
Reference in New Issue
Block a user