Add narrow_cast and narrow_by
This commit is contained in:
parent
28ec245a7a
commit
c956dca795
@ -3,7 +3,7 @@ from .vnet import VNet, VNetFat
|
||||
from .pyramid import PyramidNet
|
||||
from .patchgan import PatchGAN, PatchGAN42
|
||||
|
||||
from .conv import narrow_like
|
||||
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||
|
||||
from .lag2eul import Lag2Eul
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .narrow import narrow_like
|
||||
from .swish import Swish
|
||||
|
||||
|
||||
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
|
||||
x = self.act(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def narrow_like(a, b):
|
||||
"""Narrow a to be like b.
|
||||
|
||||
Try to be symmetric but cut more on the right for odd difference,
|
||||
consistent with the downsampling.
|
||||
"""
|
||||
for d in range(2, a.dim()):
|
||||
width = a.shape[d] - b.shape[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
return a
|
||||
|
43
map2map/models/narrow.py
Normal file
43
map2map/models/narrow.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def narrow_by(a, c):
|
||||
"""Narrow a by size c symmetrically on all edges.
|
||||
"""
|
||||
for d in range(2, a.dim()):
|
||||
a = a.narrow(d, c, a.shape[d] - 2 * c)
|
||||
return a
|
||||
|
||||
|
||||
def narrow_cast(*tensors):
|
||||
"""Narrow each tensor to the minimum length in each dimension.
|
||||
|
||||
Try to be symmetric but cut more on the right for odd difference
|
||||
"""
|
||||
dim_max = max(a.dim() for a in tensors)
|
||||
|
||||
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
|
||||
|
||||
casted_tensors = []
|
||||
for a in tensors:
|
||||
for d in range(2, dim_max):
|
||||
width = a.shape[d] - len_min[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
|
||||
casted_tensors.append(a)
|
||||
|
||||
return casted_tensors
|
||||
|
||||
|
||||
def narrow_like(a, b):
|
||||
"""Narrow a to be like b.
|
||||
|
||||
Try to be symmetric but cut more on the right for odd difference
|
||||
"""
|
||||
for d in range(2, a.dim()):
|
||||
width = a.shape[d] - b.shape[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
return a
|
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, narrow_like
|
||||
from .conv import ConvBlock
|
||||
from .narrow import narrow_like
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, ResBlock, narrow_like
|
||||
from .conv import ConvBlock, ResBlock
|
||||
from .narrow import narrow_like
|
||||
|
||||
|
||||
class VNet(nn.Module):
|
||||
|
@ -1,12 +1,12 @@
|
||||
import sys
|
||||
from pprint import pprint
|
||||
import numpy as np
|
||||
import torch
|
||||
import sys
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data import FieldDataset
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
from .models import narrow_cast
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
||||
|
||||
@ -58,12 +58,7 @@ def test(args):
|
||||
with torch.no_grad():
|
||||
for i, (input, target) in enumerate(test_loader):
|
||||
output = model(input)
|
||||
if args.pad > 0: # FIXME
|
||||
output = narrow_like(output, target)
|
||||
input = narrow_like(input, target)
|
||||
else:
|
||||
target = narrow_like(target, output)
|
||||
input = narrow_like(input, output)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
|
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from .data import FieldDataset, GroupedRandomSampler
|
||||
from .data.figures import plt_slices
|
||||
from . import models
|
||||
from .models import (narrow_like,
|
||||
from .models import (narrow_cast,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
@ -333,12 +333,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
if epoch == 0 and i == 0 and rank == 0:
|
||||
print('input.shape =', input.shape)
|
||||
print('output.shape =', output.shape)
|
||||
print('target.shape =', target.shape, flush=True)
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
@ -475,11 +478,10 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
|
||||
output = model(input)
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
Loading…
Reference in New Issue
Block a user