Move UNet with ResBlock to VNet and Revert UNet to the previous simple version

This commit is contained in:
Yin Li 2019-12-09 21:53:27 -05:00
parent 9e039c0407
commit 7b6ff73be1
11 changed files with 90 additions and 40 deletions

View File

@ -18,6 +18,7 @@ def get_args():
def add_common_args(parser): def add_common_args(parser):
parser.add_argument('--norms', type=str_list, help='comma-sep. list ' parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms') 'of normalization functions from data.norms')
parser.add_argument('--model', required=True, help='model from models')
parser.add_argument('--criterion', default='MSELoss', parser.add_argument('--criterion', default='MSELoss',
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default='', type=str,

View File

@ -1,2 +1,3 @@
from .unet import UNet from .unet import UNet
from .vnet import VNet
from .conv import narrow_like from .conv import narrow_like

View File

@ -57,22 +57,24 @@ class ConvBlock(nn.Module):
class ResBlock(ConvBlock): class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is """Residual convolution blocks of the form specified by `seq`. Input is
added to the residual followed by an activation. added to the residual followed by an activation (trailing `'A'` in `seq`).
""" """
def __init__(self, in_channels, out_channels=None, mid_channels=None, def __init__(self, in_channels, out_channels=None, mid_channels=None,
seq='CBACB'): seq='CBACBA'):
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
if out_channels is None: if out_channels is None:
out_channels = in_channels out_channels = in_channels
self.skip = None self.skip = None
else: else:
self.skip = nn.Conv3d(in_channels, out_channels, 1) self.skip = nn.Conv3d(in_channels, out_channels, 1)
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
assert seq[-1] == 'A', 'block must end with activation'
super().__init__(in_channels, out_channels, mid_channels=mid_channels, super().__init__(in_channels, out_channels, mid_channels=mid_channels,
seq=seq) seq=seq[:-1])
self.act = nn.PReLU() self.act = nn.PReLU()

View File

@ -1,39 +1,24 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like from .conv import ConvBlock, narrow_like
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()
self.conv_0l = nn.Sequential( self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
ConvBlock(in_channels, 64, seq='CA'), self.down_0l = ConvBlock(64, 64, seq='BADBA')
ResBlock(64, seq='CBACBACB'), self.conv_1l = ConvBlock(64, 64, seq='CBAC')
) self.down_1l = ConvBlock(64, 64, seq='BADBA')
self.down_0l = ConvBlock(64, 128, seq='DBA')
self.conv_1l = nn.Sequential(
ResBlock(128, seq='CBACB'),
ResBlock(128, seq='CBACB'),
)
self.down_1l = ConvBlock(128, 256, seq='DBA')
self.conv_2c = nn.Sequential( self.conv_2c = ConvBlock(64, 64, seq='CBAC')
ResBlock(256, seq='CBACB'),
ResBlock(256, seq='CBACB'),
)
self.up_1r = ConvBlock(256, 128, seq='UBA') self.up_1r = ConvBlock(64, 64, seq='BAUBA')
self.conv_1r = nn.Sequential( self.conv_1r = ConvBlock(128, 64, seq='CBAC')
ResBlock(256, seq='CBACB'), self.up_0r = ConvBlock(64, 64, seq='BAUBA')
ResBlock(256, seq='CBACB'), self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
)
self.up_0r = ConvBlock(256, 64, seq='UBA')
self.conv_0r = nn.Sequential(
ResBlock(128, seq='CBACBAC'),
ConvBlock(128, out_channels, seq='C')
)
def forward(self, x): def forward(self, x):
y0 = self.conv_0l(x) y0 = self.conv_0l(x)

59
map2map/models/vnet.py Normal file
View File

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
class VNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_0l = nn.Sequential(
ConvBlock(in_channels, 64, seq='CA'),
ResBlock(64, seq='CBACBACBA'),
)
self.down_0l = ConvBlock(64, 128, seq='DBA')
self.conv_1l = nn.Sequential(
ResBlock(128, seq='CBACBA'),
ResBlock(128, seq='CBACBA'),
)
self.down_1l = ConvBlock(128, 256, seq='DBA')
self.conv_2c = nn.Sequential(
ResBlock(256, seq='CBACBA'),
ResBlock(256, seq='CBACBA'),
)
self.up_1r = ConvBlock(256, 128, seq='UBA')
self.conv_1r = nn.Sequential(
ResBlock(256, seq='CBACBA'),
ResBlock(256, seq='CBACBA'),
)
self.up_0r = ConvBlock(256, 64, seq='UBA')
self.conv_0r = nn.Sequential(
ResBlock(128, seq='CBACBACA'),
ConvBlock(128, out_channels, seq='C')
)
def forward(self, x):
y0 = self.conv_0l(x)
x = self.down_0l(y0)
y1 = self.conv_1l(x)
x = self.down_1l(y1)
x = self.conv_2c(x)
x = self.up_1r(x)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_1r(x)
x = self.up_0r(x)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_0r(x)
return x

View File

@ -3,7 +3,8 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from .models import UNet, narrow_like from . import models
from .models import narrow_like
def test(args): def test(args):
@ -23,7 +24,7 @@ def test(args):
in_channels, out_channels = test_dataset.channels in_channels, out_channels = test_dataset.channels
model = UNet(in_channels, out_channels) model = models.__dict__[args.model](in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]() criterion = torch.nn.__dict__[args.criterion]()
device = torch.device('cpu') device = torch.device('cpu')

View File

@ -10,7 +10,8 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset from .data import FieldDataset
from .models import UNet, narrow_like from . import models
from .models import narrow_like
def node_worker(args): def node_worker(args):
@ -82,7 +83,7 @@ def gpu_worker(local_rank, args):
in_channels, out_channels = train_dataset.channels in_channels, out_channels = train_dataset.channels
model = UNet(in_channels, out_channels) model = models.__dict__[args.model](in_channels, out_channels)
model.to(args.device) model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device]) model = DistributedDataParallel(model, device_ids=[args.device])

View File

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.dis \ --norms cosmology.dis --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth --load-state best_model.pth

View File

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.dis --augment \ --norms cosmology.dis --augment --model VNet \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth # --load-state checkpoint.pth

View File

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.vel \ --norms cosmology.vel --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth --load-state best_model.pth

View File

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.vel --augment \ --norms cosmology.vel --augment --model VNet \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth # --load-state checkpoint.pth