Add narrow_cast and narrow_by

This commit is contained in:
Yin Li 2020-07-11 01:24:20 -04:00
parent 28ec245a7a
commit c956dca795
7 changed files with 59 additions and 29 deletions

View File

@ -3,7 +3,7 @@ from .vnet import VNet, VNetFat
from .pyramid import PyramidNet
from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like
from .narrow import narrow_by, narrow_cast, narrow_like
from .lag2eul import Lag2Eul

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
x = self.act(x)
return x
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference,
consistent with the downsampling.
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

43
map2map/models/narrow.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
for d in range(2, a.dim()):
a = a.narrow(d, c, a.shape[d] - 2 * c)
return a
def narrow_cast(*tensors):
"""Narrow each tensor to the minimum length in each dimension.
Try to be symmetric but cut more on the right for odd difference
"""
dim_max = max(a.dim() for a in tensors)
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
casted_tensors = []
for a in tensors:
for d in range(2, dim_max):
width = a.shape[d] - len_min[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
casted_tensors.append(a)
return casted_tensors
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, narrow_like
from .conv import ConvBlock
from .narrow import narrow_like
class UNet(nn.Module):

View File

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
from .conv import ConvBlock, ResBlock
from .narrow import narrow_like
class VNet(nn.Module):

View File

@ -1,12 +1,12 @@
import sys
from pprint import pprint
import numpy as np
import torch
import sys
from torch.utils.data import DataLoader
from .data import FieldDataset
from . import models
from .models import narrow_like
from .models import narrow_cast
from .utils import import_attr, load_model_state_dict
@ -58,12 +58,7 @@ def test(args):
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
output = model(input)
if args.pad > 0: # FIXME
output = narrow_like(output, target)
input = narrow_like(input, target)
else:
target = narrow_like(target, output)
input = narrow_like(input, output)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)

View File

@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data.figures import plt_slices
from . import models
from .models import (narrow_like,
from .models import (narrow_cast,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
@ -333,12 +333,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = target.to(device, non_blocking=True)
output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)
epoch_loss[0] += loss.item()
@ -475,11 +478,10 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
output = model(input)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)
epoch_loss[0] += loss.item()