Add styled VNet and Fix bugs
Co-authored-by: Drew Jamieson <drew.s.jamieson@gmail.com>
This commit is contained in:
parent
f5bd657625
commit
17d8e95870
5 changed files with 197 additions and 31 deletions
|
@ -46,7 +46,7 @@ class FieldDataset(Dataset):
|
|||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||
self.param_files = sorted(param_pattern)
|
||||
self.param_files = sorted(glob(param_pattern))
|
||||
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
@ -61,6 +61,7 @@ class FieldDataset(Dataset):
|
|||
if self.nfile == 0:
|
||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||
|
||||
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
|
||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.in_files[0]]
|
||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
|
@ -143,7 +144,7 @@ class FieldDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
ifile, icrop = divmod(idx, self.ncrop)
|
||||
|
||||
params = np.loadtxt(self.param_files[idx])
|
||||
params = np.loadtxt(self.param_files[ifile])
|
||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||
|
||||
|
@ -159,6 +160,7 @@ class FieldDataset(Dataset):
|
|||
self.tgt_pad,
|
||||
self.size * self.scale_factor)
|
||||
|
||||
params = torch.from_numpy(params).to(torch.float32)
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class ConvElr3d(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ConvMod3d(nn.Module):
|
||||
class ConvStyled3d(nn.Module):
|
||||
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
||||
|
||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||
|
@ -158,6 +158,22 @@ class ConvMod3d(nn.Module):
|
|||
x = x.reshape(1, N * Cin, *DHWin)
|
||||
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
||||
_, _, *DHWout = x.shape
|
||||
x = x.reshape(N, Cout, *DHWout)
|
||||
x = x.reshape(N, Cout, *DHWout)
|
||||
|
||||
return x
|
||||
|
||||
class BatchNormStyled3d(nn.BatchNorm3d) :
|
||||
""" Trivially does standard batch normalization, but accepts second argument
|
||||
|
||||
for style array that is not used
|
||||
"""
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
||||
|
||||
class LeakyReLUStyled(nn.LeakyReLU):
|
||||
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||
|
||||
for sytle array that is not used
|
||||
"""
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
||||
|
|
141
map2map/models/styled_conv.py
Normal file
141
map2map/models/styled_conv.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .narrow import narrow_like
|
||||
from .swish import Swish
|
||||
|
||||
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||
|
||||
class ConvStyledBlock(nn.Module):
|
||||
"""Convolution blocks of the form specified by `seq`.
|
||||
|
||||
`seq` types:
|
||||
'C': convolution specified by `kernel_size` and `stride`
|
||||
'B': normalization (to be renamed to 'N')
|
||||
'A': activation
|
||||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||
'D': downsampling convolution of kernel size 2 and stride 2
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBA'):
|
||||
super().__init__()
|
||||
|
||||
if out_chan is None:
|
||||
out_chan = in_chan
|
||||
|
||||
self.style_size = style_size
|
||||
self.in_chan = in_chan
|
||||
self.out_chan = out_chan
|
||||
if mid_chan is None:
|
||||
self.mid_chan = max(in_chan, out_chan)
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.norm_chan = in_chan
|
||||
self.idx_conv = 0
|
||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||
|
||||
layers = [self._get_layer(l) for l in seq]
|
||||
|
||||
self.convs = nn.ModuleList(layers)
|
||||
|
||||
def _get_layer(self, l):
|
||||
if l == 'U':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'U')
|
||||
elif l == 'D':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'D')
|
||||
elif l == 'C':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
||||
stride=self.stride)
|
||||
elif l == 'B':
|
||||
return BatchNormStyled3d(self.norm_chan)
|
||||
elif l == 'A':
|
||||
return LeakyReLUStyled()
|
||||
else:
|
||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||
|
||||
def _setup_conv(self):
|
||||
self.idx_conv += 1
|
||||
|
||||
in_chan = out_chan = self.mid_chan
|
||||
if self.idx_conv == 1:
|
||||
in_chan = self.in_chan
|
||||
if self.idx_conv == self.num_conv:
|
||||
out_chan = self.out_chan
|
||||
|
||||
self.norm_chan = out_chan
|
||||
|
||||
return in_chan, out_chan
|
||||
|
||||
def forward(self, x, s):
|
||||
for l in self.convs:
|
||||
x = l(x, s)
|
||||
return x
|
||||
|
||||
|
||||
class ResStyledBlock(ConvStyledBlock):
|
||||
"""Residual convolution blocks of the form specified by `seq`.
|
||||
Input, via a skip connection, is added to the residual followed by an
|
||||
optional activation.
|
||||
|
||||
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||
`out_chan` even if it equals `in_chan`.
|
||||
|
||||
A trailing `'A'` in seq can either operate before or after the addition,
|
||||
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||
|
||||
See `ConvStyledBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
seq='CBACBA', last_act=None):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
warnings.warn(
|
||||
'Disabling last_act without trailing activation in seq',
|
||||
RuntimeWarning,
|
||||
)
|
||||
last_act = False
|
||||
|
||||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||
|
||||
if last_act:
|
||||
self.act = LeakyReLUStyled()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
if out_chan is None:
|
||||
self.skip = None
|
||||
else:
|
||||
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
def forward(self, x, s):
|
||||
y = x
|
||||
|
||||
if self.skip is not None:
|
||||
y = self.skip(y, s)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x, s)
|
||||
|
||||
y = narrow_like(y, x)
|
||||
x += y
|
||||
|
||||
if self.act is not None:
|
||||
x = self.act(x, s)
|
||||
|
||||
return x
|
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, ResBlock
|
||||
from .styled_conv import ConvStyledBlock, ResStyledBlock
|
||||
from .narrow import narrow_by
|
||||
|
||||
|
||||
class StyledVNet(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
|
||||
See `vnet.VNet`.
|
||||
|
@ -15,43 +15,43 @@ class StyledVNet(nn.Module):
|
|||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
|
||||
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
|
||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||
|
||||
self.bypass = in_chan == out_chan
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
y0 = self.conv_l0(x)
|
||||
x = self.down_l0(y0)
|
||||
y0 = self.conv_l0(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x)
|
||||
x = self.down_l1(y1)
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
x = self.conv_c(x)
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r1(x)
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_by(y1, 4)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x)
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x)
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_by(y0, 16)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r0(x)
|
||||
x = self.conv_r0(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_by(x0, 20)
|
||||
|
|
|
@ -59,6 +59,7 @@ def gpu_worker(local_rank, node, args):
|
|||
dist_init(rank, args)
|
||||
|
||||
train_dataset = FieldDataset(
|
||||
param_pattern=args.train_param_pattern,
|
||||
in_patterns=args.train_in_patterns,
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
|
@ -90,6 +91,7 @@ def gpu_worker(local_rank, node, args):
|
|||
|
||||
if args.val:
|
||||
val_dataset = FieldDataset(
|
||||
param_pattern=args.val_param_pattern,
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
|
@ -119,10 +121,12 @@ def gpu_worker(local_rank, node, args):
|
|||
pin_memory=True,
|
||||
)
|
||||
|
||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||
args.param_dim = train_dataset.param_dim
|
||||
args.in_chan = train_dataset.in_chan
|
||||
args.out_chan = train_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(sum(args.in_chan), sum(args.out_chan),
|
||||
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
|
||||
scale_factor=args.scale_factor)
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
|
@ -238,14 +242,16 @@ def train(epoch, loader, model, criterion,
|
|||
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
for i, (param, input, target) in enumerate(loader):
|
||||
batch = epoch * len(loader) + i + 1
|
||||
|
||||
param = param.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
output = model(input, param)
|
||||
if batch == 1 and rank == 0:
|
||||
print('param shape :', param.shape)
|
||||
print('input shape :', input.shape)
|
||||
print('output shape :', output.shape)
|
||||
print('target shape :', target.shape)
|
||||
|
@ -330,11 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for input, target in loader:
|
||||
for param, input, target in loader:
|
||||
param = param.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
output = model(input, param)
|
||||
|
||||
if (hasattr(model.module, 'scale_factor')
|
||||
and model.module.scale_factor != 1):
|
||||
|
|
Loading…
Reference in a new issue