Add parameter reading

This commit is contained in:
Yin Li 2020-07-13 13:09:41 -04:00
parent 9d3253ac48
commit 4706a902f6
2 changed files with 13 additions and 4 deletions

View file

@ -96,10 +96,14 @@ def add_common_args(parser):
def add_train_args(parser):
add_common_args(parser)
parser.add_argument('--train-param-pattern', type=str, required=True,
help='glob pattern for training data parameters')
parser.add_argument('--train-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training input data')
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training target data')
parser.add_argument('--val-param-pattern', type=str,
help='glob pattern for validation data parameters')
parser.add_argument('--val-in-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list,
@ -160,6 +164,8 @@ def add_train_args(parser):
def add_test_args(parser):
add_common_args(parser)
parser.add_argument('--test-param-pattern', type=str, required=True,
help='glob pattern for test data parameters')
parser.add_argument('--test-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test input data')
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,

View file

@ -41,19 +41,21 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
"""
def __init__(self, in_patterns, tgt_patterns,
def __init__(self, param_pattern, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1):
self.param_files = sorted(param_pattern)
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists))
if len(self.in_files) != len(self.tgt_files):
raise ValueError('number of input and target fields do not match')
if len(self.param_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of param files, input and target fields do not match')
self.nfile = len(self.in_files)
if self.nfile == 0:
@ -141,6 +143,7 @@ class FieldDataset(Dataset):
def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop)
params = np.loadtxt(self.param_files[idx])
in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
@ -186,7 +189,7 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
return in_fields, tgt_fields
return params, in_fields, tgt_fields
def assemble(self, **fields):
"""Assemble cropped fields.