Add styled VNet and Fix bugs

Co-authored-by: Drew Jamieson <drew.s.jamieson@gmail.com>
This commit is contained in:
Yin Li 2021-03-16 14:31:29 -07:00
parent f5bd657625
commit 17d8e95870
5 changed files with 197 additions and 31 deletions

View File

@ -46,7 +46,7 @@ class FieldDataset(Dataset):
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1):
self.param_files = sorted(param_pattern)
self.param_files = sorted(glob(param_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -61,6 +61,7 @@ class FieldDataset(Dataset):
if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns))
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
@ -143,7 +144,7 @@ class FieldDataset(Dataset):
def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop)
params = np.loadtxt(self.param_files[idx])
params = np.loadtxt(self.param_files[ifile])
in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
@ -159,6 +160,7 @@ class FieldDataset(Dataset):
self.tgt_pad,
self.size * self.scale_factor)
params = torch.from_numpy(params).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]

View File

@ -82,7 +82,7 @@ class ConvElr3d(nn.Module):
return x
class ConvMod3d(nn.Module):
class ConvStyled3d(nn.Module):
"""Convolution layer with modulation and demodulation, from StyleGAN2.
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
@ -158,6 +158,22 @@ class ConvMod3d(nn.Module):
x = x.reshape(1, N * Cin, *DHWin)
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
_, _, *DHWout = x.shape
x = x.reshape(N, Cout, *DHWout)
x = x.reshape(N, Cout, *DHWout)
return x
class BatchNormStyled3d(nn.BatchNorm3d) :
""" Trivially does standard batch normalization, but accepts second argument
for style array that is not used
"""
def forward(self, x, s):
return super().forward(x)
class LeakyReLUStyled(nn.LeakyReLU):
""" Trivially evaluates standard leaky ReLU, but accepts second argument
for sytle array that is not used
"""
def forward(self, x, s):
return super().forward(x)

View File

@ -0,0 +1,141 @@
import warnings
import torch
import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
class ConvStyledBlock(nn.Module):
"""Convolution blocks of the form specified by `seq`.
`seq` types:
'C': convolution specified by `kernel_size` and `stride`
'B': normalization (to be renamed to 'N')
'A': activation
'U': upsampling transposed convolution of kernel size 2 and stride 2
'D': downsampling convolution of kernel size 2 and stride 2
"""
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBA'):
super().__init__()
if out_chan is None:
out_chan = in_chan
self.style_size = style_size
self.in_chan = in_chan
self.out_chan = out_chan
if mid_chan is None:
self.mid_chan = max(in_chan, out_chan)
self.kernel_size = kernel_size
self.stride = stride
self.norm_chan = in_chan
self.idx_conv = 0
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
layers = [self._get_layer(l) for l in seq]
self.convs = nn.ModuleList(layers)
def _get_layer(self, l):
if l == 'U':
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
resample = 'U')
elif l == 'D':
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
resample = 'D')
elif l == 'C':
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
stride=self.stride)
elif l == 'B':
return BatchNormStyled3d(self.norm_chan)
elif l == 'A':
return LeakyReLUStyled()
else:
raise NotImplementedError('layer type {} not supported'.format(l))
def _setup_conv(self):
self.idx_conv += 1
in_chan = out_chan = self.mid_chan
if self.idx_conv == 1:
in_chan = self.in_chan
if self.idx_conv == self.num_conv:
out_chan = self.out_chan
self.norm_chan = out_chan
return in_chan, out_chan
def forward(self, x, s):
for l in self.convs:
x = l(x, s)
return x
class ResStyledBlock(ConvStyledBlock):
"""Residual convolution blocks of the form specified by `seq`.
Input, via a skip connection, is added to the residual followed by an
optional activation.
The skip connection is identity if `out_chan` is omitted, otherwise it uses
a size 1 "convolution", i.e. one can trigger the latter by setting
`out_chan` even if it equals `in_chan`.
A trailing `'A'` in seq can either operate before or after the addition,
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
See `ConvStyledBlock` for `seq` types.
"""
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
seq='CBACBA', last_act=None):
if last_act is None:
last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A':
warnings.warn(
'Disabling last_act without trailing activation in seq',
RuntimeWarning,
)
last_act = False
if last_act:
seq = seq[:-1]
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
if last_act:
self.act = LeakyReLUStyled()
else:
self.act = None
if out_chan is None:
self.skip = None
else:
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
def forward(self, x, s):
y = x
if self.skip is not None:
y = self.skip(y, s)
for l in self.convs:
x = l(x, s)
y = narrow_like(y, x)
x += y
if self.act is not None:
x = self.act(x, s)
return x

View File

@ -1,12 +1,12 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock
from .styled_conv import ConvStyledBlock, ResStyledBlock
from .narrow import narrow_by
class StyledVNet(nn.Module):
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
"""V-Net like network with styles
See `vnet.VNet`.
@ -15,43 +15,43 @@ class StyledVNet(nn.Module):
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
self.down_l0 = ConvBlock(64, seq='DBA')
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
self.down_l1 = ConvBlock(64, seq='DBA')
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_c = ResBlock(64, 64, seq='CBACBA')
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.up_r1 = ConvBlock(64, seq='UBA')
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
self.bypass = in_chan == out_chan
def forward(self, x):
def forward(self, x, s):
if self.bypass:
x0 = x
y0 = self.conv_l0(x)
x = self.down_l0(y0)
y0 = self.conv_l0(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x)
x = self.down_l1(y1)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
x = self.conv_c(x)
x = self.conv_c(x, s)
x = self.up_r1(x)
x = self.up_r1(x, s)
y1 = narrow_by(y1, 4)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x)
x = self.conv_r1(x, s)
x = self.up_r0(x)
x = self.up_r0(x, s)
y0 = narrow_by(y0, 16)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x)
x = self.conv_r0(x, s)
if self.bypass:
x0 = narrow_by(x0, 20)

View File

@ -59,6 +59,7 @@ def gpu_worker(local_rank, node, args):
dist_init(rank, args)
train_dataset = FieldDataset(
param_pattern=args.train_param_pattern,
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
@ -90,6 +91,7 @@ def gpu_worker(local_rank, node, args):
if args.val:
val_dataset = FieldDataset(
param_pattern=args.val_param_pattern,
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
@ -119,10 +121,12 @@ def gpu_worker(local_rank, node, args):
pin_memory=True,
)
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
args.param_dim = train_dataset.param_dim
args.in_chan = train_dataset.in_chan
args.out_chan = train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan),
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
@ -238,14 +242,16 @@ def train(epoch, loader, model, criterion,
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (input, target) in enumerate(loader):
for i, (param, input, target) in enumerate(loader):
batch = epoch * len(loader) + i + 1
param = param.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
output = model(input, param)
if batch == 1 and rank == 0:
print('param shape :', param.shape)
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
@ -330,11 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad():
for input, target in loader:
for param, input, target in loader:
param = param.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
output = model(input, param)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):