Move UNet with ResBlock to VNet and Revert UNet to the previous simple version
This commit is contained in:
parent
9e039c0407
commit
7b6ff73be1
@ -18,6 +18,7 @@ def get_args():
|
||||
def add_common_args(parser):
|
||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||
'of normalization functions from data.norms')
|
||||
parser.add_argument('--model', required=True, help='model from models')
|
||||
parser.add_argument('--criterion', default='MSELoss',
|
||||
help='model criterion from torch.nn')
|
||||
parser.add_argument('--load-state', default='', type=str,
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .unet import UNet
|
||||
from .vnet import VNet
|
||||
from .conv import narrow_like
|
||||
|
@ -57,22 +57,24 @@ class ConvBlock(nn.Module):
|
||||
|
||||
class ResBlock(ConvBlock):
|
||||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||
added to the residual followed by an activation.
|
||||
added to the residual followed by an activation (trailing `'A'` in `seq`).
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||
seq='CBACB'):
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
seq='CBACBA'):
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.skip = None
|
||||
else:
|
||||
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
||||
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
assert seq[-1] == 'A', 'block must end with activation'
|
||||
|
||||
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
||||
seq=seq)
|
||||
seq=seq[:-1])
|
||||
|
||||
self.act = nn.PReLU()
|
||||
|
||||
|
@ -1,39 +1,24 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, ResBlock, narrow_like
|
||||
from .conv import ConvBlock, narrow_like
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.conv_0l = nn.Sequential(
|
||||
ConvBlock(in_channels, 64, seq='CA'),
|
||||
ResBlock(64, seq='CBACBACB'),
|
||||
)
|
||||
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
||||
self.conv_1l = nn.Sequential(
|
||||
ResBlock(128, seq='CBACB'),
|
||||
ResBlock(128, seq='CBACB'),
|
||||
)
|
||||
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
||||
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
||||
self.down_0l = ConvBlock(64, 64, seq='BADBA')
|
||||
self.conv_1l = ConvBlock(64, 64, seq='CBAC')
|
||||
self.down_1l = ConvBlock(64, 64, seq='BADBA')
|
||||
|
||||
self.conv_2c = nn.Sequential(
|
||||
ResBlock(256, seq='CBACB'),
|
||||
ResBlock(256, seq='CBACB'),
|
||||
)
|
||||
self.conv_2c = ConvBlock(64, 64, seq='CBAC')
|
||||
|
||||
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
||||
self.conv_1r = nn.Sequential(
|
||||
ResBlock(256, seq='CBACB'),
|
||||
ResBlock(256, seq='CBACB'),
|
||||
)
|
||||
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
||||
self.conv_0r = nn.Sequential(
|
||||
ResBlock(128, seq='CBACBAC'),
|
||||
ConvBlock(128, out_channels, seq='C')
|
||||
)
|
||||
self.up_1r = ConvBlock(64, 64, seq='BAUBA')
|
||||
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
|
||||
self.up_0r = ConvBlock(64, 64, seq='BAUBA')
|
||||
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
||||
|
||||
def forward(self, x):
|
||||
y0 = self.conv_0l(x)
|
||||
|
59
map2map/models/vnet.py
Normal file
59
map2map/models/vnet.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, ResBlock, narrow_like
|
||||
|
||||
|
||||
class VNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.conv_0l = nn.Sequential(
|
||||
ConvBlock(in_channels, 64, seq='CA'),
|
||||
ResBlock(64, seq='CBACBACBA'),
|
||||
)
|
||||
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
||||
self.conv_1l = nn.Sequential(
|
||||
ResBlock(128, seq='CBACBA'),
|
||||
ResBlock(128, seq='CBACBA'),
|
||||
)
|
||||
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
||||
|
||||
self.conv_2c = nn.Sequential(
|
||||
ResBlock(256, seq='CBACBA'),
|
||||
ResBlock(256, seq='CBACBA'),
|
||||
)
|
||||
|
||||
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
||||
self.conv_1r = nn.Sequential(
|
||||
ResBlock(256, seq='CBACBA'),
|
||||
ResBlock(256, seq='CBACBA'),
|
||||
)
|
||||
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
||||
self.conv_0r = nn.Sequential(
|
||||
ResBlock(128, seq='CBACBACA'),
|
||||
ConvBlock(128, out_channels, seq='C')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y0 = self.conv_0l(x)
|
||||
x = self.down_0l(y0)
|
||||
|
||||
y1 = self.conv_1l(x)
|
||||
x = self.down_1l(y1)
|
||||
|
||||
x = self.conv_2c(x)
|
||||
|
||||
x = self.up_1r(x)
|
||||
y1 = narrow_like(y1, x)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_1r(x)
|
||||
|
||||
x = self.up_0r(x)
|
||||
y0 = narrow_like(y0, x)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_0r(x)
|
||||
|
||||
return x
|
@ -3,7 +3,8 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data import FieldDataset
|
||||
from .models import UNet, narrow_like
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
|
||||
|
||||
def test(args):
|
||||
@ -23,7 +24,7 @@ def test(args):
|
||||
|
||||
in_channels, out_channels = test_dataset.channels
|
||||
|
||||
model = UNet(in_channels, out_channels)
|
||||
model = models.__dict__[args.model](in_channels, out_channels)
|
||||
criterion = torch.nn.__dict__[args.criterion]()
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
@ -10,7 +10,8 @@ from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .data import FieldDataset
|
||||
from .models import UNet, narrow_like
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
@ -82,7 +83,7 @@ def gpu_worker(local_rank, args):
|
||||
|
||||
in_channels, out_channels = train_dataset.channels
|
||||
|
||||
model = UNet(in_channels, out_channels)
|
||||
model = models.__dict__[args.model](in_channels, out_channels)
|
||||
model.to(args.device)
|
||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||
|
||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
||||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--norms cosmology.dis \
|
||||
--norms cosmology.dis --model VNet \
|
||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||
--load-state best_model.pth
|
||||
|
||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--norms cosmology.dis --augment \
|
||||
--norms cosmology.dis --augment --model VNet \
|
||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||
# --load-state checkpoint.pth
|
||||
|
||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
||||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--norms cosmology.vel \
|
||||
--norms cosmology.vel --model VNet \
|
||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||
--load-state best_model.pth
|
||||
|
||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--norms cosmology.vel --augment \
|
||||
--norms cosmology.vel --augment --model VNet \
|
||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||
# --load-state checkpoint.pth
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user