From abb16fe26a51ec16466c993d1f02821622b84361 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 26 Mar 2021 11:56:43 -0400 Subject: [PATCH] Improve FieldDataset.assemble --- map2map/data/fields.py | 110 ++++++++++++++++++++++++++++------------- map2map/test.py | 37 ++++++-------- map2map/train.py | 8 ++- 3 files changed, 98 insertions(+), 57 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index b245a38..d7224a8 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -1,3 +1,5 @@ +import os +import pathlib from glob import glob import numpy as np import torch @@ -138,6 +140,12 @@ class FieldDataset(Dataset): self.assembly_line = {} + self.commonpath = os.path.commonpath( + file + for files in self.in_files[:2] + self.tgt_files[:2] + for file in files + ) + def __len__(self): return self.nsample @@ -156,9 +164,9 @@ class FieldDataset(Dataset): crop(in_fields, anchor, self.crop, self.in_pad, self.size) crop(tgt_fields, anchor * self.scale_factor, - self.crop * self.scale_factor, - self.tgt_pad, - self.size * self.scale_factor) + self.crop * self.scale_factor, + self.tgt_pad, + self.size * self.scale_factor) style = torch.from_numpy(style).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] @@ -191,52 +199,86 @@ class FieldDataset(Dataset): in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) - return style, in_fields, tgt_fields + #in_relpath = [os.path.relpath(file, start=self.commonpath) + # for file in self.in_files[ifile]] + tgt_relpath = [os.path.relpath(file, start=self.commonpath) + for file in self.tgt_files[ifile]] - def assemble(self, **fields): - """Assemble cropped fields. + return { + 'style': style, + 'input': in_fields, + 'target': tgt_fields, + #'input_relpath': in_relpath, + 'target_relpath': tgt_relpath, + } - Repeat feeding cropped spatially ordered fields as kwargs. - After filled by the crops, the whole fields are assembled and returned. - Otherwise an empty dictionary is returned. + def assemble(self, label, chan, patches, paths): + """Assemble and write whole fields from patches, each being the end + result from a cropped field by `__getitem__`. + + Repeat feeding spatially ordered field patches. + After filled, the whole fields are assembled and saved to relative + paths specified by `paths` and `label`. + `chan` is used to split the channels to undo `cat` in `__getitem__`. + + As an example, patches of shape `(1, 4, X, Y, Z)`, `label='_out'` + and `chan=[1, 3]`, with `paths=[['d/scalar.npy'], ['d/vector.npy']]` + will write to `'d/scalar_out.npy'` and `'d/vector_out.npy'`. + + Note that `paths` assumes transposed shape due to pytorch auto batching """ if self.scale_factor != 1: raise NotImplementedError - for k, v in fields.items(): - if isinstance(v, torch.Tensor): - v = v.numpy() + if isinstance(patches, torch.Tensor): + patches = patches.detach().cpu().numpy() - assert v.ndim == 2 + self.ndim, 'ndim mismatch' - if any(self.crop_step > v.shape[2:]): - raise RuntimeError('crop too small to tile') + assert patches.ndim == 2 + self.ndim, 'ndim mismatch' + if any(self.crop_step > patches.shape[2:]): + raise RuntimeError('patch too small to tile') - v = list(v) - if k in self.assembly_line: - self.assembly_line[k] += v - else: - self.assembly_line[k] = v + # the batched paths are a list of lists with shape (channel, batch) + # since pytorch default_collate batches list of strings transposedly + # therefore we transpose below back to (batch, channel) + assert patches.shape[1] == sum(chan), 'number of channels mismatch' + assert len(paths) == len(chan), 'number of fields mismatch' + paths = list(zip(* paths)) + assert patches.shape[0] == len(paths), 'batch size mismatch' - del fields + patches = list(patches) + if label in self.assembly_line: + self.assembly_line[label] += patches + self.assembly_line[label + 'path'] += paths + else: + self.assembly_line[label] = patches + self.assembly_line[label + 'path'] = paths - assembled_fields = {} + del patches, paths + patches = self.assembly_line[label] + paths = self.assembly_line[label + 'path'] - # NOTE anchor positioning assumes sensible target padding - # so that outputs are aligned with - anchors = self.anchors - self.tgt_pad[:, 0] + # NOTE anchor positioning assumes sufficient target padding and + # symmetric narrowing (more on the right if odd) see `models/narrow.py` + narrow = self.crop + self.tgt_pad.sum(axis=1) - patches[0].shape[1:] + anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2 - for k, v in self.assembly_line.items(): - while len(v) >= self.ncrop: - assert k not in assembled_fields - assembled_fields[k] = np.zeros( - v[0].shape[:1] + tuple(self.size), v[0].dtype) + while len(patches) >= self.ncrop: + fields = np.zeros(patches[0].shape[:1] + tuple(self.size), + patches[0].dtype) - for patch, anchor in zip(v, anchors): - fill(assembled_fields[k], patch, anchor) + for patch, anchor in zip(patches, anchors): + fill(fields, patch, anchor) - del v[:self.ncrop] + for field, path in zip( + np.split(fields, np.cumsum(chan), axis=0), + paths[0]): + pathlib.Path(os.path.dirname(path)).mkdir(parents=True, + exist_ok=True) - return assembled_fields + path = label.join(os.path.splitext(path)) + np.save(path, field) + + del patches[:self.ncrop], paths[:self.ncrop] def fill(field, patch, anchor): diff --git a/map2map/test.py b/map2map/test.py index a5d1d42..314b14e 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -1,6 +1,5 @@ import sys from pprint import pprint -from collections import Counter import numpy as np import torch from torch.utils.data import DataLoader @@ -59,12 +58,12 @@ def test(args): state['epoch'], args.load_state)) del state - assembled_counts = Counter() - model.eval() with torch.no_grad(): - for i, (style, input, target) in enumerate(test_loader): + for i, data in enumerate(test_loader): + style, input, target = data['style'], data['input'], data['target'] + output = model(input, style) input, output, target = narrow_cast(input, output, target) @@ -72,27 +71,23 @@ def test(args): print('sample {} loss: {}'.format(i, loss.item())) - if args.in_norms is not None: - start = 0 - for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): - norm = import_attr(norm, norms, callback_at=args.callback_at) - norm(input[:, start:stop], undo=True) - start = stop + #if args.in_norms is not None: + # start = 0 + # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): + # norm = import_attr(norm, norms, callback_at=args.callback_at) + # norm(input[:, start:stop], undo=True) + # start = stop if args.tgt_norms is not None: start = 0 for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): norm = import_attr(norm, norms, callback_at=args.callback_at) norm(output[:, start:stop], undo=True) - norm(target[:, start:stop], undo=True) + #norm(target[:, start:stop], undo=True) start = stop - assembled_fields = test_dataset.assemble( - #input=input.numpy(), - output=output.numpy(), - #target=target.numpy(), - ) - - if assembled_fields: - for k, v in assembled_fields.items(): - np.save(f'{k}_{assembled_counts[k]}.npy', v) - assembled_counts[k] += 1 + #test_dataset.assemble('_in', in_chan, input, + # data['input_relpath']) + test_dataset.assemble('_out', out_chan, output, + data['target_relpath']) + #test_dataset.assemble('_tgt', out_chan, target, + # data['target_relpath']) diff --git a/map2map/train.py b/map2map/train.py index 0cee826..901a84a 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -242,9 +242,11 @@ def train(epoch, loader, model, criterion, epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) - for i, (style, input, target) in enumerate(loader): + for i, data in enumerate(loader): batch = epoch * len(loader) + i + 1 + style, input, target = data['style'], data['input'], data['target'] + style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) @@ -336,7 +338,9 @@ def validate(epoch, loader, model, criterion, logger, device, args): epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) with torch.no_grad(): - for style, input, target in loader: + for data in loader: + style, input, target = data['style'], data['input'], data['target'] + style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True)