Rename param to style
This commit is contained in:
parent
17d8e95870
commit
a61990ee45
@ -96,14 +96,14 @@ def add_common_args(parser):
|
||||
def add_train_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--train-param-pattern', type=str, required=True,
|
||||
help='glob pattern for training data parameters')
|
||||
parser.add_argument('--train-style-pattern', type=str, required=True,
|
||||
help='glob pattern for training data styles')
|
||||
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training input data')
|
||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training target data')
|
||||
parser.add_argument('--val-param-pattern', type=str,
|
||||
help='glob pattern for validation data parameters')
|
||||
parser.add_argument('--val-style-pattern', type=str,
|
||||
help='glob pattern for validation data styles')
|
||||
parser.add_argument('--val-in-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation input data')
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||
@ -164,8 +164,8 @@ def add_train_args(parser):
|
||||
def add_test_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--test-param-pattern', type=str, required=True,
|
||||
help='glob pattern for test data parameters')
|
||||
parser.add_argument('--test-style-pattern', type=str, required=True,
|
||||
help='glob pattern for test data styles')
|
||||
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for test input data')
|
||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||
|
@ -41,12 +41,12 @@ class FieldDataset(Dataset):
|
||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
"""
|
||||
def __init__(self, param_pattern, in_patterns, tgt_patterns,
|
||||
def __init__(self, style_pattern, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||
self.param_files = sorted(glob(param_pattern))
|
||||
self.style_files = sorted(glob(style_pattern))
|
||||
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
@ -54,14 +54,14 @@ class FieldDataset(Dataset):
|
||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
if len(self.param_files) != len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of param files, input and target fields do not match')
|
||||
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of style, input, and target files do not match')
|
||||
self.nfile = len(self.in_files)
|
||||
|
||||
if self.nfile == 0:
|
||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||
|
||||
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
|
||||
self.style_size = np.loadtxt(self.style_files[0]).shape[0]
|
||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.in_files[0]]
|
||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
@ -144,7 +144,7 @@ class FieldDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
ifile, icrop = divmod(idx, self.ncrop)
|
||||
|
||||
params = np.loadtxt(self.param_files[ifile])
|
||||
style = np.loadtxt(self.style_files[ifile])
|
||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||
|
||||
@ -160,7 +160,7 @@ class FieldDataset(Dataset):
|
||||
self.tgt_pad,
|
||||
self.size * self.scale_factor)
|
||||
|
||||
params = torch.from_numpy(params).to(torch.float32)
|
||||
style = torch.from_numpy(style).to(torch.float32)
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
|
||||
@ -191,7 +191,7 @@ class FieldDataset(Dataset):
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
return params, in_fields, tgt_fields
|
||||
return style, in_fields, tgt_fields
|
||||
|
||||
def assemble(self, **fields):
|
||||
"""Assemble cropped fields.
|
||||
|
@ -59,7 +59,7 @@ def gpu_worker(local_rank, node, args):
|
||||
dist_init(rank, args)
|
||||
|
||||
train_dataset = FieldDataset(
|
||||
param_pattern=args.train_param_pattern,
|
||||
style_pattern=args.train_style_pattern,
|
||||
in_patterns=args.train_in_patterns,
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
@ -91,7 +91,7 @@ def gpu_worker(local_rank, node, args):
|
||||
|
||||
if args.val:
|
||||
val_dataset = FieldDataset(
|
||||
param_pattern=args.val_param_pattern,
|
||||
style_pattern=args.val_style_pattern,
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
@ -121,12 +121,12 @@ def gpu_worker(local_rank, node, args):
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
args.param_dim = train_dataset.param_dim
|
||||
args.style_size = train_dataset.style_size
|
||||
args.in_chan = train_dataset.in_chan
|
||||
args.out_chan = train_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
|
||||
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
||||
scale_factor=args.scale_factor)
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
@ -242,16 +242,16 @@ def train(epoch, loader, model, criterion,
|
||||
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
for i, (param, input, target) in enumerate(loader):
|
||||
for i, (style, input, target) in enumerate(loader):
|
||||
batch = epoch * len(loader) + i + 1
|
||||
|
||||
param = param.to(device, non_blocking=True)
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input, param)
|
||||
output = model(input, style)
|
||||
if batch == 1 and rank == 0:
|
||||
print('param shape :', param.shape)
|
||||
print('style shape :', style.shape)
|
||||
print('input shape :', input.shape)
|
||||
print('output shape :', output.shape)
|
||||
print('target shape :', target.shape)
|
||||
@ -336,12 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for param, input, target in loader:
|
||||
param = param.to(device, non_blocking=True)
|
||||
for style, input, target in loader:
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input, param)
|
||||
output = model(input, style)
|
||||
|
||||
if (hasattr(model.module, 'scale_factor')
|
||||
and model.module.scale_factor != 1):
|
||||
|
Loading…
Reference in New Issue
Block a user