Add data division, good with data caching

This commit is contained in:
Yin Li 2019-12-18 17:06:16 -05:00
parent de24f8d585
commit 01b0c8b514
6 changed files with 64 additions and 52 deletions

View File

@ -18,22 +18,25 @@ def get_args():
def add_common_args(parser): def add_common_args(parser):
parser.add_argument('--norms', type=str_list, help='comma-sep. list ' parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms') 'of normalization functions from data.norms')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition')
parser.add_argument('--model', required=True, help='model from models') parser.add_argument('--model', required=True, help='model from models')
parser.add_argument('--criterion', default='MSELoss', parser.add_argument('--criterion', default='MSELoss',
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default='', type=str,
help='path to load model, optimizer, rng, etc.') help='path to load model, optimizer, rng, etc.')
parser.add_argument('--batches', default=1, type=int, parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=0, type=int, parser.add_argument('--loader-workers', default=0, type=int,
help='number of data loading workers, per GPU in training or ' help='number of data loading workers, per GPU in training or '
'in total in testing') 'in total in testing')
parser.add_argument('--cache', action='store_true', parser.add_argument('--cache', action='store_true',
help='enable caching in field datasets') help='enable caching in field datasets')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition')
def add_train_args(parser): def add_train_args(parser):
@ -47,10 +50,11 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation input data') help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list, required=True, parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--epochs', default=1024, type=int,
help='total number of epochs to run')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable training data augmentation') help='enable training data augmentation')
parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run')
parser.add_argument('--optimizer', default='Adam', parser.add_argument('--optimizer', default='Adam',
help='optimizer from torch.optim') help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float, parser.add_argument('--lr', default=0.001, type=float,
@ -59,10 +63,13 @@ def add_train_args(parser):
# help='momentum') # help='momentum')
parser.add_argument('--weight-decay', default=0., type=float, parser.add_argument('--weight-decay', default=0., type=float,
help='weight decay') help='weight decay')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--seed', type=int, parser.add_argument('--seed', type=int,
help='seed for initializing training') help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs, useful with cache')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=20, type=int, parser.add_argument('--log-interval', default=20, type=int,
help='interval between logging training loss') help='interval between logging training loss')

View File

@ -14,15 +14,20 @@ class FieldDataset(Dataset):
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files. Input and target samples of all fields are matched by sorting the globbed files.
Input and target fields can be cached, and they can be cropped. `norms` can be a list of callables to normalize each field.
Input fields can be padded assuming periodic boundary condition.
Data augmentations are supported for scalar and vector fields. Data augmentations are supported for scalar and vector fields.
`norms` can be a list of callables to normalize each field. Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
""" """
def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0, def __init__(self, in_patterns, tgt_patterns,
augment=False, norms=None): norms=None, augment=False, crop=None, pad=0,
cache=False, div_data=False, rank=None, world_size=None,
**kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -32,6 +37,11 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \ assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match' 'input and target sample sizes do not match'
if div_data:
files = len(self.in_files) // world_size
self.in_files = self.in_files[rank * files : (rank + 1) * files]
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0]) self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0]) self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
@ -39,13 +49,15 @@ class FieldDataset(Dataset):
self.size = np.asarray(self.size) self.size = np.asarray(self.size)
self.ndim = len(self.size) self.ndim = len(self.size)
self.cache = cache if norms is not None: # FIXME: in_norms, tgt_norms
if self.cache: assert len(in_patterns) == len(norms), \
self.in_fields = [] 'numbers of normalization callables and input fields do not match'
self.tgt_fields = [] norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
for idx in range(len(self.in_files)): self.norms = norms
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]]) self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
if crop is None: if crop is None:
self.crop = self.size self.crop = self.size
@ -58,15 +70,13 @@ class FieldDataset(Dataset):
assert isinstance(pad, int), 'only support symmetric padding for now' assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2)) self.pad = np.broadcast_to(pad, (self.ndim, 2))
self.augment = augment self.cache = cache
if self.ndim == 1 and self.augment: if self.cache:
raise ValueError('cannot augment 1D fields') self.in_fields = []
self.tgt_fields = []
if norms is not None: # FIXME: in_norms, tgt_norms for idx in range(len(self.in_files)):
assert len(in_patterns) == len(norms), \ self.in_fields.append([np.load(f) for f in self.in_files[idx]])
'numbers of normalization callables and input fields do not match' self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms
def __len__(self): def __len__(self):
return len(self.in_files) * self.tot_reps return len(self.in_files) * self.tot_reps

View File

@ -13,11 +13,8 @@ def test(args):
test_dataset = FieldDataset( test_dataset = FieldDataset(
in_patterns=args.test_in_patterns, in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns, tgt_patterns=args.test_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False, augment=False,
norms=args.norms, **vars(args),
) )
test_loader = DataLoader( test_loader = DataLoader(
test_dataset, test_dataset,

View File

@ -48,19 +48,16 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset( train_dataset = FieldDataset(
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
cache=args.cache, **vars(args),
crop=args.crop,
pad=args.pad,
augment=args.augment,
norms=args.norms,
) )
#train_sampler = DistributedSampler(train_dataset, shuffle=True) if not args.div_data:
train_sampler = DistributedSampler(train_dataset) #train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches, batch_size=args.batches,
shuffle=False, shuffle=args.div_data,
sampler=train_sampler, sampler=None if args.div_data else train_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True pin_memory=True
) )
@ -68,19 +65,17 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset( val_dataset = FieldDataset(
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False, augment=False,
norms=args.norms, **{k:v for k, v in vars(args).items() if k != 'augment'},
) )
#val_sampler = DistributedSampler(val_dataset, shuffle=False) if not args.div_data:
val_sampler = DistributedSampler(val_dataset) #val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistributedSampler(val_dataset)
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batches, batch_size=args.batches,
shuffle=False, shuffle=False,
sampler=val_sampler, sampler=None if args.div_data else val_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True pin_memory=True
) )
@ -129,7 +124,8 @@ def gpu_worker(local_rank, args):
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={}) #args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
train_sampler.set_epoch(epoch) if not args.div_data:
train_sampler.set_epoch(epoch)
train(epoch, train_loader, model, criterion, optimizer, scheduler, args) train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
val_loss = validate(epoch, val_loader, model, criterion, args) val_loss = validate(epoch, val_loader, model, criterion, args)

View File

@ -42,8 +42,9 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.dis --augment --crop 100 --pad 42 \ --norms cosmology.dis --augment --crop 100 --pad 42 \
--model VNet \ --model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
--cache --div-data
# --load-state checkpoint.pth \ # --load-state checkpoint.pth \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
date date

View File

@ -42,8 +42,9 @@ srun m2m.py train \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--norms cosmology.vel --augment --crop 100 --pad 42 \ --norms cosmology.vel --augment --crop 100 --pad 42 \
--model VNet \ --model VNet \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
--cache --div-data
# --load-state checkpoint.pth \ # --load-state checkpoint.pth \
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
date date