From a61990ee459978144e5216f0ffa31223735dba0c Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 18 Mar 2021 10:43:14 -0700 Subject: [PATCH] Rename param to style --- map2map/args.py | 12 ++++++------ map2map/data/fields.py | 16 ++++++++-------- map2map/train.py | 22 +++++++++++----------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index bb22e46..c80c5e0 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -96,14 +96,14 @@ def add_common_args(parser): def add_train_args(parser): add_common_args(parser) - parser.add_argument('--train-param-pattern', type=str, required=True, - help='glob pattern for training data parameters') + parser.add_argument('--train-style-pattern', type=str, required=True, + help='glob pattern for training data styles') parser.add_argument('--train-in-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for training input data') parser.add_argument('--train-tgt-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for training target data') - parser.add_argument('--val-param-pattern', type=str, - help='glob pattern for validation data parameters') + parser.add_argument('--val-style-pattern', type=str, + help='glob pattern for validation data styles') parser.add_argument('--val-in-patterns', type=str_list, help='comma-sep. list of glob patterns for validation input data') parser.add_argument('--val-tgt-patterns', type=str_list, @@ -164,8 +164,8 @@ def add_train_args(parser): def add_test_args(parser): add_common_args(parser) - parser.add_argument('--test-param-pattern', type=str, required=True, - help='glob pattern for test data parameters') + parser.add_argument('--test-style-pattern', type=str, required=True, + help='glob pattern for test data styles') parser.add_argument('--test-in-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for test input data') parser.add_argument('--test-tgt-patterns', type=str_list, required=True, diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 42973c0..b245a38 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -41,12 +41,12 @@ class FieldDataset(Dataset): the input for super-resolution, in which case `crop` and `pad` are sizes of the input resolution. """ - def __init__(self, param_pattern, in_patterns, tgt_patterns, + def __init__(self, style_pattern, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, callback_at=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None, crop=None, crop_start=None, crop_stop=None, crop_step=None, in_pad=0, tgt_pad=0, scale_factor=1): - self.param_files = sorted(glob(param_pattern)) + self.style_files = sorted(glob(style_pattern)) in_file_lists = [sorted(glob(p)) for p in in_patterns] self.in_files = list(zip(* in_file_lists)) @@ -54,14 +54,14 @@ class FieldDataset(Dataset): tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] self.tgt_files = list(zip(* tgt_file_lists)) - if len(self.param_files) != len(self.in_files) != len(self.tgt_files): - raise ValueError('number of param files, input and target fields do not match') + if len(self.style_files) != len(self.in_files) != len(self.tgt_files): + raise ValueError('number of style, input, and target files do not match') self.nfile = len(self.in_files) if self.nfile == 0: raise FileNotFoundError('file not found for {}'.format(in_patterns)) - self.param_dim = np.loadtxt(self.param_files[0]).shape[0] + self.style_size = np.loadtxt(self.style_files[0]).shape[0] self.in_chan = [np.load(f, mmap_mode='r').shape[0] for f in self.in_files[0]] self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] @@ -144,7 +144,7 @@ class FieldDataset(Dataset): def __getitem__(self, idx): ifile, icrop = divmod(idx, self.ncrop) - params = np.loadtxt(self.param_files[ifile]) + style = np.loadtxt(self.style_files[ifile]) in_fields = [np.load(f) for f in self.in_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]] @@ -160,7 +160,7 @@ class FieldDataset(Dataset): self.tgt_pad, self.size * self.scale_factor) - params = torch.from_numpy(params).to(torch.float32) + style = torch.from_numpy(style).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] @@ -191,7 +191,7 @@ class FieldDataset(Dataset): in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) - return params, in_fields, tgt_fields + return style, in_fields, tgt_fields def assemble(self, **fields): """Assemble cropped fields. diff --git a/map2map/train.py b/map2map/train.py index 04e0048..0cee826 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -59,7 +59,7 @@ def gpu_worker(local_rank, node, args): dist_init(rank, args) train_dataset = FieldDataset( - param_pattern=args.train_param_pattern, + style_pattern=args.train_style_pattern, in_patterns=args.train_in_patterns, tgt_patterns=args.train_tgt_patterns, in_norms=args.in_norms, @@ -91,7 +91,7 @@ def gpu_worker(local_rank, node, args): if args.val: val_dataset = FieldDataset( - param_pattern=args.val_param_pattern, + style_pattern=args.val_style_pattern, in_patterns=args.val_in_patterns, tgt_patterns=args.val_tgt_patterns, in_norms=args.in_norms, @@ -121,12 +121,12 @@ def gpu_worker(local_rank, node, args): pin_memory=True, ) - args.param_dim = train_dataset.param_dim + args.style_size = train_dataset.style_size args.in_chan = train_dataset.in_chan args.out_chan = train_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan), + model = model(args.style_size, sum(args.in_chan), sum(args.out_chan), scale_factor=args.scale_factor) model.to(device) model = DistributedDataParallel(model, device_ids=[device], @@ -242,16 +242,16 @@ def train(epoch, loader, model, criterion, epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) - for i, (param, input, target) in enumerate(loader): + for i, (style, input, target) in enumerate(loader): batch = epoch * len(loader) + i + 1 - param = param.to(device, non_blocking=True) + style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - output = model(input, param) + output = model(input, style) if batch == 1 and rank == 0: - print('param shape :', param.shape) + print('style shape :', style.shape) print('input shape :', input.shape) print('output shape :', output.shape) print('target shape :', target.shape) @@ -336,12 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args): epoch_loss = torch.zeros(3, dtype=torch.float64, device=device) with torch.no_grad(): - for param, input, target in loader: - param = param.to(device, non_blocking=True) + for style, input, target in loader: + style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) - output = model(input, param) + output = model(input, style) if (hasattr(model.module, 'scale_factor') and model.module.scale_factor != 1):