Add linear resamplers
This commit is contained in:
parent
c956dca795
commit
873258d8a7
@ -4,6 +4,7 @@ from .pyramid import PyramidNet
|
|||||||
from .patchgan import PatchGAN, PatchGAN42
|
from .patchgan import PatchGAN, PatchGAN42
|
||||||
|
|
||||||
from .narrow import narrow_by, narrow_cast, narrow_like
|
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||||
|
from .resample import resample, Resampler
|
||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
from .lag2eul import Lag2Eul
|
||||||
|
|
||||||
|
46
map2map/models/resample.py
Normal file
46
map2map/models/resample.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
|
def resample(x, scale_factor, narrow=True):
|
||||||
|
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||||
|
ndim = x.dim() - 2
|
||||||
|
mode = modes[ndim]
|
||||||
|
|
||||||
|
x = F.interpolate(x, scale_factor=scale_factor,
|
||||||
|
mode=mode, align_corners=False)
|
||||||
|
|
||||||
|
if scale_factor > 1 and narrow == True:
|
||||||
|
edges = round(scale_factor) // 2
|
||||||
|
edges = max(edges, 1)
|
||||||
|
x = narrow_by(x, edges)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resampler(nn.Module):
|
||||||
|
"""Resampling, upsampling or downsampling.
|
||||||
|
|
||||||
|
By default discard the inaccurate edges when upsampling.
|
||||||
|
"""
|
||||||
|
def __init__(self, ndim, scale_factor, narrow=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||||
|
self.mode = modes[ndim]
|
||||||
|
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.narrow = narrow
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.interpolate(x, scale_factor=self.scale_factor,
|
||||||
|
mode=self.mode, align_corners=False)
|
||||||
|
|
||||||
|
if self.scale_factor > 1 and self.narrow == True:
|
||||||
|
edges = round(self.scale_factor) // 2
|
||||||
|
edges = max(edges, 1)
|
||||||
|
x = narrow_by(x, edges)
|
||||||
|
|
||||||
|
return x
|
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_cast,
|
from .models import (narrow_cast, resample
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
@ -338,9 +338,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
print('output.shape =', output.shape)
|
print('output.shape =', output.shape)
|
||||||
print('target.shape =', target.shape, flush=True)
|
print('target.shape =', target.shape, flush=True)
|
||||||
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if (hasattr(model.module, 'scale_factor')
|
||||||
input = F.interpolate(input,
|
and model.module.scale_factor != 1):
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
@ -478,9 +478,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if (hasattr(model.module, 'scale_factor')
|
||||||
input = F.interpolate(input,
|
and model.module.scale_factor != 1):
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
Loading…
Reference in New Issue
Block a user