Add power spectrum module
This commit is contained in:
parent
5d22594ede
commit
afaf4675fe
2 changed files with 58 additions and 0 deletions
|
@ -6,6 +6,7 @@ from .narrow import narrow_by, narrow_cast, narrow_like
|
|||
from .resample import resample, Resampler
|
||||
|
||||
from .lag2eul import Lag2Eul
|
||||
from .power import power
|
||||
|
||||
from .dice import DiceLoss, dice_loss
|
||||
|
||||
|
|
57
map2map/models/power.py
Normal file
57
map2map/models/power.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
|
||||
from .lag2eul import lag2eul
|
||||
|
||||
|
||||
def power(x):
|
||||
"""Compute power spectra of input fields
|
||||
|
||||
Each field should have batch and channel dimensions followed by spatial
|
||||
dimensions. Powers are summed over channels, and averaged over batches.
|
||||
|
||||
Power is not normalized. Wavevectors are in unit of the fundamental
|
||||
frequency of the input.
|
||||
"""
|
||||
signal_ndim = x.dim() - 2
|
||||
kmax = min(d for d in x.shape[-signal_ndim:]) // 2
|
||||
even = x.shape[-1] % 2 == 0
|
||||
|
||||
x = torch.rfft(x, signal_ndim)
|
||||
P = x.pow(2).sum(dim=-1)
|
||||
del x
|
||||
|
||||
batch_ndim = P.dim() - signal_ndim - 1
|
||||
if batch_ndim > 0:
|
||||
P = P.mean(tuple(range(batch_ndim)))
|
||||
if P.dim() > signal_ndim:
|
||||
P = P.sum(dim=0)
|
||||
P = P.flatten()
|
||||
|
||||
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
||||
for d in P.shape]
|
||||
k = torch.meshgrid(*k)
|
||||
k = torch.stack(k, dim=0)
|
||||
k = k.norm(p=2, dim=0)
|
||||
k = k.flatten()
|
||||
|
||||
N = torch.full_like(P, 2, dtype=torch.int32)
|
||||
N[..., 0] = 1
|
||||
if even:
|
||||
N[..., -1] = 1
|
||||
N = N.flatten()
|
||||
|
||||
kbin = k.ceil().to(torch.int32)
|
||||
k = torch.bincount(kbin, weights=k * N)
|
||||
P = torch.bincount(kbin, weights=P * N)
|
||||
N = torch.bincount(kbin, weights=N)
|
||||
del kbin
|
||||
|
||||
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
||||
k = k[1:1+kmax]
|
||||
P = P[1:1+kmax]
|
||||
N = N[1:1+kmax]
|
||||
|
||||
k /= N
|
||||
P /= N
|
||||
|
||||
return k, P, N
|
Loading…
Reference in a new issue