Rename param to style
This commit is contained in:
parent
17d8e95870
commit
a61990ee45
@ -96,14 +96,14 @@ def add_common_args(parser):
|
|||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
add_common_args(parser)
|
add_common_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--train-param-pattern', type=str, required=True,
|
parser.add_argument('--train-style-pattern', type=str, required=True,
|
||||||
help='glob pattern for training data parameters')
|
help='glob pattern for training data styles')
|
||||||
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
||||||
help='comma-sep. list of glob patterns for training input data')
|
help='comma-sep. list of glob patterns for training input data')
|
||||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||||
help='comma-sep. list of glob patterns for training target data')
|
help='comma-sep. list of glob patterns for training target data')
|
||||||
parser.add_argument('--val-param-pattern', type=str,
|
parser.add_argument('--val-style-pattern', type=str,
|
||||||
help='glob pattern for validation data parameters')
|
help='glob pattern for validation data styles')
|
||||||
parser.add_argument('--val-in-patterns', type=str_list,
|
parser.add_argument('--val-in-patterns', type=str_list,
|
||||||
help='comma-sep. list of glob patterns for validation input data')
|
help='comma-sep. list of glob patterns for validation input data')
|
||||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||||
@ -164,8 +164,8 @@ def add_train_args(parser):
|
|||||||
def add_test_args(parser):
|
def add_test_args(parser):
|
||||||
add_common_args(parser)
|
add_common_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--test-param-pattern', type=str, required=True,
|
parser.add_argument('--test-style-pattern', type=str, required=True,
|
||||||
help='glob pattern for test data parameters')
|
help='glob pattern for test data styles')
|
||||||
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
||||||
help='comma-sep. list of glob patterns for test input data')
|
help='comma-sep. list of glob patterns for test input data')
|
||||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||||
|
@ -41,12 +41,12 @@ class FieldDataset(Dataset):
|
|||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
"""
|
"""
|
||||||
def __init__(self, param_pattern, in_patterns, tgt_patterns,
|
def __init__(self, style_pattern, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||||
self.param_files = sorted(glob(param_pattern))
|
self.style_files = sorted(glob(style_pattern))
|
||||||
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
@ -54,14 +54,14 @@ class FieldDataset(Dataset):
|
|||||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||||
self.tgt_files = list(zip(* tgt_file_lists))
|
self.tgt_files = list(zip(* tgt_file_lists))
|
||||||
|
|
||||||
if len(self.param_files) != len(self.in_files) != len(self.tgt_files):
|
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||||
raise ValueError('number of param files, input and target fields do not match')
|
raise ValueError('number of style, input, and target files do not match')
|
||||||
self.nfile = len(self.in_files)
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
if self.nfile == 0:
|
if self.nfile == 0:
|
||||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||||
|
|
||||||
self.param_dim = np.loadtxt(self.param_files[0]).shape[0]
|
self.style_size = np.loadtxt(self.style_files[0]).shape[0]
|
||||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
for f in self.in_files[0]]
|
for f in self.in_files[0]]
|
||||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
@ -144,7 +144,7 @@ class FieldDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
ifile, icrop = divmod(idx, self.ncrop)
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
params = np.loadtxt(self.param_files[ifile])
|
style = np.loadtxt(self.style_files[ifile])
|
||||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class FieldDataset(Dataset):
|
|||||||
self.tgt_pad,
|
self.tgt_pad,
|
||||||
self.size * self.scale_factor)
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
params = torch.from_numpy(params).to(torch.float32)
|
style = torch.from_numpy(style).to(torch.float32)
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
return params, in_fields, tgt_fields
|
return style, in_fields, tgt_fields
|
||||||
|
|
||||||
def assemble(self, **fields):
|
def assemble(self, **fields):
|
||||||
"""Assemble cropped fields.
|
"""Assemble cropped fields.
|
||||||
|
@ -59,7 +59,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
dist_init(rank, args)
|
dist_init(rank, args)
|
||||||
|
|
||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
param_pattern=args.train_param_pattern,
|
style_pattern=args.train_style_pattern,
|
||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
@ -91,7 +91,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_dataset = FieldDataset(
|
val_dataset = FieldDataset(
|
||||||
param_pattern=args.val_param_pattern,
|
style_pattern=args.val_style_pattern,
|
||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
@ -121,12 +121,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
args.param_dim = train_dataset.param_dim
|
args.style_size = train_dataset.style_size
|
||||||
args.in_chan = train_dataset.in_chan
|
args.in_chan = train_dataset.in_chan
|
||||||
args.out_chan = train_dataset.tgt_chan
|
args.out_chan = train_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(args.param_dim, sum(args.in_chan), sum(args.out_chan),
|
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
||||||
scale_factor=args.scale_factor)
|
scale_factor=args.scale_factor)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
@ -242,16 +242,16 @@ def train(epoch, loader, model, criterion,
|
|||||||
|
|
||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
for i, (param, input, target) in enumerate(loader):
|
for i, (style, input, target) in enumerate(loader):
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
|
|
||||||
param = param.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input, param)
|
output = model(input, style)
|
||||||
if batch == 1 and rank == 0:
|
if batch == 1 and rank == 0:
|
||||||
print('param shape :', param.shape)
|
print('style shape :', style.shape)
|
||||||
print('input shape :', input.shape)
|
print('input shape :', input.shape)
|
||||||
print('output shape :', output.shape)
|
print('output shape :', output.shape)
|
||||||
print('target shape :', target.shape)
|
print('target shape :', target.shape)
|
||||||
@ -336,12 +336,12 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param, input, target in loader:
|
for style, input, target in loader:
|
||||||
param = param.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input, param)
|
output = model(input, style)
|
||||||
|
|
||||||
if (hasattr(model.module, 'scale_factor')
|
if (hasattr(model.module, 'scale_factor')
|
||||||
and model.module.scale_factor != 1):
|
and model.module.scale_factor != 1):
|
||||||
|
Loading…
Reference in New Issue
Block a user