Add assemble to FieldDataset to undo cropping

This commit is contained in:
Yin Li 2021-03-24 12:29:46 -04:00
parent 89e8651c26
commit 9d3253ac48
2 changed files with 79 additions and 5 deletions

View File

@ -104,6 +104,7 @@ class FieldDataset(Dataset):
crop_step = self.crop crop_step = self.crop
else: else:
crop_step = np.broadcast_to(crop_step, (self.ndim,)) crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.crop_step = crop_step
self.anchors = np.stack(np.mgrid[tuple( self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d]) slice(crop_start[d], crop_stop[d], crop_step[d])
@ -132,6 +133,8 @@ class FieldDataset(Dataset):
self.nsample = self.nfile * self.ncrop self.nsample = self.nfile * self.ncrop
self.assembly_line = {}
def __len__(self): def __len__(self):
return self.nsample return self.nsample
@ -185,6 +188,66 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields return in_fields, tgt_fields
def assemble(self, **fields):
"""Assemble cropped fields.
Repeat feeding cropped spatially ordered fields as kwargs.
After filled by the crops, the whole fields are assembled and returned.
Otherwise an empty dictionary is returned.
"""
if self.scale_factor != 1:
raise NotImplementedError
for k, v in fields.items():
if isinstance(v, torch.Tensor):
v = v.numpy()
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
if any(self.crop_step > v.shape[2:]):
raise RuntimeError('crop too small to tile')
v = list(v)
if k in self.assembly_line:
self.assembly_line[k] += v
else:
self.assembly_line[k] = v
del fields
assembled_fields = {}
# NOTE anchor positioning assumes sensible target padding
# so that outputs are aligned with
anchors = self.anchors - self.tgt_pad[:, 0]
for k, v in self.assembly_line.items():
while len(v) >= self.ncrop:
assert k not in assembled_fields
assembled_fields[k] = np.zeros(
v[0].shape[:1] + tuple(self.size), v[0].dtype)
for patch, anchor in zip(v, anchors):
fill(assembled_fields[k], patch, anchor)
del v[:self.ncrop]
return assembled_fields
def fill(field, patch, anchor):
ndim = len(anchor)
ind = [slice(None)]
for d, (p, a, s) in enumerate(zip(
patch.shape[1:], anchor, field.shape[1:])):
i = np.arange(a, a + p)
i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
ind.append(i)
ind = tuple(ind)
field[ind] = patch
def crop(fields, anchor, crop, pad, size): def crop(fields, anchor, crop, pad, size):
ndim = len(size) ndim = len(size)

View File

@ -1,5 +1,6 @@
import sys import sys
from pprint import pprint from pprint import pprint
from collections import Counter
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -55,6 +56,8 @@ def test(args):
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state
assembled_counts = Counter()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -80,5 +83,13 @@ def test(args):
norm(target[:, start:stop], undo=True) norm(target[:, start:stop], undo=True)
start = stop start = stop
np.savez('{}.npz'.format(i), input=input.numpy(), assembled_fields = test_dataset.assemble(
output=output.numpy(), target=target.numpy()) #input=input.numpy(),
output=output.numpy(),
#target=target.numpy(),
)
if assembled_fields:
for k, v in assembled_fields.items():
np.save(f'{k}_{assembled_counts[k]}.npy', v)
assembled_counts[k] += 1