Move UNet with ResBlock to VNet and Revert UNet to the previous simple version
This commit is contained in:
parent
9e039c0407
commit
7b6ff73be1
@ -18,6 +18,7 @@ def get_args():
|
|||||||
def add_common_args(parser):
|
def add_common_args(parser):
|
||||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||||
'of normalization functions from data.norms')
|
'of normalization functions from data.norms')
|
||||||
|
parser.add_argument('--model', required=True, help='model from models')
|
||||||
parser.add_argument('--criterion', default='MSELoss',
|
parser.add_argument('--criterion', default='MSELoss',
|
||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
parser.add_argument('--load-state', default='', type=str,
|
parser.add_argument('--load-state', default='', type=str,
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
from .unet import UNet
|
from .unet import UNet
|
||||||
|
from .vnet import VNet
|
||||||
from .conv import narrow_like
|
from .conv import narrow_like
|
||||||
|
@ -57,22 +57,24 @@ class ConvBlock(nn.Module):
|
|||||||
|
|
||||||
class ResBlock(ConvBlock):
|
class ResBlock(ConvBlock):
|
||||||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||||
added to the residual followed by an activation.
|
added to the residual followed by an activation (trailing `'A'` in `seq`).
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||||
seq='CBACB'):
|
seq='CBACBA'):
|
||||||
if 'U' in seq or 'D' in seq:
|
|
||||||
raise NotImplementedError('upsample and downsample layers '
|
|
||||||
'not supported yet')
|
|
||||||
|
|
||||||
if out_channels is None:
|
if out_channels is None:
|
||||||
out_channels = in_channels
|
out_channels = in_channels
|
||||||
self.skip = None
|
self.skip = None
|
||||||
else:
|
else:
|
||||||
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
||||||
|
|
||||||
|
if 'U' in seq or 'D' in seq:
|
||||||
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
|
'not supported yet')
|
||||||
|
|
||||||
|
assert seq[-1] == 'A', 'block must end with activation'
|
||||||
|
|
||||||
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
||||||
seq=seq)
|
seq=seq[:-1])
|
||||||
|
|
||||||
self.act = nn.PReLU()
|
self.act = nn.PReLU()
|
||||||
|
|
||||||
|
@ -1,39 +1,24 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, ResBlock, narrow_like
|
from .conv import ConvBlock, narrow_like
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_0l = nn.Sequential(
|
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
||||||
ConvBlock(in_channels, 64, seq='CA'),
|
self.down_0l = ConvBlock(64, 64, seq='BADBA')
|
||||||
ResBlock(64, seq='CBACBACB'),
|
self.conv_1l = ConvBlock(64, 64, seq='CBAC')
|
||||||
)
|
self.down_1l = ConvBlock(64, 64, seq='BADBA')
|
||||||
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
|
||||||
self.conv_1l = nn.Sequential(
|
|
||||||
ResBlock(128, seq='CBACB'),
|
|
||||||
ResBlock(128, seq='CBACB'),
|
|
||||||
)
|
|
||||||
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
|
||||||
|
|
||||||
self.conv_2c = nn.Sequential(
|
self.conv_2c = ConvBlock(64, 64, seq='CBAC')
|
||||||
ResBlock(256, seq='CBACB'),
|
|
||||||
ResBlock(256, seq='CBACB'),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
self.up_1r = ConvBlock(64, 64, seq='BAUBA')
|
||||||
self.conv_1r = nn.Sequential(
|
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
|
||||||
ResBlock(256, seq='CBACB'),
|
self.up_0r = ConvBlock(64, 64, seq='BAUBA')
|
||||||
ResBlock(256, seq='CBACB'),
|
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
||||||
)
|
|
||||||
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
|
||||||
self.conv_0r = nn.Sequential(
|
|
||||||
ResBlock(128, seq='CBACBAC'),
|
|
||||||
ConvBlock(128, out_channels, seq='C')
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_0l(x)
|
y0 = self.conv_0l(x)
|
||||||
|
59
map2map/models/vnet.py
Normal file
59
map2map/models/vnet.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .conv import ConvBlock, ResBlock, narrow_like
|
||||||
|
|
||||||
|
|
||||||
|
class VNet(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_0l = nn.Sequential(
|
||||||
|
ConvBlock(in_channels, 64, seq='CA'),
|
||||||
|
ResBlock(64, seq='CBACBACBA'),
|
||||||
|
)
|
||||||
|
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
||||||
|
self.conv_1l = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACBA'),
|
||||||
|
ResBlock(128, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
||||||
|
|
||||||
|
self.conv_2c = nn.Sequential(
|
||||||
|
ResBlock(256, seq='CBACBA'),
|
||||||
|
ResBlock(256, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
||||||
|
self.conv_1r = nn.Sequential(
|
||||||
|
ResBlock(256, seq='CBACBA'),
|
||||||
|
ResBlock(256, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
||||||
|
self.conv_0r = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACBACA'),
|
||||||
|
ConvBlock(128, out_channels, seq='C')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y0 = self.conv_0l(x)
|
||||||
|
x = self.down_0l(y0)
|
||||||
|
|
||||||
|
y1 = self.conv_1l(x)
|
||||||
|
x = self.down_1l(y1)
|
||||||
|
|
||||||
|
x = self.conv_2c(x)
|
||||||
|
|
||||||
|
x = self.up_1r(x)
|
||||||
|
y1 = narrow_like(y1, x)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_1r(x)
|
||||||
|
|
||||||
|
x = self.up_0r(x)
|
||||||
|
y0 = narrow_like(y0, x)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_0r(x)
|
||||||
|
|
||||||
|
return x
|
@ -3,7 +3,8 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from .models import UNet, narrow_like
|
from . import models
|
||||||
|
from .models import narrow_like
|
||||||
|
|
||||||
|
|
||||||
def test(args):
|
def test(args):
|
||||||
@ -23,7 +24,7 @@ def test(args):
|
|||||||
|
|
||||||
in_channels, out_channels = test_dataset.channels
|
in_channels, out_channels = test_dataset.channels
|
||||||
|
|
||||||
model = UNet(in_channels, out_channels)
|
model = models.__dict__[args.model](in_channels, out_channels)
|
||||||
criterion = torch.nn.__dict__[args.criterion]()
|
criterion = torch.nn.__dict__[args.criterion]()
|
||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
|
@ -10,7 +10,8 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from .models import UNet, narrow_like
|
from . import models
|
||||||
|
from .models import narrow_like
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
@ -82,7 +83,7 @@ def gpu_worker(local_rank, args):
|
|||||||
|
|
||||||
in_channels, out_channels = train_dataset.channels
|
in_channels, out_channels = train_dataset.channels
|
||||||
|
|
||||||
model = UNet(in_channels, out_channels)
|
model = models.__dict__[args.model](in_channels, out_channels)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis \
|
--norms cosmology.dis --model VNet \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --augment \
|
--norms cosmology.dis --augment --model VNet \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel \
|
--norms cosmology.vel --model VNet \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --augment \
|
--norms cosmology.vel --augment --model VNet \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user