Add parameter reading
This commit is contained in:
parent
9d3253ac48
commit
4706a902f6
@ -96,10 +96,14 @@ def add_common_args(parser):
|
||||
def add_train_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--train-param-pattern', type=str, required=True,
|
||||
help='glob pattern for training data parameters')
|
||||
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training input data')
|
||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training target data')
|
||||
parser.add_argument('--val-param-pattern', type=str,
|
||||
help='glob pattern for validation data parameters')
|
||||
parser.add_argument('--val-in-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation input data')
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||
@ -160,6 +164,8 @@ def add_train_args(parser):
|
||||
def add_test_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--test-param-pattern', type=str, required=True,
|
||||
help='glob pattern for test data parameters')
|
||||
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for test input data')
|
||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||
|
@ -41,19 +41,21 @@ class FieldDataset(Dataset):
|
||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
def __init__(self, param_pattern, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||
self.param_files = sorted(param_pattern)
|
||||
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
if len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of input and target fields do not match')
|
||||
if len(self.param_files) != len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of param files, input and target fields do not match')
|
||||
self.nfile = len(self.in_files)
|
||||
|
||||
if self.nfile == 0:
|
||||
@ -141,6 +143,7 @@ class FieldDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
ifile, icrop = divmod(idx, self.ncrop)
|
||||
|
||||
params = np.loadtxt(self.param_files[idx])
|
||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||
|
||||
@ -186,7 +189,7 @@ class FieldDataset(Dataset):
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
return in_fields, tgt_fields
|
||||
return params, in_fields, tgt_fields
|
||||
|
||||
def assemble(self, **fields):
|
||||
"""Assemble cropped fields.
|
||||
|
Loading…
Reference in New Issue
Block a user