diff --git a/map2map/data/fields.py b/map2map/data/fields.py index be95998..42973c0 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -46,7 +46,7 @@ class FieldDataset(Dataset): augment=False, aug_shift=None, aug_add=None, aug_mul=None, crop=None, crop_start=None, crop_stop=None, crop_step=None, in_pad=0, tgt_pad=0, scale_factor=1): - self.param_files = sorted(param_pattern) + self.param_files = sorted(glob(param_pattern)) in_file_lists = [sorted(glob(p)) for p in in_patterns] self.in_files = list(zip(* in_file_lists)) @@ -61,6 +61,7 @@ class FieldDataset(Dataset): if self.nfile == 0: raise FileNotFoundError('file not found for {}'.format(in_patterns)) + self.param_dim = np.loadtxt(self.param_files[0]).shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0] for f in self.in_files[0]] self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] @@ -143,7 +144,7 @@ class FieldDataset(Dataset): def __getitem__(self, idx): ifile, icrop = divmod(idx, self.ncrop) - params = np.loadtxt(self.param_files[idx]) + params = np.loadtxt(self.param_files[ifile]) in_fields = [np.load(f) for f in self.in_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] @@ -159,6 +160,7 @@ class FieldDataset(Dataset): self.tgt_pad, self.size * self.scale_factor) + params = torch.from_numpy(params).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] diff --git a/map2map/models/style.py b/map2map/models/style.py index 2bdaf74..94fc93e 100644 --- a/map2map/models/style.py +++ b/map2map/models/style.py @@ -82,7 +82,7 @@ class ConvElr3d(nn.Module): return x -class ConvMod3d(nn.Module): +class ConvStyled3d(nn.Module): """Convolution layer with modulation and demodulation, from StyleGAN2. Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`. @@ -158,6 +158,22 @@ class ConvMod3d(nn.Module): x = x.reshape(1, N * Cin, *DHWin) x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N) _, _, *DHWout = x.shape - x = x.reshape(N, Cout, *DHWout) + x = x.reshape(N, Cout, *DHWout) return x + +class BatchNormStyled3d(nn.BatchNorm3d) : + """ Trivially does standard batch normalization, but accepts second argument + + for style array that is not used + """ + def forward(self, x, s): + return super().forward(x) + +class LeakyReLUStyled(nn.LeakyReLU): + """ Trivially evaluates standard leaky ReLU, but accepts second argument + + for sytle array that is not used + """ + def forward(self, x, s): + return super().forward(x) diff --git a/map2map/models/styled_conv.py b/map2map/models/styled_conv.py new file mode 100644 index 0000000..69602b4 --- /dev/null +++ b/map2map/models/styled_conv.py @@ -0,0 +1,141 @@ +import warnings +import torch +import torch.nn as nn + +from .narrow import narrow_like +from .swish import Swish + +from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled + +class ConvStyledBlock(nn.Module): + """Convolution blocks of the form specified by `seq`. + + `seq` types: + 'C': convolution specified by `kernel_size` and `stride` + 'B': normalization (to be renamed to 'N') + 'A': activation + 'U': upsampling transposed convolution of kernel size 2 and stride 2 + 'D': downsampling convolution of kernel size 2 and stride 2 + """ + def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, + kernel_size=3, stride=1, seq='CBA'): + super().__init__() + + if out_chan is None: + out_chan = in_chan + + self.style_size = style_size + self.in_chan = in_chan + self.out_chan = out_chan + if mid_chan is None: + self.mid_chan = max(in_chan, out_chan) + self.kernel_size = kernel_size + self.stride = stride + + self.norm_chan = in_chan + self.idx_conv = 0 + self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) + + layers = [self._get_layer(l) for l in seq] + + self.convs = nn.ModuleList(layers) + + def _get_layer(self, l): + if l == 'U': + in_chan, out_chan = self._setup_conv() + return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, + resample = 'U') + elif l == 'D': + in_chan, out_chan = self._setup_conv() + return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, + resample = 'D') + elif l == 'C': + in_chan, out_chan = self._setup_conv() + return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size, + stride=self.stride) + elif l == 'B': + return BatchNormStyled3d(self.norm_chan) + elif l == 'A': + return LeakyReLUStyled() + else: + raise NotImplementedError('layer type {} not supported'.format(l)) + + def _setup_conv(self): + self.idx_conv += 1 + + in_chan = out_chan = self.mid_chan + if self.idx_conv == 1: + in_chan = self.in_chan + if self.idx_conv == self.num_conv: + out_chan = self.out_chan + + self.norm_chan = out_chan + + return in_chan, out_chan + + def forward(self, x, s): + for l in self.convs: + x = l(x, s) + return x + + +class ResStyledBlock(ConvStyledBlock): + """Residual convolution blocks of the form specified by `seq`. + Input, via a skip connection, is added to the residual followed by an + optional activation. + + The skip connection is identity if `out_chan` is omitted, otherwise it uses + a size 1 "convolution", i.e. one can trigger the latter by setting + `out_chan` even if it equals `in_chan`. + + A trailing `'A'` in seq can either operate before or after the addition, + depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'` + + See `ConvStyledBlock` for `seq` types. + """ + def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, + seq='CBACBA', last_act=None): + if last_act is None: + last_act = seq[-1] == 'A' + elif last_act and seq[-1] != 'A': + warnings.warn( + 'Disabling last_act without trailing activation in seq', + RuntimeWarning, + ) + last_act = False + + if last_act: + seq = seq[:-1] + + super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) + + if last_act: + self.act = LeakyReLUStyled() + else: + self.act = None + + if out_chan is None: + self.skip = None + else: + self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1) + + if 'U' in seq or 'D' in seq: + raise NotImplementedError('upsample and downsample layers ' + 'not supported yet') + + def forward(self, x, s): + y = x + + if self.skip is not None: + y = self.skip(y, s) + + for l in self.convs: + x = l(x, s) + + y = narrow_like(y, x) + x += y + + if self.act is not None: + x = self.act(x, s) + + return x diff --git a/map2map/models/styled_vnet.py b/map2map/models/styled_vnet.py index 9ef8475..d9025c5 100644 --- a/map2map/models/styled_vnet.py +++ b/map2map/models/styled_vnet.py @@ -1,12 +1,12 @@ import torch import torch.nn as nn -from .conv import ConvBlock, ResBlock +from .styled_conv import ConvStyledBlock, ResStyledBlock from .narrow import narrow_by class StyledVNet(nn.Module): - def __init__(self, in_chan, out_chan, bypass=None, **kwargs): + def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs): """V-Net like network with styles See `vnet.VNet`. @@ -15,43 +15,43 @@ class StyledVNet(nn.Module): # activate non-identity skip connection in residual block # by explicitly setting out_chan - self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA') - self.down_l0 = ConvBlock(64, seq='DBA') - self.conv_l1 = ResBlock(64, 64, seq='CBACBA') - self.down_l1 = ConvBlock(64, seq='DBA') + self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA') + self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA') + self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA') + self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA') - self.conv_c = ResBlock(64, 64, seq='CBACBA') + self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') - self.up_r1 = ConvBlock(64, seq='UBA') - self.conv_r1 = ResBlock(128, 64, seq='CBACBA') - self.up_r0 = ConvBlock(64, seq='UBA') - self.conv_r0 = ResBlock(128, out_chan, seq='CAC') + self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA') + self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA') + self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') + self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') self.bypass = in_chan == out_chan - def forward(self, x): + def forward(self, x, s): if self.bypass: x0 = x - y0 = self.conv_l0(x) - x = self.down_l0(y0) + y0 = self.conv_l0(x, s) + x = self.down_l0(y0, s) - y1 = self.conv_l1(x) - x = self.down_l1(y1) + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) - x = self.conv_c(x) + x = self.conv_c(x, s) - x = self.up_r1(x) + x = self.up_r1(x, s) y1 = narrow_by(y1, 4) x = torch.cat([y1, x], dim=1) del y1 - x = self.conv_r1(x) + x = self.conv_r1(x, s) - x = self.up_r0(x) + x = self.up_r0(x, s) y0 = narrow_by(y0, 16) x = torch.cat([y0, x], dim=1) del y0 - x = self.conv_r0(x) + x = self.conv_r0(x, s) if self.bypass: x0 = narrow_by(x0, 20) diff --git a/map2map/train.py b/map2map/train.py index 39788b7..04e0048 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -59,6 +59,7 @@ def gpu_worker(local_rank, node, args): dist_init(rank, args) train_dataset = FieldDataset( + param_pattern=args.train_param_pattern, in_patterns=args.train_in_patterns, tgt_patterns=args.train_tgt_patterns, in_norms=args.in_norms, @@ -90,6 +91,7 @@ def gpu_worker(local_rank, node, args): if args.val: val_dataset = FieldDataset( + param_pattern=args.val_param_pattern, in_patterns=args.val_in_patterns, tgt_patterns=args.val_tgt_patterns, in_norms=args.in_norms, @@ -119,10 +121,12 @@ def gpu_worker(local_rank, node, args): pin_memory=True, ) - args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan + args.param_dim = train_dataset.param_dim + args.in_chan = train_dataset.in_chan + args.out_chan = train_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(sum(args.in_chan), sum(args.out_chan), + model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan), scale_factor=args.scale_factor) model.to(device) model = DistributedDataParallel(model, device_ids=[device], @@ -238,14 +242,16 @@ def train(epoch, loader, model, criterion, epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) - for i, (input, target) in enumerate(loader): + for i, (param, input, target) in enumerate(loader): batch = epoch * len(loader) + i + 1 + param = param.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - output = model(input) + output = model(input, param) if batch == 1 and rank == 0: + print('param shape :', param.shape) print('input shape :', input.shape) print('output shape :', output.shape) print('target shape :', target.shape) @@ -330,11 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args): epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) with torch.no_grad(): - for input, target in loader: + for param, input, target in loader: + param = param.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - output = model(input) + output = model(input, param) if (hasattr(model.module, 'scale_factor') and model.module.scale_factor != 1):