From 818ed6923db4bbd99b271aa42ce0cebc5c156e98 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 14 Jul 2020 17:05:30 -0400 Subject: [PATCH] Add memmap to numpy data loading --- map2map/args.py | 21 +------------ map2map/data/fields.py | 63 ++++++++------------------------------ map2map/test.py | 1 - map2map/train.py | 44 +++----------------------- scripts/dis2den.slurm | 3 +- scripts/dis2dis-test.slurm | 3 +- scripts/dis2dis.slurm | 3 +- scripts/srsgan.slurm | 5 +-- scripts/vel2vel-test.slurm | 3 +- scripts/vel2vel.slurm | 3 +- setup.py | 2 +- 11 files changed, 25 insertions(+), 126 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index f85fa0e..4b36dac 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -63,7 +63,7 @@ def add_common_args(parser): parser.add_argument('--load-state', default=ckpt_link, type=str, help='path to load the states of model, optimizer, rng, etc. ' 'Default is the checkpoint. ' - 'Start from scratch if set empty or the checkpoint is missing') + 'Start from scratch in case of empty string or missing checkpoint') parser.add_argument('--load-state-non-strict', action='store_false', help='allow incompatible keys when loading model states', dest='load_state_strict') @@ -75,12 +75,6 @@ def add_common_args(parser): 'in total in testing. Default is 0 for single batch, ' 'otherwise same as the batch size') - parser.add_argument('--cache', action='store_true', - help='enable LRU cache of input and target fields to reduce I/O') - parser.add_argument('--cache-maxsize', type=int, - help='maximum pairs of fields in cache, unlimited by default. ' - 'This only applies to training if not None, ' - 'in which case the testing cache maxsize is 1') parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s), help='directory of custorm code defining callbacks for models, ' 'norms, criteria, and optimizers. Disabled if not set. ' @@ -153,8 +147,6 @@ def add_train_args(parser): parser.add_argument('--seed', default=42, type=int, help='seed for initializing training') - parser.add_argument('--div-data', action='store_true', - help='enable data division among GPUs, useful with cache') parser.add_argument('--dist-backend', default='nccl', type=str, choices=['gloo', 'nccl'], help='distributed backend') parser.add_argument('--log-interval', default=100, type=int, @@ -190,17 +182,6 @@ def set_common_args(args): if args.batches > 1: args.loader_workers = args.batches - if not args.cache and args.cache_maxsize is not None: - args.cache_maxsize = None - warnings.warn('Resetting cache maxsize given cache is disabled', - RuntimeWarning) - if (args.cache and args.cache_maxsize is not None - and args.cache_maxsize < 1): - args.cache = False - args.cache_maxsize = None - warnings.warn('Disabling cache given cache maxsize < 1', - RuntimeWarning) - def set_train_args(args): set_common_args(args) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 873f11e..dd3621f 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -1,5 +1,4 @@ from glob import glob -from functools import lru_cache import numpy as np import torch import torch.nn.functional as F @@ -39,20 +38,12 @@ class FieldDataset(Dataset): Setting integer `scale_factor` greater than 1 will crop target bigger than the input for super-resolution, in which case `crop` and `pad` are sizes of the input resolution. - - `cache` enables LRU cache of the input and target fields, up to `cache_maxsize` - pairs (unlimited by default). - `div_data` enables data division, to be used with `cache`, so that different - fields are cached in different GPU processes. - This saves CPU RAM but limits stochasticity. """ def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, callback_at=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None, crop=None, crop_start=None, crop_stop=None, crop_step=None, - pad=0, scale_factor=1, - cache=False, cache_maxsize=None, div_data=False, - rank=None, world_size=None): + pad=0, scale_factor=1): in_file_lists = [sorted(glob(p)) for p in in_patterns] self.in_files = list(zip(* in_file_lists)) @@ -65,10 +56,12 @@ class FieldDataset(Dataset): assert self.nfile > 0, 'file not found for {}'.format(in_patterns) - self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]] - self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]] + self.in_chan = [np.load(f, mmap_mode='r').shape[0] + for f in self.in_files[0]] + self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] + for f in self.tgt_files[0]] - self.size = np.load(self.in_files[0][0]).shape[1:] + self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:] self.size = np.asarray(self.size) self.ndim = len(self.size) @@ -126,47 +119,16 @@ class FieldDataset(Dataset): 'only support integer upsampling' self.scale_factor = scale_factor - if cache: - self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields) - - if div_data: - self.samples = [] - - # first add full fields when num_fields > num_GPU - for i in range(rank, self.nfile // world_size * world_size, - world_size): - self.samples.extend(list( - range(i * self.ncrop, (i + 1) * self.ncrop) - )) - - # then split the rest into fractions of fields - # drop the last incomplete batch of samples - frac_start = self.nfile // world_size * world_size * self.ncrop - frac_samples = self.nfile % world_size * self.ncrop // world_size - self.samples.extend(list( - range(frac_start + rank * frac_samples, - frac_start + (rank + 1) * frac_samples) - )) - else: - self.samples = list(range(self.nfile * self.ncrop)) - self.nsample = len(self.samples) - - self.rank = rank - - def get_fields(self, idx): - in_fields = [np.load(f) for f in self.in_files[idx]] - tgt_fields = [np.load(f) for f in self.tgt_files[idx]] - return in_fields, tgt_fields - def __len__(self): - return self.nsample + return self.nfile * self.ncrop def __getitem__(self, idx): - idx = self.samples[idx] + ifile, icrop = divmod(idx, self.ncrop) - in_fields, tgt_fields = self.get_fields(idx // self.ncrop) + in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]] + tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]] - anchor = self.anchors[idx % self.ncrop] + anchor = self.anchors[icrop] for d, shift in enumerate(self.aug_shift): if shift is not None: @@ -221,7 +183,8 @@ def crop(fields, anchor, crop, pad, size): i = i.reshape((-1,) + (1,) * (ndim - d - 1)) ind.append(i) - x = x[ind] + x = x[tuple(ind)] + x.setflags(write=True) # workaround numpy bug before 1.18 new_fields.append(x) diff --git a/map2map/test.py b/map2map/test.py index 9466149..dd3f30f 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -30,7 +30,6 @@ def test(args): crop_step=args.crop_step, pad=args.pad, scale_factor=args.scale_factor, - cache=args.cache, ) test_loader = DataLoader( test_dataset, diff --git a/map2map/train.py b/map2map/train.py index e9de49b..c779466 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -74,23 +74,8 @@ def gpu_worker(local_rank, node, args): crop_step=args.crop_step, pad=args.pad, scale_factor=args.scale_factor, - cache=args.cache, - cache_maxsize=args.cache_maxsize, - div_data=args.div_data, - rank=rank, - world_size=args.world_size, ) - if args.div_data: - train_sampler = GroupedRandomSampler( - train_dataset, - group_size=None if args.cache_maxsize is None else - args.cache_maxsize * train_dataset.ncrop, - ) - else: - try: - train_sampler = DistributedSampler(train_dataset, shuffle=True) - except TypeError: - train_sampler = DistributedSampler(train_dataset) # old pytorch + train_sampler = DistributedSampler(train_dataset, shuffle=True) train_loader = DataLoader( train_dataset, batch_size=args.batches, @@ -117,19 +102,8 @@ def gpu_worker(local_rank, node, args): crop_step=args.crop_step, pad=args.pad, scale_factor=args.scale_factor, - cache=args.cache, - cache_maxsize=None if args.cache_maxsize is None else 1, - div_data=args.div_data, - rank=rank, - world_size=args.world_size, ) - if args.div_data: - val_sampler = None - else: - try: - val_sampler = DistributedSampler(val_dataset, shuffle=False) - except TypeError: - val_sampler = DistributedSampler(val_dataset) # old pytorch + val_sampler = DistributedSampler(val_dataset, shuffle=False) val_loader = DataLoader( val_dataset, batch_size=args.batches, @@ -252,8 +226,7 @@ def gpu_worker(local_rank, node, args): args.instance_noise_batches) for epoch in range(start_epoch, args.epochs): - if not args.div_data: - train_sampler.set_epoch(epoch) + train_sampler.set_epoch(epoch) train_loss = train(epoch, train_loader, model, criterion, optimizer, scheduler, @@ -273,10 +246,7 @@ def gpu_worker(local_rank, node, args): adv_scheduler.step(epoch_loss[0]) if rank == 0: - try: - logger.flush() - except AttributeError: - logger.close() # old pytorch + logger.flush() if ((min_loss is None or epoch_loss[0] < min_loss[0]) and epoch >= args.adv_start): @@ -299,12 +269,6 @@ def gpu_worker(local_rank, node, args): os.symlink(state_file, tmp_link) # workaround to overwrite os.rename(tmp_link, ckpt_link) - if args.cache: - print('rank {} train data: {}'.format( - rank, train_dataset.get_fields.cache_info())) - print('rank {} val data: {}'.format( - rank, val_dataset.get_fields.cache_info())) - dist.destroy_process_group() diff --git a/scripts/dis2den.slurm b/scripts/dis2den.slurm index 3a087fe..4edd18c 100644 --- a/scripts/dis2den.slurm +++ b/scripts/dis2den.slurm @@ -38,8 +38,7 @@ srun m2m.py train \ --in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \ --model UNet \ --lr 0.0001 --batches 1 --loader-workers 0 \ - --epochs 1024 --seed $RANDOM \ - --cache --div-data + --epochs 1024 --seed $RANDOM date diff --git a/scripts/dis2dis-test.slurm b/scripts/dis2dis-test.slurm index 0e035c6..8ca2370 100644 --- a/scripts/dis2dis-test.slurm +++ b/scripts/dis2dis-test.slurm @@ -38,8 +38,7 @@ m2m.py test \ --in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \ --model VNet \ --load-state best_model.pt \ - --batches 1 --loader-workers 0 \ - --cache + --batches 1 --loader-workers 0 date diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index 4adb6c1..ed39a15 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -39,8 +39,7 @@ srun m2m.py train \ --in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \ --model VNet --adv-model UNet --cgan \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ - --epochs 1024 --seed $RANDOM \ - --cache --div-data + --epochs 1024 --seed $RANDOM date diff --git a/scripts/srsgan.slurm b/scripts/srsgan.slurm index c4332e9..5de73cc 100644 --- a/scripts/srsgan.slurm +++ b/scripts/srsgan.slurm @@ -39,10 +39,7 @@ srun m2m.py train \ --in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \ --model VNet --adv-model PatchGAN --cgan \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ - --epochs 1024 --seed $RANDOM \ - --cache --div-data -# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \ -# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \ + --epochs 1024 --seed $RANDOM date diff --git a/scripts/vel2vel-test.slurm b/scripts/vel2vel-test.slurm index 7c9c3e9..8f1a5b7 100644 --- a/scripts/vel2vel-test.slurm +++ b/scripts/vel2vel-test.slurm @@ -38,8 +38,7 @@ m2m.py test \ --in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \ --model VNet \ --load-state best_model.pt \ - --batches 1 --loader-workers 0 \ - --cache + --batches 1 --loader-workers 0 date diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index e0eedb3..ac046fd 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -39,8 +39,7 @@ srun m2m.py train \ --in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \ --model VNet --adv-model UNet --cgan \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ - --epochs 1024 --seed $RANDOM \ - --cache --div-data + --epochs 1024 --seed $RANDOM date diff --git a/setup.py b/setup.py index f144c25..8e3b4d3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( packages=find_packages(), python_requires='>=3.6', install_requires=[ - 'torch', + 'torch>=1.2', 'numpy', 'scipy', ],