Add narrow_cast and narrow_by
This commit is contained in:
parent
28ec245a7a
commit
c956dca795
@ -3,7 +3,7 @@ from .vnet import VNet, VNetFat
|
|||||||
from .pyramid import PyramidNet
|
from .pyramid import PyramidNet
|
||||||
from .patchgan import PatchGAN, PatchGAN42
|
from .patchgan import PatchGAN, PatchGAN42
|
||||||
|
|
||||||
from .conv import narrow_like
|
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
from .lag2eul import Lag2Eul
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .narrow import narrow_like
|
||||||
from .swish import Swish
|
from .swish import Swish
|
||||||
|
|
||||||
|
|
||||||
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
|
|||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def narrow_like(a, b):
|
|
||||||
"""Narrow a to be like b.
|
|
||||||
|
|
||||||
Try to be symmetric but cut more on the right for odd difference,
|
|
||||||
consistent with the downsampling.
|
|
||||||
"""
|
|
||||||
for d in range(2, a.dim()):
|
|
||||||
width = a.shape[d] - b.shape[d]
|
|
||||||
half_width = width // 2
|
|
||||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
|
||||||
return a
|
|
||||||
|
43
map2map/models/narrow.py
Normal file
43
map2map/models/narrow.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_by(a, c):
|
||||||
|
"""Narrow a by size c symmetrically on all edges.
|
||||||
|
"""
|
||||||
|
for d in range(2, a.dim()):
|
||||||
|
a = a.narrow(d, c, a.shape[d] - 2 * c)
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_cast(*tensors):
|
||||||
|
"""Narrow each tensor to the minimum length in each dimension.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
dim_max = max(a.dim() for a in tensors)
|
||||||
|
|
||||||
|
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
|
||||||
|
|
||||||
|
casted_tensors = []
|
||||||
|
for a in tensors:
|
||||||
|
for d in range(2, dim_max):
|
||||||
|
width = a.shape[d] - len_min[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
|
||||||
|
casted_tensors.append(a)
|
||||||
|
|
||||||
|
return casted_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_like(a, b):
|
||||||
|
"""Narrow a to be like b.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
for d in range(2, a.dim()):
|
||||||
|
width = a.shape[d] - b.shape[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
return a
|
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, narrow_like
|
from .conv import ConvBlock
|
||||||
|
from .narrow import narrow_like
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, ResBlock, narrow_like
|
from .conv import ConvBlock, ResBlock
|
||||||
|
from .narrow import narrow_like
|
||||||
|
|
||||||
|
|
||||||
class VNet(nn.Module):
|
class VNet(nn.Module):
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
|
import sys
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import sys
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_cast
|
||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -58,12 +58,7 @@ def test(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (input, target) in enumerate(test_loader):
|
for i, (input, target) in enumerate(test_loader):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.pad > 0: # FIXME
|
input, output, target = narrow_cast(input, output, target)
|
||||||
output = narrow_like(output, target)
|
|
||||||
input = narrow_like(input, target)
|
|
||||||
else:
|
|
||||||
target = narrow_like(target, output)
|
|
||||||
input = narrow_like(input, output)
|
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_like,
|
from .models import (narrow_cast,
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
@ -333,12 +333,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
if epoch == 0 and i == 0 and rank == 0:
|
||||||
|
print('input.shape =', input.shape)
|
||||||
|
print('output.shape =', output.shape)
|
||||||
|
print('target.shape =', target.shape, flush=True)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
input = narrow_like(input, output)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
@ -475,11 +478,10 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
input = narrow_like(input, output)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
Loading…
Reference in New Issue
Block a user