Move UNet with ResBlock to VNet and Revert UNet to the previous simple version

This commit is contained in:
Yin Li 2019-12-09 21:53:27 -05:00
parent 9e039c0407
commit 7b6ff73be1
11 changed files with 90 additions and 40 deletions

View file

@ -18,6 +18,7 @@ def get_args():
def add_common_args(parser):
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms')
parser.add_argument('--model', required=True, help='model from models')
parser.add_argument('--criterion', default='MSELoss',
help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str,

View file

@ -1,2 +1,3 @@
from .unet import UNet
from .vnet import VNet
from .conv import narrow_like

View file

@ -57,22 +57,24 @@ class ConvBlock(nn.Module):
class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is
added to the residual followed by an activation.
added to the residual followed by an activation (trailing `'A'` in `seq`).
"""
def __init__(self, in_channels, out_channels=None, mid_channels=None,
seq='CBACB'):
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
seq='CBACBA'):
if out_channels is None:
out_channels = in_channels
self.skip = None
else:
self.skip = nn.Conv3d(in_channels, out_channels, 1)
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
assert seq[-1] == 'A', 'block must end with activation'
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
seq=seq)
seq=seq[:-1])
self.act = nn.PReLU()

View file

@ -1,39 +1,24 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
from .conv import ConvBlock, narrow_like
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_0l = nn.Sequential(
ConvBlock(in_channels, 64, seq='CA'),
ResBlock(64, seq='CBACBACB'),
)
self.down_0l = ConvBlock(64, 128, seq='DBA')
self.conv_1l = nn.Sequential(
ResBlock(128, seq='CBACB'),
ResBlock(128, seq='CBACB'),
)
self.down_1l = ConvBlock(128, 256, seq='DBA')
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
self.down_0l = ConvBlock(64, 64, seq='BADBA')
self.conv_1l = ConvBlock(64, 64, seq='CBAC')
self.down_1l = ConvBlock(64, 64, seq='BADBA')
self.conv_2c = nn.Sequential(
ResBlock(256, seq='CBACB'),
ResBlock(256, seq='CBACB'),
)
self.conv_2c = ConvBlock(64, 64, seq='CBAC')
self.up_1r = ConvBlock(256, 128, seq='UBA')
self.conv_1r = nn.Sequential(
ResBlock(256, seq='CBACB'),
ResBlock(256, seq='CBACB'),
)
self.up_0r = ConvBlock(256, 64, seq='UBA')
self.conv_0r = nn.Sequential(
ResBlock(128, seq='CBACBAC'),
ConvBlock(128, out_channels, seq='C')
)
self.up_1r = ConvBlock(64, 64, seq='BAUBA')
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
self.up_0r = ConvBlock(64, 64, seq='BAUBA')
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
def forward(self, x):
y0 = self.conv_0l(x)

59
map2map/models/vnet.py Normal file
View file

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
class VNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_0l = nn.Sequential(
ConvBlock(in_channels, 64, seq='CA'),
ResBlock(64, seq='CBACBACBA'),
)
self.down_0l = ConvBlock(64, 128, seq='DBA')
self.conv_1l = nn.Sequential(
ResBlock(128, seq='CBACBA'),
ResBlock(128, seq='CBACBA'),
)
self.down_1l = ConvBlock(128, 256, seq='DBA')
self.conv_2c = nn.Sequential(
ResBlock(256, seq='CBACBA'),
ResBlock(256, seq='CBACBA'),
)
self.up_1r = ConvBlock(256, 128, seq='UBA')
self.conv_1r = nn.Sequential(
ResBlock(256, seq='CBACBA'),
ResBlock(256, seq='CBACBA'),
)
self.up_0r = ConvBlock(256, 64, seq='UBA')
self.conv_0r = nn.Sequential(
ResBlock(128, seq='CBACBACA'),
ConvBlock(128, out_channels, seq='C')
)
def forward(self, x):
y0 = self.conv_0l(x)
x = self.down_0l(y0)
y1 = self.conv_1l(x)
x = self.down_1l(y1)
x = self.conv_2c(x)
x = self.up_1r(x)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_1r(x)
x = self.up_0r(x)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_0r(x)
return x

View file

@ -3,7 +3,8 @@ import torch
from torch.utils.data import DataLoader
from .data import FieldDataset
from .models import UNet, narrow_like
from . import models
from .models import narrow_like
def test(args):
@ -23,7 +24,7 @@ def test(args):
in_channels, out_channels = test_dataset.channels
model = UNet(in_channels, out_channels)
model = models.__dict__[args.model](in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]()
device = torch.device('cpu')

View file

@ -10,7 +10,8 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset
from .models import UNet, narrow_like
from . import models
from .models import narrow_like
def node_worker(args):
@ -82,7 +83,7 @@ def gpu_worker(local_rank, args):
in_channels, out_channels = train_dataset.channels
model = UNet(in_channels, out_channels)
model = models.__dict__[args.model](in_channels, out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])

View file

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.dis \
--norms cosmology.dis --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth

View file

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.dis --augment \
--norms cosmology.dis --augment --model VNet \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth

View file

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.vel \
--norms cosmology.vel --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth

View file

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.vel --augment \
--norms cosmology.vel --augment --model VNet \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth