Add narrow_cast and narrow_by

This commit is contained in:
Yin Li 2020-07-11 01:24:20 -04:00
parent 28ec245a7a
commit c956dca795
7 changed files with 59 additions and 29 deletions

View File

@ -3,7 +3,7 @@ from .vnet import VNet, VNetFat
from .pyramid import PyramidNet from .pyramid import PyramidNet
from .patchgan import PatchGAN, PatchGAN42 from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like from .narrow import narrow_by, narrow_cast, narrow_like
from .lag2eul import Lag2Eul from .lag2eul import Lag2Eul

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish from .swish import Swish
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
x = self.act(x) x = self.act(x)
return x return x
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference,
consistent with the downsampling.
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

43
map2map/models/narrow.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
for d in range(2, a.dim()):
a = a.narrow(d, c, a.shape[d] - 2 * c)
return a
def narrow_cast(*tensors):
"""Narrow each tensor to the minimum length in each dimension.
Try to be symmetric but cut more on the right for odd difference
"""
dim_max = max(a.dim() for a in tensors)
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
casted_tensors = []
for a in tensors:
for d in range(2, dim_max):
width = a.shape[d] - len_min[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
casted_tensors.append(a)
return casted_tensors
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, narrow_like from .conv import ConvBlock
from .narrow import narrow_like
class UNet(nn.Module): class UNet(nn.Module):

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like from .conv import ConvBlock, ResBlock
from .narrow import narrow_like
class VNet(nn.Module): class VNet(nn.Module):

View File

@ -1,12 +1,12 @@
import sys
from pprint import pprint from pprint import pprint
import numpy as np import numpy as np
import torch import torch
import sys
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from . import models from . import models
from .models import narrow_like from .models import narrow_cast
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -58,12 +58,7 @@ def test(args):
with torch.no_grad(): with torch.no_grad():
for i, (input, target) in enumerate(test_loader): for i, (input, target) in enumerate(test_loader):
output = model(input) output = model(input)
if args.pad > 0: # FIXME input, output, target = narrow_cast(input, output, target)
output = narrow_like(output, target)
input = narrow_like(input, target)
else:
target = narrow_like(target, output)
input = narrow_like(input, output)
loss = criterion(output, target) loss = criterion(output, target)

View File

@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler from .data import FieldDataset, GroupedRandomSampler
from .data.figures import plt_slices from .data.figures import plt_slices
from . import models from . import models
from .models import (narrow_like, from .models import (narrow_cast,
adv_model_wrapper, adv_criterion_wrapper, adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm, add_spectral_norm, rm_spectral_norm,
InstanceNoise) InstanceNoise)
@ -333,12 +333,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = target.to(device, non_blocking=True) target = target.to(device, non_blocking=True)
output = model(input) output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest') scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output) input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
@ -475,11 +478,10 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
output = model(input) output = model(input)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest') scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output) input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()