Add assemble to FieldDataset to undo cropping
This commit is contained in:
parent
89e8651c26
commit
9d3253ac48
@ -104,6 +104,7 @@ class FieldDataset(Dataset):
|
||||
crop_step = self.crop
|
||||
else:
|
||||
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||
self.crop_step = crop_step
|
||||
|
||||
self.anchors = np.stack(np.mgrid[tuple(
|
||||
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||
@ -132,6 +133,8 @@ class FieldDataset(Dataset):
|
||||
|
||||
self.nsample = self.nfile * self.ncrop
|
||||
|
||||
self.assembly_line = {}
|
||||
|
||||
def __len__(self):
|
||||
return self.nsample
|
||||
|
||||
@ -185,6 +188,66 @@ class FieldDataset(Dataset):
|
||||
|
||||
return in_fields, tgt_fields
|
||||
|
||||
def assemble(self, **fields):
|
||||
"""Assemble cropped fields.
|
||||
|
||||
Repeat feeding cropped spatially ordered fields as kwargs.
|
||||
After filled by the crops, the whole fields are assembled and returned.
|
||||
Otherwise an empty dictionary is returned.
|
||||
"""
|
||||
if self.scale_factor != 1:
|
||||
raise NotImplementedError
|
||||
|
||||
for k, v in fields.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.numpy()
|
||||
|
||||
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||
if any(self.crop_step > v.shape[2:]):
|
||||
raise RuntimeError('crop too small to tile')
|
||||
|
||||
v = list(v)
|
||||
if k in self.assembly_line:
|
||||
self.assembly_line[k] += v
|
||||
else:
|
||||
self.assembly_line[k] = v
|
||||
|
||||
del fields
|
||||
|
||||
assembled_fields = {}
|
||||
|
||||
# NOTE anchor positioning assumes sensible target padding
|
||||
# so that outputs are aligned with
|
||||
anchors = self.anchors - self.tgt_pad[:, 0]
|
||||
|
||||
for k, v in self.assembly_line.items():
|
||||
while len(v) >= self.ncrop:
|
||||
assert k not in assembled_fields
|
||||
assembled_fields[k] = np.zeros(
|
||||
v[0].shape[:1] + tuple(self.size), v[0].dtype)
|
||||
|
||||
for patch, anchor in zip(v, anchors):
|
||||
fill(assembled_fields[k], patch, anchor)
|
||||
|
||||
del v[:self.ncrop]
|
||||
|
||||
return assembled_fields
|
||||
|
||||
|
||||
def fill(field, patch, anchor):
|
||||
ndim = len(anchor)
|
||||
|
||||
ind = [slice(None)]
|
||||
for d, (p, a, s) in enumerate(zip(
|
||||
patch.shape[1:], anchor, field.shape[1:])):
|
||||
i = np.arange(a, a + p)
|
||||
i %= s
|
||||
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||
ind.append(i)
|
||||
ind = tuple(ind)
|
||||
|
||||
field[ind] = patch
|
||||
|
||||
|
||||
def crop(fields, anchor, crop, pad, size):
|
||||
ndim = len(size)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
from pprint import pprint
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@ -55,6 +56,8 @@ def test(args):
|
||||
state['epoch'], args.load_state))
|
||||
del state
|
||||
|
||||
assembled_counts = Counter()
|
||||
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -80,5 +83,13 @@ def test(args):
|
||||
norm(target[:, start:stop], undo=True)
|
||||
start = stop
|
||||
|
||||
np.savez('{}.npz'.format(i), input=input.numpy(),
|
||||
output=output.numpy(), target=target.numpy())
|
||||
assembled_fields = test_dataset.assemble(
|
||||
#input=input.numpy(),
|
||||
output=output.numpy(),
|
||||
#target=target.numpy(),
|
||||
)
|
||||
|
||||
if assembled_fields:
|
||||
for k, v in assembled_fields.items():
|
||||
np.save(f'{k}_{assembled_counts[k]}.npy', v)
|
||||
assembled_counts[k] += 1
|
||||
|
Loading…
Reference in New Issue
Block a user