Improve FieldDataset.assemble

This commit is contained in:
Yin Li 2021-03-26 11:56:43 -04:00
parent 0410435a8a
commit abb16fe26a
3 changed files with 98 additions and 57 deletions

View File

@ -1,3 +1,5 @@
import os
import pathlib
from glob import glob
import numpy as np
import torch
@ -138,6 +140,12 @@ class FieldDataset(Dataset):
self.assembly_line = {}
self.commonpath = os.path.commonpath(
file
for files in self.in_files[:2] + self.tgt_files[:2]
for file in files
)
def __len__(self):
return self.nsample
@ -191,52 +199,86 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
return style, in_fields, tgt_fields
#in_relpath = [os.path.relpath(file, start=self.commonpath)
# for file in self.in_files[ifile]]
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
for file in self.tgt_files[ifile]]
def assemble(self, **fields):
"""Assemble cropped fields.
return {
'style': style,
'input': in_fields,
'target': tgt_fields,
#'input_relpath': in_relpath,
'target_relpath': tgt_relpath,
}
Repeat feeding cropped spatially ordered fields as kwargs.
After filled by the crops, the whole fields are assembled and returned.
Otherwise an empty dictionary is returned.
def assemble(self, label, chan, patches, paths):
"""Assemble and write whole fields from patches, each being the end
result from a cropped field by `__getitem__`.
Repeat feeding spatially ordered field patches.
After filled, the whole fields are assembled and saved to relative
paths specified by `paths` and `label`.
`chan` is used to split the channels to undo `cat` in `__getitem__`.
As an example, patches of shape `(1, 4, X, Y, Z)`, `label='_out'`
and `chan=[1, 3]`, with `paths=[['d/scalar.npy'], ['d/vector.npy']]`
will write to `'d/scalar_out.npy'` and `'d/vector_out.npy'`.
Note that `paths` assumes transposed shape due to pytorch auto batching
"""
if self.scale_factor != 1:
raise NotImplementedError
for k, v in fields.items():
if isinstance(v, torch.Tensor):
v = v.numpy()
if isinstance(patches, torch.Tensor):
patches = patches.detach().cpu().numpy()
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
if any(self.crop_step > v.shape[2:]):
raise RuntimeError('crop too small to tile')
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
if any(self.crop_step > patches.shape[2:]):
raise RuntimeError('patch too small to tile')
v = list(v)
if k in self.assembly_line:
self.assembly_line[k] += v
# the batched paths are a list of lists with shape (channel, batch)
# since pytorch default_collate batches list of strings transposedly
# therefore we transpose below back to (batch, channel)
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
assert len(paths) == len(chan), 'number of fields mismatch'
paths = list(zip(* paths))
assert patches.shape[0] == len(paths), 'batch size mismatch'
patches = list(patches)
if label in self.assembly_line:
self.assembly_line[label] += patches
self.assembly_line[label + 'path'] += paths
else:
self.assembly_line[k] = v
self.assembly_line[label] = patches
self.assembly_line[label + 'path'] = paths
del fields
del patches, paths
patches = self.assembly_line[label]
paths = self.assembly_line[label + 'path']
assembled_fields = {}
# NOTE anchor positioning assumes sufficient target padding and
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
narrow = self.crop + self.tgt_pad.sum(axis=1) - patches[0].shape[1:]
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
# NOTE anchor positioning assumes sensible target padding
# so that outputs are aligned with
anchors = self.anchors - self.tgt_pad[:, 0]
while len(patches) >= self.ncrop:
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
patches[0].dtype)
for k, v in self.assembly_line.items():
while len(v) >= self.ncrop:
assert k not in assembled_fields
assembled_fields[k] = np.zeros(
v[0].shape[:1] + tuple(self.size), v[0].dtype)
for patch, anchor in zip(patches, anchors):
fill(fields, patch, anchor)
for patch, anchor in zip(v, anchors):
fill(assembled_fields[k], patch, anchor)
for field, path in zip(
np.split(fields, np.cumsum(chan), axis=0),
paths[0]):
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
exist_ok=True)
del v[:self.ncrop]
path = label.join(os.path.splitext(path))
np.save(path, field)
return assembled_fields
del patches[:self.ncrop], paths[:self.ncrop]
def fill(field, patch, anchor):

View File

@ -1,6 +1,5 @@
import sys
from pprint import pprint
from collections import Counter
import numpy as np
import torch
from torch.utils.data import DataLoader
@ -59,12 +58,12 @@ def test(args):
state['epoch'], args.load_state))
del state
assembled_counts = Counter()
model.eval()
with torch.no_grad():
for i, (style, input, target) in enumerate(test_loader):
for i, data in enumerate(test_loader):
style, input, target = data['style'], data['input'], data['target']
output = model(input, style)
input, output, target = narrow_cast(input, output, target)
@ -72,27 +71,23 @@ def test(args):
print('sample {} loss: {}'.format(i, loss.item()))
if args.in_norms is not None:
start = 0
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(input[:, start:stop], undo=True)
start = stop
#if args.in_norms is not None:
# start = 0
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
# norm = import_attr(norm, norms, callback_at=args.callback_at)
# norm(input[:, start:stop], undo=True)
# start = stop
if args.tgt_norms is not None:
start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True)
norm(target[:, start:stop], undo=True)
#norm(target[:, start:stop], undo=True)
start = stop
assembled_fields = test_dataset.assemble(
#input=input.numpy(),
output=output.numpy(),
#target=target.numpy(),
)
if assembled_fields:
for k, v in assembled_fields.items():
np.save(f'{k}_{assembled_counts[k]}.npy', v)
assembled_counts[k] += 1
#test_dataset.assemble('_in', in_chan, input,
# data['input_relpath'])
test_dataset.assemble('_out', out_chan, output,
data['target_relpath'])
#test_dataset.assemble('_tgt', out_chan, target,
# data['target_relpath'])

View File

@ -242,9 +242,11 @@ def train(epoch, loader, model, criterion,
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (style, input, target) in enumerate(loader):
for i, data in enumerate(loader):
batch = epoch * len(loader) + i + 1
style, input, target = data['style'], data['input'], data['target']
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
@ -336,7 +338,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad():
for style, input, target in loader:
for data in loader:
style, input, target = data['style'], data['input'], data['target']
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)