Add styled VNet and Fix bugs
Co-authored-by: Drew Jamieson <drew.s.jamieson@gmail.com>
This commit is contained in:
parent
f5bd657625
commit
17d8e95870
@ -46,7 +46,7 @@ class FieldDataset(Dataset):
|
|||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||||
self.param_files = sorted(param_pattern)
|
self.param_files = sorted(glob(param_pattern))
|
||||||
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
@ -61,6 +61,7 @@ class FieldDataset(Dataset):
|
|||||||
if self.nfile == 0:
|
if self.nfile == 0:
|
||||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||||
|
|
||||||
|
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
|
||||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
for f in self.in_files[0]]
|
for f in self.in_files[0]]
|
||||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
@ -143,7 +144,7 @@ class FieldDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
ifile, icrop = divmod(idx, self.ncrop)
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
params = np.loadtxt(self.param_files[idx])
|
params = np.loadtxt(self.param_files[ifile])
|
||||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
@ -159,6 +160,7 @@ class FieldDataset(Dataset):
|
|||||||
self.tgt_pad,
|
self.tgt_pad,
|
||||||
self.size * self.scale_factor)
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
|
params = torch.from_numpy(params).to(torch.float32)
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class ConvElr3d(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvMod3d(nn.Module):
|
class ConvStyled3d(nn.Module):
|
||||||
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
||||||
|
|
||||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||||
@ -161,3 +161,19 @@ class ConvMod3d(nn.Module):
|
|||||||
x = x.reshape(N, Cout, *DHWout)
|
x = x.reshape(N, Cout, *DHWout)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class BatchNormStyled3d(nn.BatchNorm3d) :
|
||||||
|
""" Trivially does standard batch normalization, but accepts second argument
|
||||||
|
|
||||||
|
for style array that is not used
|
||||||
|
"""
|
||||||
|
def forward(self, x, s):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
class LeakyReLUStyled(nn.LeakyReLU):
|
||||||
|
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||||
|
|
||||||
|
for sytle array that is not used
|
||||||
|
"""
|
||||||
|
def forward(self, x, s):
|
||||||
|
return super().forward(x)
|
||||||
|
141
map2map/models/styled_conv.py
Normal file
141
map2map/models/styled_conv.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .narrow import narrow_like
|
||||||
|
from .swish import Swish
|
||||||
|
|
||||||
|
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||||
|
|
||||||
|
class ConvStyledBlock(nn.Module):
|
||||||
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
|
|
||||||
|
`seq` types:
|
||||||
|
'C': convolution specified by `kernel_size` and `stride`
|
||||||
|
'B': normalization (to be renamed to 'N')
|
||||||
|
'A': activation
|
||||||
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||||
|
'D': downsampling convolution of kernel size 2 and stride 2
|
||||||
|
"""
|
||||||
|
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||||
|
kernel_size=3, stride=1, seq='CBA'):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if out_chan is None:
|
||||||
|
out_chan = in_chan
|
||||||
|
|
||||||
|
self.style_size = style_size
|
||||||
|
self.in_chan = in_chan
|
||||||
|
self.out_chan = out_chan
|
||||||
|
if mid_chan is None:
|
||||||
|
self.mid_chan = max(in_chan, out_chan)
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
self.norm_chan = in_chan
|
||||||
|
self.idx_conv = 0
|
||||||
|
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||||
|
|
||||||
|
layers = [self._get_layer(l) for l in seq]
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def _get_layer(self, l):
|
||||||
|
if l == 'U':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||||
|
resample = 'U')
|
||||||
|
elif l == 'D':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||||
|
resample = 'D')
|
||||||
|
elif l == 'C':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
||||||
|
stride=self.stride)
|
||||||
|
elif l == 'B':
|
||||||
|
return BatchNormStyled3d(self.norm_chan)
|
||||||
|
elif l == 'A':
|
||||||
|
return LeakyReLUStyled()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
|
def _setup_conv(self):
|
||||||
|
self.idx_conv += 1
|
||||||
|
|
||||||
|
in_chan = out_chan = self.mid_chan
|
||||||
|
if self.idx_conv == 1:
|
||||||
|
in_chan = self.in_chan
|
||||||
|
if self.idx_conv == self.num_conv:
|
||||||
|
out_chan = self.out_chan
|
||||||
|
|
||||||
|
self.norm_chan = out_chan
|
||||||
|
|
||||||
|
return in_chan, out_chan
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x, s)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResStyledBlock(ConvStyledBlock):
|
||||||
|
"""Residual convolution blocks of the form specified by `seq`.
|
||||||
|
Input, via a skip connection, is added to the residual followed by an
|
||||||
|
optional activation.
|
||||||
|
|
||||||
|
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||||
|
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||||
|
`out_chan` even if it equals `in_chan`.
|
||||||
|
|
||||||
|
A trailing `'A'` in seq can either operate before or after the addition,
|
||||||
|
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||||
|
|
||||||
|
See `ConvStyledBlock` for `seq` types.
|
||||||
|
"""
|
||||||
|
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||||
|
seq='CBACBA', last_act=None):
|
||||||
|
if last_act is None:
|
||||||
|
last_act = seq[-1] == 'A'
|
||||||
|
elif last_act and seq[-1] != 'A':
|
||||||
|
warnings.warn(
|
||||||
|
'Disabling last_act without trailing activation in seq',
|
||||||
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
last_act = False
|
||||||
|
|
||||||
|
if last_act:
|
||||||
|
seq = seq[:-1]
|
||||||
|
|
||||||
|
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||||
|
|
||||||
|
if last_act:
|
||||||
|
self.act = LeakyReLUStyled()
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
if out_chan is None:
|
||||||
|
self.skip = None
|
||||||
|
else:
|
||||||
|
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||||
|
|
||||||
|
if 'U' in seq or 'D' in seq:
|
||||||
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
|
'not supported yet')
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
y = x
|
||||||
|
|
||||||
|
if self.skip is not None:
|
||||||
|
y = self.skip(y, s)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x, s)
|
||||||
|
|
||||||
|
y = narrow_like(y, x)
|
||||||
|
x += y
|
||||||
|
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x, s)
|
||||||
|
|
||||||
|
return x
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, ResBlock
|
from .styled_conv import ConvStyledBlock, ResStyledBlock
|
||||||
from .narrow import narrow_by
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
class StyledVNet(nn.Module):
|
class StyledVNet(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
|
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||||
"""V-Net like network with styles
|
"""V-Net like network with styles
|
||||||
|
|
||||||
See `vnet.VNet`.
|
See `vnet.VNet`.
|
||||||
@ -15,43 +15,43 @@ class StyledVNet(nn.Module):
|
|||||||
|
|
||||||
# activate non-identity skip connection in residual block
|
# activate non-identity skip connection in residual block
|
||||||
# by explicitly setting out_chan
|
# by explicitly setting out_chan
|
||||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
|
||||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||||
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||||
|
|
||||||
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||||
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
|
||||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||||
|
|
||||||
self.bypass = in_chan == out_chan
|
self.bypass = in_chan == out_chan
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, s):
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
x0 = x
|
x0 = x
|
||||||
|
|
||||||
y0 = self.conv_l0(x)
|
y0 = self.conv_l0(x, s)
|
||||||
x = self.down_l0(y0)
|
x = self.down_l0(y0, s)
|
||||||
|
|
||||||
y1 = self.conv_l1(x)
|
y1 = self.conv_l1(x, s)
|
||||||
x = self.down_l1(y1)
|
x = self.down_l1(y1, s)
|
||||||
|
|
||||||
x = self.conv_c(x)
|
x = self.conv_c(x, s)
|
||||||
|
|
||||||
x = self.up_r1(x)
|
x = self.up_r1(x, s)
|
||||||
y1 = narrow_by(y1, 4)
|
y1 = narrow_by(y1, 4)
|
||||||
x = torch.cat([y1, x], dim=1)
|
x = torch.cat([y1, x], dim=1)
|
||||||
del y1
|
del y1
|
||||||
x = self.conv_r1(x)
|
x = self.conv_r1(x, s)
|
||||||
|
|
||||||
x = self.up_r0(x)
|
x = self.up_r0(x, s)
|
||||||
y0 = narrow_by(y0, 16)
|
y0 = narrow_by(y0, 16)
|
||||||
x = torch.cat([y0, x], dim=1)
|
x = torch.cat([y0, x], dim=1)
|
||||||
del y0
|
del y0
|
||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x, s)
|
||||||
|
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
x0 = narrow_by(x0, 20)
|
x0 = narrow_by(x0, 20)
|
||||||
|
@ -59,6 +59,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
dist_init(rank, args)
|
dist_init(rank, args)
|
||||||
|
|
||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
|
param_pattern=args.train_param_pattern,
|
||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
@ -90,6 +91,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_dataset = FieldDataset(
|
val_dataset = FieldDataset(
|
||||||
|
param_pattern=args.val_param_pattern,
|
||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
@ -119,10 +121,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
args.param_dim = train_dataset.param_dim
|
||||||
|
args.in_chan = train_dataset.in_chan
|
||||||
|
args.out_chan = train_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(sum(args.in_chan), sum(args.out_chan),
|
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
|
||||||
scale_factor=args.scale_factor)
|
scale_factor=args.scale_factor)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
@ -238,14 +242,16 @@ def train(epoch, loader, model, criterion,
|
|||||||
|
|
||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (param, input, target) in enumerate(loader):
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
|
|
||||||
|
param = param.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input, param)
|
||||||
if batch == 1 and rank == 0:
|
if batch == 1 and rank == 0:
|
||||||
|
print('param shape :', param.shape)
|
||||||
print('input shape :', input.shape)
|
print('input shape :', input.shape)
|
||||||
print('output shape :', output.shape)
|
print('output shape :', output.shape)
|
||||||
print('target shape :', target.shape)
|
print('target shape :', target.shape)
|
||||||
@ -330,11 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input, target in loader:
|
for param, input, target in loader:
|
||||||
|
param = param.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input, param)
|
||||||
|
|
||||||
if (hasattr(model.module, 'scale_factor')
|
if (hasattr(model.module, 'scale_factor')
|
||||||
and model.module.scale_factor != 1):
|
and model.module.scale_factor != 1):
|
||||||
|
Loading…
Reference in New Issue
Block a user