Improve FieldDataset.assemble
This commit is contained in:
parent
0410435a8a
commit
abb16fe26a
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import pathlib
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -138,6 +140,12 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
self.assembly_line = {}
|
self.assembly_line = {}
|
||||||
|
|
||||||
|
self.commonpath = os.path.commonpath(
|
||||||
|
file
|
||||||
|
for files in self.in_files[:2] + self.tgt_files[:2]
|
||||||
|
for file in files
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nsample
|
return self.nsample
|
||||||
|
|
||||||
@ -191,52 +199,86 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
return style, in_fields, tgt_fields
|
#in_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||||
|
# for file in self.in_files[ifile]]
|
||||||
|
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||||
|
for file in self.tgt_files[ifile]]
|
||||||
|
|
||||||
def assemble(self, **fields):
|
return {
|
||||||
"""Assemble cropped fields.
|
'style': style,
|
||||||
|
'input': in_fields,
|
||||||
|
'target': tgt_fields,
|
||||||
|
#'input_relpath': in_relpath,
|
||||||
|
'target_relpath': tgt_relpath,
|
||||||
|
}
|
||||||
|
|
||||||
Repeat feeding cropped spatially ordered fields as kwargs.
|
def assemble(self, label, chan, patches, paths):
|
||||||
After filled by the crops, the whole fields are assembled and returned.
|
"""Assemble and write whole fields from patches, each being the end
|
||||||
Otherwise an empty dictionary is returned.
|
result from a cropped field by `__getitem__`.
|
||||||
|
|
||||||
|
Repeat feeding spatially ordered field patches.
|
||||||
|
After filled, the whole fields are assembled and saved to relative
|
||||||
|
paths specified by `paths` and `label`.
|
||||||
|
`chan` is used to split the channels to undo `cat` in `__getitem__`.
|
||||||
|
|
||||||
|
As an example, patches of shape `(1, 4, X, Y, Z)`, `label='_out'`
|
||||||
|
and `chan=[1, 3]`, with `paths=[['d/scalar.npy'], ['d/vector.npy']]`
|
||||||
|
will write to `'d/scalar_out.npy'` and `'d/vector_out.npy'`.
|
||||||
|
|
||||||
|
Note that `paths` assumes transposed shape due to pytorch auto batching
|
||||||
"""
|
"""
|
||||||
if self.scale_factor != 1:
|
if self.scale_factor != 1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
for k, v in fields.items():
|
if isinstance(patches, torch.Tensor):
|
||||||
if isinstance(v, torch.Tensor):
|
patches = patches.detach().cpu().numpy()
|
||||||
v = v.numpy()
|
|
||||||
|
|
||||||
assert v.ndim == 2 + self.ndim, 'ndim mismatch'
|
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||||
if any(self.crop_step > v.shape[2:]):
|
if any(self.crop_step > patches.shape[2:]):
|
||||||
raise RuntimeError('crop too small to tile')
|
raise RuntimeError('patch too small to tile')
|
||||||
|
|
||||||
v = list(v)
|
# the batched paths are a list of lists with shape (channel, batch)
|
||||||
if k in self.assembly_line:
|
# since pytorch default_collate batches list of strings transposedly
|
||||||
self.assembly_line[k] += v
|
# therefore we transpose below back to (batch, channel)
|
||||||
|
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
|
||||||
|
assert len(paths) == len(chan), 'number of fields mismatch'
|
||||||
|
paths = list(zip(* paths))
|
||||||
|
assert patches.shape[0] == len(paths), 'batch size mismatch'
|
||||||
|
|
||||||
|
patches = list(patches)
|
||||||
|
if label in self.assembly_line:
|
||||||
|
self.assembly_line[label] += patches
|
||||||
|
self.assembly_line[label + 'path'] += paths
|
||||||
else:
|
else:
|
||||||
self.assembly_line[k] = v
|
self.assembly_line[label] = patches
|
||||||
|
self.assembly_line[label + 'path'] = paths
|
||||||
|
|
||||||
del fields
|
del patches, paths
|
||||||
|
patches = self.assembly_line[label]
|
||||||
|
paths = self.assembly_line[label + 'path']
|
||||||
|
|
||||||
assembled_fields = {}
|
# NOTE anchor positioning assumes sufficient target padding and
|
||||||
|
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
|
||||||
|
narrow = self.crop + self.tgt_pad.sum(axis=1) - patches[0].shape[1:]
|
||||||
|
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
|
||||||
|
|
||||||
# NOTE anchor positioning assumes sensible target padding
|
while len(patches) >= self.ncrop:
|
||||||
# so that outputs are aligned with
|
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
|
||||||
anchors = self.anchors - self.tgt_pad[:, 0]
|
patches[0].dtype)
|
||||||
|
|
||||||
for k, v in self.assembly_line.items():
|
for patch, anchor in zip(patches, anchors):
|
||||||
while len(v) >= self.ncrop:
|
fill(fields, patch, anchor)
|
||||||
assert k not in assembled_fields
|
|
||||||
assembled_fields[k] = np.zeros(
|
|
||||||
v[0].shape[:1] + tuple(self.size), v[0].dtype)
|
|
||||||
|
|
||||||
for patch, anchor in zip(v, anchors):
|
for field, path in zip(
|
||||||
fill(assembled_fields[k], patch, anchor)
|
np.split(fields, np.cumsum(chan), axis=0),
|
||||||
|
paths[0]):
|
||||||
|
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
|
||||||
|
exist_ok=True)
|
||||||
|
|
||||||
del v[:self.ncrop]
|
path = label.join(os.path.splitext(path))
|
||||||
|
np.save(path, field)
|
||||||
|
|
||||||
return assembled_fields
|
del patches[:self.ncrop], paths[:self.ncrop]
|
||||||
|
|
||||||
|
|
||||||
def fill(field, patch, anchor):
|
def fill(field, patch, anchor):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from collections import Counter
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -59,12 +58,12 @@ def test(args):
|
|||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del state
|
del state
|
||||||
|
|
||||||
assembled_counts = Counter()
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (style, input, target) in enumerate(test_loader):
|
for i, data in enumerate(test_loader):
|
||||||
|
style, input, target = data['style'], data['input'], data['target']
|
||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
@ -72,27 +71,23 @@ def test(args):
|
|||||||
|
|
||||||
print('sample {} loss: {}'.format(i, loss.item()))
|
print('sample {} loss: {}'.format(i, loss.item()))
|
||||||
|
|
||||||
if args.in_norms is not None:
|
#if args.in_norms is not None:
|
||||||
start = 0
|
# start = 0
|
||||||
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(input[:, start:stop], undo=True)
|
# norm(input[:, start:stop], undo=True)
|
||||||
start = stop
|
# start = stop
|
||||||
if args.tgt_norms is not None:
|
if args.tgt_norms is not None:
|
||||||
start = 0
|
start = 0
|
||||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(output[:, start:stop], undo=True)
|
norm(output[:, start:stop], undo=True)
|
||||||
norm(target[:, start:stop], undo=True)
|
#norm(target[:, start:stop], undo=True)
|
||||||
start = stop
|
start = stop
|
||||||
|
|
||||||
assembled_fields = test_dataset.assemble(
|
#test_dataset.assemble('_in', in_chan, input,
|
||||||
#input=input.numpy(),
|
# data['input_relpath'])
|
||||||
output=output.numpy(),
|
test_dataset.assemble('_out', out_chan, output,
|
||||||
#target=target.numpy(),
|
data['target_relpath'])
|
||||||
)
|
#test_dataset.assemble('_tgt', out_chan, target,
|
||||||
|
# data['target_relpath'])
|
||||||
if assembled_fields:
|
|
||||||
for k, v in assembled_fields.items():
|
|
||||||
np.save(f'{k}_{assembled_counts[k]}.npy', v)
|
|
||||||
assembled_counts[k] += 1
|
|
||||||
|
@ -242,9 +242,11 @@ def train(epoch, loader, model, criterion,
|
|||||||
|
|
||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
for i, (style, input, target) in enumerate(loader):
|
for i, data in enumerate(loader):
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
|
|
||||||
|
style, input, target = data['style'], data['input'], data['target']
|
||||||
|
|
||||||
style = style.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
@ -336,7 +338,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for style, input, target in loader:
|
for data in loader:
|
||||||
|
style, input, target = data['style'], data['input'], data['target']
|
||||||
|
|
||||||
style = style.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user