Improve FieldDataset.assemble
This commit is contained in:
parent
0410435a8a
commit
abb16fe26a
@ -1,3 +1,5 @@
|
||||
import os
|
||||
import pathlib
|
||||
from glob import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -138,6 +140,12 @@ class FieldDataset(Dataset):
|
||||
|
||||
self.assembly_line = {}
|
||||
|
||||
self.commonpath = os.path.commonpath(
|
||||
file
|
||||
for files in self.in_files[:2] + self.tgt_files[:2]
|
||||
for file in files
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.nsample
|
||||
|
||||
@ -156,9 +164,9 @@ class FieldDataset(Dataset):
|
||||
|
||||
crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
||||
crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop * self.scale_factor,
|
||||
self.tgt_pad,
|
||||
self.size * self.scale_factor)
|
||||
self.crop * self.scale_factor,
|
||||
self.tgt_pad,
|
||||
self.size * self.scale_factor)
|
||||
|
||||
style = torch.from_numpy(style).to(torch.float32)
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
@ -191,52 +199,86 @@ class FieldDataset(Dataset):
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
return style, in_fields, tgt_fields
|
||||
#in_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||
# for file in self.in_files[ifile]]
|
||||
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||
for file in self.tgt_files[ifile]]
|
||||
|
||||
def assemble(self, **fields):
|
||||
"""Assemble cropped fields.
|
||||
return {
|
||||
'style': style,
|
||||
'input': in_fields,
|
||||
'target': tgt_fields,
|
||||
#'input_relpath': in_relpath,
|
||||
'target_relpath': tgt_relpath,
|
||||
}
|
||||
|
||||
Repeat feeding cropped spatially ordered fields as kwargs.
|
||||
After filled by the crops, the whole fields are assembled and returned.
|
||||
Otherwise an empty dictionary is returned.
|
||||
def assemble(self, label, chan, patches, paths):
|
||||
"""Assemble and write whole fields from patches, each being the end
|
||||
result from a cropped field by `__getitem__`.
|
||||
|
||||
Repeat feeding spatially ordered field patches.
|
||||
After filled, the whole fields are assembled and saved to relative
|
||||
paths specified by `paths` and `label`.
|
||||
`chan` is used to split the channels to undo `cat` in `__getitem__`.
|
||||
|
||||
As an example, patches of shape `(1, 4, X, Y, Z)`, `label='_out'`
|
||||
and `chan=[1, 3]`, with `paths=[['d/scalar.npy'], ['d/vector.npy']]`
|
||||
will write to `'d/scalar_out.npy'` and `'d/vector_out.npy'`.
|
||||
|
||||
Note that `paths` assumes transposed shape due to pytorch auto batching
|
||||
"""
|
||||
if self.scale_factor != 1:
|
||||
raise NotImplementedError
|
||||
|
||||
for k, v in fields.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.numpy()
|
||||
if isinstance(patches, torch.Tensor):
|
||||
patches = patches.detach().cpu().numpy()
|
||||
|
||||
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||
if any(self.crop_step > v.shape[2:]):
|
||||
raise RuntimeError('crop too small to tile')
|
||||
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||
if any(self.crop_step > patches.shape[2:]):
|
||||
raise RuntimeError('patch too small to tile')
|
||||
|
||||
v = list(v)
|
||||
if k in self.assembly_line:
|
||||
self.assembly_line[k] += v
|
||||
else:
|
||||
self.assembly_line[k] = v
|
||||
# the batched paths are a list of lists with shape (channel, batch)
|
||||
# since pytorch default_collate batches list of strings transposedly
|
||||
# therefore we transpose below back to (batch, channel)
|
||||
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
|
||||
assert len(paths) == len(chan), 'number of fields mismatch'
|
||||
paths = list(zip(* paths))
|
||||
assert patches.shape[0] == len(paths), 'batch size mismatch'
|
||||
|
||||
del fields
|
||||
patches = list(patches)
|
||||
if label in self.assembly_line:
|
||||
self.assembly_line[label] += patches
|
||||
self.assembly_line[label + 'path'] += paths
|
||||
else:
|
||||
self.assembly_line[label] = patches
|
||||
self.assembly_line[label + 'path'] = paths
|
||||
|
||||
assembled_fields = {}
|
||||
del patches, paths
|
||||
patches = self.assembly_line[label]
|
||||
paths = self.assembly_line[label + 'path']
|
||||
|
||||
# NOTE anchor positioning assumes sensible target padding
|
||||
# so that outputs are aligned with
|
||||
anchors = self.anchors - self.tgt_pad[:, 0]
|
||||
# NOTE anchor positioning assumes sufficient target padding and
|
||||
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
|
||||
narrow = self.crop + self.tgt_pad.sum(axis=1) - patches[0].shape[1:]
|
||||
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
|
||||
|
||||
for k, v in self.assembly_line.items():
|
||||
while len(v) >= self.ncrop:
|
||||
assert k not in assembled_fields
|
||||
assembled_fields[k] = np.zeros(
|
||||
v[0].shape[:1] + tuple(self.size), v[0].dtype)
|
||||
while len(patches) >= self.ncrop:
|
||||
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
|
||||
patches[0].dtype)
|
||||
|
||||
for patch, anchor in zip(v, anchors):
|
||||
fill(assembled_fields[k], patch, anchor)
|
||||
for patch, anchor in zip(patches, anchors):
|
||||
fill(fields, patch, anchor)
|
||||
|
||||
del v[:self.ncrop]
|
||||
for field, path in zip(
|
||||
np.split(fields, np.cumsum(chan), axis=0),
|
||||
paths[0]):
|
||||
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
|
||||
exist_ok=True)
|
||||
|
||||
return assembled_fields
|
||||
path = label.join(os.path.splitext(path))
|
||||
np.save(path, field)
|
||||
|
||||
del patches[:self.ncrop], paths[:self.ncrop]
|
||||
|
||||
|
||||
def fill(field, patch, anchor):
|
||||
|
@ -1,6 +1,5 @@
|
||||
import sys
|
||||
from pprint import pprint
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@ -59,12 +58,12 @@ def test(args):
|
||||
state['epoch'], args.load_state))
|
||||
del state
|
||||
|
||||
assembled_counts = Counter()
|
||||
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (style, input, target) in enumerate(test_loader):
|
||||
for i, data in enumerate(test_loader):
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
|
||||
output = model(input, style)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
@ -72,27 +71,23 @@ def test(args):
|
||||
|
||||
print('sample {} loss: {}'.format(i, loss.item()))
|
||||
|
||||
if args.in_norms is not None:
|
||||
start = 0
|
||||
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(input[:, start:stop], undo=True)
|
||||
start = stop
|
||||
#if args.in_norms is not None:
|
||||
# start = 0
|
||||
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
# norm(input[:, start:stop], undo=True)
|
||||
# start = stop
|
||||
if args.tgt_norms is not None:
|
||||
start = 0
|
||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(output[:, start:stop], undo=True)
|
||||
norm(target[:, start:stop], undo=True)
|
||||
#norm(target[:, start:stop], undo=True)
|
||||
start = stop
|
||||
|
||||
assembled_fields = test_dataset.assemble(
|
||||
#input=input.numpy(),
|
||||
output=output.numpy(),
|
||||
#target=target.numpy(),
|
||||
)
|
||||
|
||||
if assembled_fields:
|
||||
for k, v in assembled_fields.items():
|
||||
np.save(f'{k}_{assembled_counts[k]}.npy', v)
|
||||
assembled_counts[k] += 1
|
||||
#test_dataset.assemble('_in', in_chan, input,
|
||||
# data['input_relpath'])
|
||||
test_dataset.assemble('_out', out_chan, output,
|
||||
data['target_relpath'])
|
||||
#test_dataset.assemble('_tgt', out_chan, target,
|
||||
# data['target_relpath'])
|
||||
|
@ -242,9 +242,11 @@ def train(epoch, loader, model, criterion,
|
||||
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
for i, (style, input, target) in enumerate(loader):
|
||||
for i, data in enumerate(loader):
|
||||
batch = epoch * len(loader) + i + 1
|
||||
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
@ -336,7 +338,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for style, input, target in loader:
|
||||
for data in loader:
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
Loading…
Reference in New Issue
Block a user