diff --git a/map2map/models/__init__.py b/map2map/models/__init__.py index 70779b0..e478742 100644 --- a/map2map/models/__init__.py +++ b/map2map/models/__init__.py @@ -4,6 +4,7 @@ from .pyramid import PyramidNet from .patchgan import PatchGAN, PatchGAN42 from .narrow import narrow_by, narrow_cast, narrow_like +from .resample import resample, Resampler from .lag2eul import Lag2Eul diff --git a/map2map/models/resample.py b/map2map/models/resample.py new file mode 100644 index 0000000..308fc80 --- /dev/null +++ b/map2map/models/resample.py @@ -0,0 +1,46 @@ +import torch.nn as nn +import torch.nn.functional as F + +from .narrow import narrow_by + + +def resample(x, scale_factor, narrow=True): + modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} + ndim = x.dim() - 2 + mode = modes[ndim] + + x = F.interpolate(x, scale_factor=scale_factor, + mode=mode, align_corners=False) + + if scale_factor > 1 and narrow == True: + edges = round(scale_factor) // 2 + edges = max(edges, 1) + x = narrow_by(x, edges) + + return x + + +class Resampler(nn.Module): + """Resampling, upsampling or downsampling. + + By default discard the inaccurate edges when upsampling. + """ + def __init__(self, ndim, scale_factor, narrow=True): + super().__init__() + + modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} + self.mode = modes[ndim] + + self.scale_factor = scale_factor + self.narrow = narrow + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.scale_factor, + mode=self.mode, align_corners=False) + + if self.scale_factor > 1 and self.narrow == True: + edges = round(self.scale_factor) // 2 + edges = max(edges, 1) + x = narrow_by(x, edges) + + return x diff --git a/map2map/train.py b/map2map/train.py index d127229..032fcd7 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset, GroupedRandomSampler from .data.figures import plt_slices from . import models -from .models import (narrow_cast, +from .models import (narrow_cast, resample adv_model_wrapper, adv_criterion_wrapper, add_spectral_norm, rm_spectral_norm, InstanceNoise) @@ -338,9 +338,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, print('output.shape =', output.shape) print('target.shape =', target.shape, flush=True) - if hasattr(model, 'scale_factor') and model.scale_factor != 1: - input = F.interpolate(input, - scale_factor=model.scale_factor, mode='nearest') + if (hasattr(model.module, 'scale_factor') + and model.module.scale_factor != 1): + input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) loss = criterion(output, target) @@ -478,9 +478,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, output = model(input) - if hasattr(model, 'scale_factor') and model.scale_factor != 1: - input = F.interpolate(input, - scale_factor=model.scale_factor, mode='nearest') + if (hasattr(model.module, 'scale_factor') + and model.module.scale_factor != 1): + input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) loss = criterion(output, target)