Add assemble to FieldDataset to undo cropping
This commit is contained in:
parent
89e8651c26
commit
9d3253ac48
@ -104,6 +104,7 @@ class FieldDataset(Dataset):
|
|||||||
crop_step = self.crop
|
crop_step = self.crop
|
||||||
else:
|
else:
|
||||||
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||||
|
self.crop_step = crop_step
|
||||||
|
|
||||||
self.anchors = np.stack(np.mgrid[tuple(
|
self.anchors = np.stack(np.mgrid[tuple(
|
||||||
slice(crop_start[d], crop_stop[d], crop_step[d])
|
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||||
@ -132,6 +133,8 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
self.nsample = self.nfile * self.ncrop
|
self.nsample = self.nfile * self.ncrop
|
||||||
|
|
||||||
|
self.assembly_line = {}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nsample
|
return self.nsample
|
||||||
|
|
||||||
@ -149,9 +152,9 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
||||||
crop(tgt_fields, anchor * self.scale_factor,
|
crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
self.tgt_pad,
|
self.tgt_pad,
|
||||||
self.size * self.scale_factor)
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
@ -185,6 +188,66 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
def assemble(self, **fields):
|
||||||
|
"""Assemble cropped fields.
|
||||||
|
|
||||||
|
Repeat feeding cropped spatially ordered fields as kwargs.
|
||||||
|
After filled by the crops, the whole fields are assembled and returned.
|
||||||
|
Otherwise an empty dictionary is returned.
|
||||||
|
"""
|
||||||
|
if self.scale_factor != 1:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
for k, v in fields.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.numpy()
|
||||||
|
|
||||||
|
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||||
|
if any(self.crop_step > v.shape[2:]):
|
||||||
|
raise RuntimeError('crop too small to tile')
|
||||||
|
|
||||||
|
v = list(v)
|
||||||
|
if k in self.assembly_line:
|
||||||
|
self.assembly_line[k] += v
|
||||||
|
else:
|
||||||
|
self.assembly_line[k] = v
|
||||||
|
|
||||||
|
del fields
|
||||||
|
|
||||||
|
assembled_fields = {}
|
||||||
|
|
||||||
|
# NOTE anchor positioning assumes sensible target padding
|
||||||
|
# so that outputs are aligned with
|
||||||
|
anchors = self.anchors - self.tgt_pad[:, 0]
|
||||||
|
|
||||||
|
for k, v in self.assembly_line.items():
|
||||||
|
while len(v) >= self.ncrop:
|
||||||
|
assert k not in assembled_fields
|
||||||
|
assembled_fields[k] = np.zeros(
|
||||||
|
v[0].shape[:1] + tuple(self.size), v[0].dtype)
|
||||||
|
|
||||||
|
for patch, anchor in zip(v, anchors):
|
||||||
|
fill(assembled_fields[k], patch, anchor)
|
||||||
|
|
||||||
|
del v[:self.ncrop]
|
||||||
|
|
||||||
|
return assembled_fields
|
||||||
|
|
||||||
|
|
||||||
|
def fill(field, patch, anchor):
|
||||||
|
ndim = len(anchor)
|
||||||
|
|
||||||
|
ind = [slice(None)]
|
||||||
|
for d, (p, a, s) in enumerate(zip(
|
||||||
|
patch.shape[1:], anchor, field.shape[1:])):
|
||||||
|
i = np.arange(a, a + p)
|
||||||
|
i %= s
|
||||||
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
|
ind.append(i)
|
||||||
|
ind = tuple(ind)
|
||||||
|
|
||||||
|
field[ind] = patch
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, anchor, crop, pad, size):
|
def crop(fields, anchor, crop, pad, size):
|
||||||
ndim = len(size)
|
ndim = len(size)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from collections import Counter
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -55,6 +56,8 @@ def test(args):
|
|||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del state
|
del state
|
||||||
|
|
||||||
|
assembled_counts = Counter()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -80,5 +83,13 @@ def test(args):
|
|||||||
norm(target[:, start:stop], undo=True)
|
norm(target[:, start:stop], undo=True)
|
||||||
start = stop
|
start = stop
|
||||||
|
|
||||||
np.savez('{}.npz'.format(i), input=input.numpy(),
|
assembled_fields = test_dataset.assemble(
|
||||||
output=output.numpy(), target=target.numpy())
|
#input=input.numpy(),
|
||||||
|
output=output.numpy(),
|
||||||
|
#target=target.numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if assembled_fields:
|
||||||
|
for k, v in assembled_fields.items():
|
||||||
|
np.save(f'{k}_{assembled_counts[k]}.npy', v)
|
||||||
|
assembled_counts[k] += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user