Add linear resamplers

This commit is contained in:
Yin Li 2020-07-11 01:37:23 -04:00
parent c956dca795
commit 873258d8a7
3 changed files with 54 additions and 7 deletions

View File

@ -4,6 +4,7 @@ from .pyramid import PyramidNet
from .patchgan import PatchGAN, PatchGAN42
from .narrow import narrow_by, narrow_cast, narrow_like
from .resample import resample, Resampler
from .lag2eul import Lag2Eul

View File

@ -0,0 +1,46 @@
import torch.nn as nn
import torch.nn.functional as F
from .narrow import narrow_by
def resample(x, scale_factor, narrow=True):
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
ndim = x.dim() - 2
mode = modes[ndim]
x = F.interpolate(x, scale_factor=scale_factor,
mode=mode, align_corners=False)
if scale_factor > 1 and narrow == True:
edges = round(scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x
class Resampler(nn.Module):
"""Resampling, upsampling or downsampling.
By default discard the inaccurate edges when upsampling.
"""
def __init__(self, ndim, scale_factor, narrow=True):
super().__init__()
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
self.mode = modes[ndim]
self.scale_factor = scale_factor
self.narrow = narrow
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor,
mode=self.mode, align_corners=False)
if self.scale_factor > 1 and self.narrow == True:
edges = round(self.scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x

View File

@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data.figures import plt_slices
from . import models
from .models import (narrow_cast,
from .models import (narrow_cast, resample
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
@ -338,9 +338,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)
@ -478,9 +478,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
output = model(input)
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)