Improve FieldDataset.assemble

This commit is contained in:
Yin Li 2021-03-26 11:56:43 -04:00
parent 0410435a8a
commit abb16fe26a
3 changed files with 98 additions and 57 deletions

View File

@ -1,3 +1,5 @@
import os
import pathlib
from glob import glob from glob import glob
import numpy as np import numpy as np
import torch import torch
@ -138,6 +140,12 @@ class FieldDataset(Dataset):
self.assembly_line = {} self.assembly_line = {}
self.commonpath = os.path.commonpath(
file
for files in self.in_files[:2] + self.tgt_files[:2]
for file in files
)
def __len__(self): def __len__(self):
return self.nsample return self.nsample
@ -156,9 +164,9 @@ class FieldDataset(Dataset):
crop(in_fields, anchor, self.crop, self.in_pad, self.size) crop(in_fields, anchor, self.crop, self.in_pad, self.size)
crop(tgt_fields, anchor * self.scale_factor, crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,
self.tgt_pad, self.tgt_pad,
self.size * self.scale_factor) self.size * self.scale_factor)
style = torch.from_numpy(style).to(torch.float32) style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -191,52 +199,86 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)
return style, in_fields, tgt_fields #in_relpath = [os.path.relpath(file, start=self.commonpath)
# for file in self.in_files[ifile]]
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
for file in self.tgt_files[ifile]]
def assemble(self, **fields): return {
"""Assemble cropped fields. 'style': style,
'input': in_fields,
'target': tgt_fields,
#'input_relpath': in_relpath,
'target_relpath': tgt_relpath,
}
Repeat feeding cropped spatially ordered fields as kwargs. def assemble(self, label, chan, patches, paths):
After filled by the crops, the whole fields are assembled and returned. """Assemble and write whole fields from patches, each being the end
Otherwise an empty dictionary is returned. result from a cropped field by `__getitem__`.
Repeat feeding spatially ordered field patches.
After filled, the whole fields are assembled and saved to relative
paths specified by `paths` and `label`.
`chan` is used to split the channels to undo `cat` in `__getitem__`.
As an example, patches of shape `(1, 4, X, Y, Z)`, `label='_out'`
and `chan=[1, 3]`, with `paths=[['d/scalar.npy'], ['d/vector.npy']]`
will write to `'d/scalar_out.npy'` and `'d/vector_out.npy'`.
Note that `paths` assumes transposed shape due to pytorch auto batching
""" """
if self.scale_factor != 1: if self.scale_factor != 1:
raise NotImplementedError raise NotImplementedError
for k, v in fields.items(): if isinstance(patches, torch.Tensor):
if isinstance(v, torch.Tensor): patches = patches.detach().cpu().numpy()
v = v.numpy()
assert v.ndim == 2 + self.ndim, 'ndim mismatch' assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
if any(self.crop_step > v.shape[2:]): if any(self.crop_step > patches.shape[2:]):
raise RuntimeError('crop too small to tile') raise RuntimeError('patch too small to tile')
v = list(v) # the batched paths are a list of lists with shape (channel, batch)
if k in self.assembly_line: # since pytorch default_collate batches list of strings transposedly
self.assembly_line[k] += v # therefore we transpose below back to (batch, channel)
else: assert patches.shape[1] == sum(chan), 'number of channels mismatch'
self.assembly_line[k] = v assert len(paths) == len(chan), 'number of fields mismatch'
paths = list(zip(* paths))
assert patches.shape[0] == len(paths), 'batch size mismatch'
del fields patches = list(patches)
if label in self.assembly_line:
self.assembly_line[label] += patches
self.assembly_line[label + 'path'] += paths
else:
self.assembly_line[label] = patches
self.assembly_line[label + 'path'] = paths
assembled_fields = {} del patches, paths
patches = self.assembly_line[label]
paths = self.assembly_line[label + 'path']
# NOTE anchor positioning assumes sensible target padding # NOTE anchor positioning assumes sufficient target padding and
# so that outputs are aligned with # symmetric narrowing (more on the right if odd) see `models/narrow.py`
anchors = self.anchors - self.tgt_pad[:, 0] narrow = self.crop + self.tgt_pad.sum(axis=1) - patches[0].shape[1:]
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
for k, v in self.assembly_line.items(): while len(patches) >= self.ncrop:
while len(v) >= self.ncrop: fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
assert k not in assembled_fields patches[0].dtype)
assembled_fields[k] = np.zeros(
v[0].shape[:1] + tuple(self.size), v[0].dtype)
for patch, anchor in zip(v, anchors): for patch, anchor in zip(patches, anchors):
fill(assembled_fields[k], patch, anchor) fill(fields, patch, anchor)
del v[:self.ncrop] for field, path in zip(
np.split(fields, np.cumsum(chan), axis=0),
paths[0]):
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
exist_ok=True)
return assembled_fields path = label.join(os.path.splitext(path))
np.save(path, field)
del patches[:self.ncrop], paths[:self.ncrop]
def fill(field, patch, anchor): def fill(field, patch, anchor):

View File

@ -1,6 +1,5 @@
import sys import sys
from pprint import pprint from pprint import pprint
from collections import Counter
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -59,12 +58,12 @@ def test(args):
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state
assembled_counts = Counter()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for i, (style, input, target) in enumerate(test_loader): for i, data in enumerate(test_loader):
style, input, target = data['style'], data['input'], data['target']
output = model(input, style) output = model(input, style)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
@ -72,27 +71,23 @@ def test(args):
print('sample {} loss: {}'.format(i, loss.item())) print('sample {} loss: {}'.format(i, loss.item()))
if args.in_norms is not None: #if args.in_norms is not None:
start = 0 # start = 0
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at) # norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(input[:, start:stop], undo=True) # norm(input[:, start:stop], undo=True)
start = stop # start = stop
if args.tgt_norms is not None: if args.tgt_norms is not None:
start = 0 start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at) norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True) norm(output[:, start:stop], undo=True)
norm(target[:, start:stop], undo=True) #norm(target[:, start:stop], undo=True)
start = stop start = stop
assembled_fields = test_dataset.assemble( #test_dataset.assemble('_in', in_chan, input,
#input=input.numpy(), # data['input_relpath'])
output=output.numpy(), test_dataset.assemble('_out', out_chan, output,
#target=target.numpy(), data['target_relpath'])
) #test_dataset.assemble('_tgt', out_chan, target,
# data['target_relpath'])
if assembled_fields:
for k, v in assembled_fields.items():
np.save(f'{k}_{assembled_counts[k]}.npy', v)
assembled_counts[k] += 1

View File

@ -242,9 +242,11 @@ def train(epoch, loader, model, criterion,
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (style, input, target) in enumerate(loader): for i, data in enumerate(loader):
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
style, input, target = data['style'], data['input'], data['target']
style = style.to(device, non_blocking=True) style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True) target = target.to(device, non_blocking=True)
@ -336,7 +338,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad(): with torch.no_grad():
for style, input, target in loader: for data in loader:
style, input, target = data['style'], data['input'], data['target']
style = style.to(device, non_blocking=True) style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True) target = target.to(device, non_blocking=True)