diff --git a/map2map/args.py b/map2map/args.py index f1c1c9b..bb22e46 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -96,10 +96,14 @@ def add_common_args(parser): def add_train_args(parser): add_common_args(parser) + parser.add_argument('--train-param-pattern', type=str, required=True, + help='glob pattern for training data parameters') parser.add_argument('--train-in-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for training input data') parser.add_argument('--train-tgt-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for training target data') + parser.add_argument('--val-param-pattern', type=str, + help='glob pattern for validation data parameters') parser.add_argument('--val-in-patterns', type=str_list, help='comma-sep. list of glob patterns for validation input data') parser.add_argument('--val-tgt-patterns', type=str_list, @@ -160,6 +164,8 @@ def add_train_args(parser): def add_test_args(parser): add_common_args(parser) + parser.add_argument('--test-param-pattern', type=str, required=True, + help='glob pattern for test data parameters') parser.add_argument('--test-in-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for test input data') parser.add_argument('--test-tgt-patterns', type=str_list, required=True, diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 7468e5b..be95998 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -41,19 +41,21 @@ class FieldDataset(Dataset): the input for super-resolution, in which case `crop` and `pad` are sizes of the input resolution. """ - def __init__(self, in_patterns, tgt_patterns, + def __init__(self, param_pattern, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, callback_at=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None, crop=None, crop_start=None, crop_stop=None, crop_step=None, in_pad=0, tgt_pad=0, scale_factor=1): + self.param_files = sorted(param_pattern) + in_file_lists = [sorted(glob(p)) for p in in_patterns] self.in_files = list(zip(* in_file_lists)) tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] self.tgt_files = list(zip(* tgt_file_lists)) - if len(self.in_files) != len(self.tgt_files): - raise ValueError('number of input and target fields do not match') + if len(self.param_files) != len(self.in_files) != len(self.tgt_files): + raise ValueError('number of param files, input and target fields do not match') self.nfile = len(self.in_files) if self.nfile == 0: @@ -141,6 +143,7 @@ class FieldDataset(Dataset): def __getitem__(self, idx): ifile, icrop = divmod(idx, self.ncrop) + params = np.loadtxt(self.param_files[idx]) in_fields = [np.load(f) for f in self.in_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] @@ -186,7 +189,7 @@ class FieldDataset(Dataset): in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) - return in_fields, tgt_fields + return params, in_fields, tgt_fields def assemble(self, **fields): """Assemble cropped fields.