Rename param to style

This commit is contained in:
Yin Li 2021-03-18 10:43:14 -07:00
parent 17d8e95870
commit a61990ee45
3 changed files with 25 additions and 25 deletions

View file

@ -96,14 +96,14 @@ def add_common_args(parser):
def add_train_args(parser):
add_common_args(parser)
parser.add_argument('--train-param-pattern', type=str, required=True,
help='glob pattern for training data parameters')
parser.add_argument('--train-style-pattern', type=str, required=True,
help='glob pattern for training data styles')
parser.add_argument('--train-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training input data')
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training target data')
parser.add_argument('--val-param-pattern', type=str,
help='glob pattern for validation data parameters')
parser.add_argument('--val-style-pattern', type=str,
help='glob pattern for validation data styles')
parser.add_argument('--val-in-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list,
@ -164,8 +164,8 @@ def add_train_args(parser):
def add_test_args(parser):
add_common_args(parser)
parser.add_argument('--test-param-pattern', type=str, required=True,
help='glob pattern for test data parameters')
parser.add_argument('--test-style-pattern', type=str, required=True,
help='glob pattern for test data styles')
parser.add_argument('--test-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test input data')
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,

View file

@ -41,12 +41,12 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
"""
def __init__(self, param_pattern, in_patterns, tgt_patterns,
def __init__(self, style_pattern, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1):
self.param_files = sorted(glob(param_pattern))
self.style_files = sorted(glob(style_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -54,14 +54,14 @@ class FieldDataset(Dataset):
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists))
if len(self.param_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of param files, input and target fields do not match')
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of style, input, and target files do not match')
self.nfile = len(self.in_files)
if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns))
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
self.style_size = np.loadtxt(self.style_files[0]).shape[0]
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
@ -144,7 +144,7 @@ class FieldDataset(Dataset):
def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop)
params = np.loadtxt(self.param_files[ifile])
style = np.loadtxt(self.style_files[ifile])
in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
@ -160,7 +160,7 @@ class FieldDataset(Dataset):
self.tgt_pad,
self.size * self.scale_factor)
params = torch.from_numpy(params).to(torch.float32)
style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -191,7 +191,7 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
return params, in_fields, tgt_fields
return style, in_fields, tgt_fields
def assemble(self, **fields):
"""Assemble cropped fields.

View file

@ -59,7 +59,7 @@ def gpu_worker(local_rank, node, args):
dist_init(rank, args)
train_dataset = FieldDataset(
param_pattern=args.train_param_pattern,
style_pattern=args.train_style_pattern,
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
@ -91,7 +91,7 @@ def gpu_worker(local_rank, node, args):
if args.val:
val_dataset = FieldDataset(
param_pattern=args.val_param_pattern,
style_pattern=args.val_style_pattern,
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
@ -121,12 +121,12 @@ def gpu_worker(local_rank, node, args):
pin_memory=True,
)
args.param_dim = train_dataset.param_dim
args.style_size = train_dataset.style_size
args.in_chan = train_dataset.in_chan
args.out_chan = train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
@ -242,16 +242,16 @@ def train(epoch, loader, model, criterion,
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (param, input, target) in enumerate(loader):
for i, (style, input, target) in enumerate(loader):
batch = epoch * len(loader) + i + 1
param = param.to(device, non_blocking=True)
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input, param)
output = model(input, style)
if batch == 1 and rank == 0:
print('param shape :', param.shape)
print('style shape :', style.shape)
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
@ -336,12 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad():
for param, input, target in loader:
param = param.to(device, non_blocking=True)
for style, input, target in loader:
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input, param)
output = model(input, style)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):