diff --git a/map2map/data/fields.py b/map2map/data/fields.py index e80a0f9..7468e5b 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -104,6 +104,7 @@ class FieldDataset(Dataset): crop_step = self.crop else: crop_step = np.broadcast_to(crop_step, (self.ndim,)) + self.crop_step = crop_step self.anchors = np.stack(np.mgrid[tuple( slice(crop_start[d], crop_stop[d], crop_step[d]) @@ -132,6 +133,8 @@ class FieldDataset(Dataset): self.nsample = self.nfile * self.ncrop + self.assembly_line = {} + def __len__(self): return self.nsample @@ -149,9 +152,9 @@ class FieldDataset(Dataset): crop(in_fields, anchor, self.crop, self.in_pad, self.size) crop(tgt_fields, anchor * self.scale_factor, - self.crop * self.scale_factor, - self.tgt_pad, - self.size * self.scale_factor) + self.crop * self.scale_factor, + self.tgt_pad, + self.size * self.scale_factor) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] @@ -185,6 +188,66 @@ class FieldDataset(Dataset): return in_fields, tgt_fields + def assemble(self, **fields): + """Assemble cropped fields. + + Repeat feeding cropped spatially ordered fields as kwargs. + After filled by the crops, the whole fields are assembled and returned. + Otherwise an empty dictionary is returned. + """ + if self.scale_factor != 1: + raise NotImplementedError + + for k, v in fields.items(): + if isinstance(v, torch.Tensor): + v = v.numpy() + + assert v.ndim == 2 + self.ndim, 'ndim mismatch' + if any(self.crop_step > v.shape[2:]): + raise RuntimeError('crop too small to tile') + + v = list(v) + if k in self.assembly_line: + self.assembly_line[k] += v + else: + self.assembly_line[k] = v + + del fields + + assembled_fields = {} + + # NOTE anchor positioning assumes sensible target padding + # so that outputs are aligned with + anchors = self.anchors - self.tgt_pad[:, 0] + + for k, v in self.assembly_line.items(): + while len(v) >= self.ncrop: + assert k not in assembled_fields + assembled_fields[k] = np.zeros( + v[0].shape[:1] + tuple(self.size), v[0].dtype) + + for patch, anchor in zip(v, anchors): + fill(assembled_fields[k], patch, anchor) + + del v[:self.ncrop] + + return assembled_fields + + +def fill(field, patch, anchor): + ndim = len(anchor) + + ind = [slice(None)] + for d, (p, a, s) in enumerate(zip( + patch.shape[1:], anchor, field.shape[1:])): + i = np.arange(a, a + p) + i %= s + i = i.reshape((-1,) + (1,) * (ndim - d - 1)) + ind.append(i) + ind = tuple(ind) + + field[ind] = patch + def crop(fields, anchor, crop, pad, size): ndim = len(size) diff --git a/map2map/test.py b/map2map/test.py index 0eb5149..cb32a2b 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -1,5 +1,6 @@ import sys from pprint import pprint +from collections import Counter import numpy as np import torch from torch.utils.data import DataLoader @@ -55,6 +56,8 @@ def test(args): state['epoch'], args.load_state)) del state + assembled_counts = Counter() + model.eval() with torch.no_grad(): @@ -80,5 +83,13 @@ def test(args): norm(target[:, start:stop], undo=True) start = stop - np.savez('{}.npz'.format(i), input=input.numpy(), - output=output.numpy(), target=target.numpy()) + assembled_fields = test_dataset.assemble( + #input=input.numpy(), + output=output.numpy(), + #target=target.numpy(), + ) + + if assembled_fields: + for k, v in assembled_fields.items(): + np.save(f'{k}_{assembled_counts[k]}.npy', v) + assembled_counts[k] += 1