Add data division, good with data caching

This commit is contained in:
Yin Li 2019-12-18 17:06:16 -05:00
parent de24f8d585
commit 01b0c8b514
6 changed files with 64 additions and 52 deletions

View File

@ -18,22 +18,25 @@ def get_args():
def add_common_args(parser):
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition')
parser.add_argument('--model', required=True, help='model from models')
parser.add_argument('--criterion', default='MSELoss',
help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str,
help='path to load model, optimizer, rng, etc.')
parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=0, type=int,
help='number of data loading workers, per GPU in training or '
'in total in testing')
parser.add_argument('--cache', action='store_true',
help='enable caching in field datasets')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition')
def add_train_args(parser):
@ -47,10 +50,11 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--epochs', default=1024, type=int,
help='total number of epochs to run')
parser.add_argument('--augment', action='store_true',
help='enable training data augmentation')
parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run')
parser.add_argument('--optimizer', default='Adam',
help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float,
@ -59,10 +63,13 @@ def add_train_args(parser):
# help='momentum')
parser.add_argument('--weight-decay', default=0., type=float,
help='weight decay')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--seed', type=int,
help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs, useful with cache')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=20, type=int,
help='interval between logging training loss')

View File

@ -14,15 +14,20 @@ class FieldDataset(Dataset):
Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files.
Input and target fields can be cached, and they can be cropped.
Input fields can be padded assuming periodic boundary condition.
`norms` can be a list of callables to normalize each field.
Data augmentations are supported for scalar and vector fields.
`norms` can be a list of callables to normalize each field.
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
"""
def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0,
augment=False, norms=None):
def __init__(self, in_patterns, tgt_patterns,
norms=None, augment=False, crop=None, pad=0,
cache=False, div_data=False, rank=None, world_size=None,
**kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -32,6 +37,11 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match'
if div_data:
files = len(self.in_files) // world_size
self.in_files = self.in_files[rank * files : (rank + 1) * files]
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
@ -39,13 +49,15 @@ class FieldDataset(Dataset):
self.size = np.asarray(self.size)
self.ndim = len(self.size)
self.cache = cache
if self.cache:
self.in_fields = []
self.tgt_fields = []
for idx in range(len(self.in_files)):
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
if norms is not None: # FIXME: in_norms, tgt_norms
assert len(in_patterns) == len(norms), \
'numbers of normalization callables and input fields do not match'
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
if crop is None:
self.crop = self.size
@ -58,15 +70,13 @@ class FieldDataset(Dataset):
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
if norms is not None: # FIXME: in_norms, tgt_norms
assert len(in_patterns) == len(norms), \
'numbers of normalization callables and input fields do not match'
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms
self.cache = cache
if self.cache:
self.in_fields = []
self.tgt_fields = []
for idx in range(len(self.in_files)):
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
def __len__(self):
return len(self.in_files) * self.tot_reps

View File

@ -13,11 +13,8 @@ def test(args):
test_dataset = FieldDataset(
in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False,
norms=args.norms,
**vars(args),
)
test_loader = DataLoader(
test_dataset,

View File

@ -48,19 +48,16 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=args.augment,
norms=args.norms,
**vars(args),
)
if not args.div_data:
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
shuffle=False,
sampler=train_sampler,
shuffle=args.div_data,
sampler=None if args.div_data else train_sampler,
num_workers=args.loader_workers,
pin_memory=True
)
@ -68,19 +65,17 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False,
norms=args.norms,
**{k:v for k, v in vars(args).items() if k != 'augment'},
)
if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistributedSampler(val_dataset)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
shuffle=False,
sampler=val_sampler,
sampler=None if args.div_data else val_sampler,
num_workers=args.loader_workers,
pin_memory=True
)
@ -129,6 +124,7 @@ def gpu_worker(local_rank, args):
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
for epoch in range(args.start_epoch, args.epochs):
if not args.div_data:
train_sampler.set_epoch(epoch)
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)

View File

@ -42,8 +42,9 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.dis --augment --crop 100 --pad 42 \
--model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
--cache --div-data
# --load-state checkpoint.pth \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
date

View File

@ -42,8 +42,9 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.vel --augment --crop 100 --pad 42 \
--model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
--cache --div-data
# --load-state checkpoint.pth \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
date