Add data division, good with data caching
This commit is contained in:
parent
de24f8d585
commit
01b0c8b514
6 changed files with 64 additions and 52 deletions
|
@ -18,22 +18,25 @@ def get_args():
|
|||
def add_common_args(parser):
|
||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||
'of normalization functions from data.norms')
|
||||
parser.add_argument('--crop', type=int,
|
||||
help='size to crop the input and target data')
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
help='pad the input data assuming periodic boundary condition')
|
||||
|
||||
parser.add_argument('--model', required=True, help='model from models')
|
||||
parser.add_argument('--criterion', default='MSELoss',
|
||||
help='model criterion from torch.nn')
|
||||
parser.add_argument('--load-state', default='', type=str,
|
||||
help='path to load model, optimizer, rng, etc.')
|
||||
|
||||
parser.add_argument('--batches', default=1, type=int,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', default=0, type=int,
|
||||
help='number of data loading workers, per GPU in training or '
|
||||
'in total in testing')
|
||||
|
||||
parser.add_argument('--cache', action='store_true',
|
||||
help='enable caching in field datasets')
|
||||
parser.add_argument('--crop', type=int,
|
||||
help='size to crop the input and target data')
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
help='pad the input data assuming periodic boundary condition')
|
||||
|
||||
|
||||
def add_train_args(parser):
|
||||
|
@ -47,10 +50,11 @@ def add_train_args(parser):
|
|||
help='comma-sep. list of glob patterns for validation input data')
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--epochs', default=1024, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable training data augmentation')
|
||||
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--optimizer', default='Adam',
|
||||
help='optimizer from torch.optim')
|
||||
parser.add_argument('--lr', default=0.001, type=float,
|
||||
|
@ -59,10 +63,13 @@ def add_train_args(parser):
|
|||
# help='momentum')
|
||||
parser.add_argument('--weight-decay', default=0., type=float,
|
||||
help='weight decay')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
choices=['gloo', 'nccl'], help='distributed backend')
|
||||
parser.add_argument('--seed', type=int,
|
||||
help='seed for initializing training')
|
||||
|
||||
parser.add_argument('--div-data', action='store_true',
|
||||
help='enable data division among GPUs, useful with cache')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
choices=['gloo', 'nccl'], help='distributed backend')
|
||||
parser.add_argument('--log-interval', default=20, type=int,
|
||||
help='interval between logging training loss')
|
||||
|
||||
|
|
|
@ -14,15 +14,20 @@ class FieldDataset(Dataset):
|
|||
Likewise `tgt_patterns` is for target fields.
|
||||
Input and target samples of all fields are matched by sorting the globbed files.
|
||||
|
||||
Input and target fields can be cached, and they can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
`norms` can be a list of callables to normalize each field.
|
||||
|
||||
Data augmentations are supported for scalar and vector fields.
|
||||
|
||||
`norms` can be a list of callables to normalize each field.
|
||||
Input and target fields can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0,
|
||||
augment=False, norms=None):
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
norms=None, augment=False, crop=None, pad=0,
|
||||
cache=False, div_data=False, rank=None, world_size=None,
|
||||
**kwargs):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
|
@ -32,6 +37,11 @@ class FieldDataset(Dataset):
|
|||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'input and target sample sizes do not match'
|
||||
|
||||
if div_data:
|
||||
files = len(self.in_files) // world_size
|
||||
self.in_files = self.in_files[rank * files : (rank + 1) * files]
|
||||
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
|
||||
|
||||
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
||||
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
||||
|
||||
|
@ -39,13 +49,15 @@ class FieldDataset(Dataset):
|
|||
self.size = np.asarray(self.size)
|
||||
self.ndim = len(self.size)
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = []
|
||||
self.tgt_fields = []
|
||||
for idx in range(len(self.in_files)):
|
||||
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
|
||||
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
|
||||
if norms is not None: # FIXME: in_norms, tgt_norms
|
||||
assert len(in_patterns) == len(norms), \
|
||||
'numbers of normalization callables and input fields do not match'
|
||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||
self.norms = norms
|
||||
|
||||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
|
||||
if crop is None:
|
||||
self.crop = self.size
|
||||
|
@ -58,15 +70,13 @@ class FieldDataset(Dataset):
|
|||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||
|
||||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
|
||||
if norms is not None: # FIXME: in_norms, tgt_norms
|
||||
assert len(in_patterns) == len(norms), \
|
||||
'numbers of normalization callables and input fields do not match'
|
||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||
self.norms = norms
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = []
|
||||
self.tgt_fields = []
|
||||
for idx in range(len(self.in_files)):
|
||||
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
|
||||
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.in_files) * self.tot_reps
|
||||
|
|
|
@ -13,11 +13,8 @@ def test(args):
|
|||
test_dataset = FieldDataset(
|
||||
in_patterns=args.test_in_patterns,
|
||||
tgt_patterns=args.test_tgt_patterns,
|
||||
cache=args.cache,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
augment=False,
|
||||
norms=args.norms,
|
||||
**vars(args),
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
|
|
|
@ -48,19 +48,16 @@ def gpu_worker(local_rank, args):
|
|||
train_dataset = FieldDataset(
|
||||
in_patterns=args.train_in_patterns,
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
cache=args.cache,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
augment=args.augment,
|
||||
norms=args.norms,
|
||||
**vars(args),
|
||||
)
|
||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
if not args.div_data:
|
||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batches,
|
||||
shuffle=False,
|
||||
sampler=train_sampler,
|
||||
shuffle=args.div_data,
|
||||
sampler=None if args.div_data else train_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
|
@ -68,19 +65,17 @@ def gpu_worker(local_rank, args):
|
|||
val_dataset = FieldDataset(
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
cache=args.cache,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
augment=False,
|
||||
norms=args.norms,
|
||||
**{k:v for k, v in vars(args).items() if k != 'augment'},
|
||||
)
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
if not args.div_data:
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
shuffle=False,
|
||||
sampler=val_sampler,
|
||||
sampler=None if args.div_data else val_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
|
@ -129,7 +124,8 @@ def gpu_worker(local_rank, args):
|
|||
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
if not args.div_data:
|
||||
train_sampler.set_epoch(epoch)
|
||||
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
|
||||
|
||||
val_loss = validate(epoch, val_loader, model, criterion, args)
|
||||
|
|
|
@ -42,8 +42,9 @@ srun m2m.py train \
|
|||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--norms cosmology.dis --augment --crop 100 --pad 42 \
|
||||
--model VNet \
|
||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
|
||||
|
||||
|
||||
date
|
||||
|
|
|
@ -42,8 +42,9 @@ srun m2m.py train \
|
|||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--norms cosmology.vel --augment --crop 100 --pad 42 \
|
||||
--model VNet \
|
||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0 \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
--epochs 128 --lr 0.001 --batches 1 --loader-workers 0
|
||||
|
||||
|
||||
date
|
||||
|
|
Loading…
Reference in a new issue