Add memmap to numpy data loading
This commit is contained in:
parent
eba76bf90d
commit
818ed6923d
@ -63,7 +63,7 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||||
help='path to load the states of model, optimizer, rng, etc. '
|
help='path to load the states of model, optimizer, rng, etc. '
|
||||||
'Default is the checkpoint. '
|
'Default is the checkpoint. '
|
||||||
'Start from scratch if set empty or the checkpoint is missing')
|
'Start from scratch in case of empty string or missing checkpoint')
|
||||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
@ -75,12 +75,6 @@ def add_common_args(parser):
|
|||||||
'in total in testing. Default is 0 for single batch, '
|
'in total in testing. Default is 0 for single batch, '
|
||||||
'otherwise same as the batch size')
|
'otherwise same as the batch size')
|
||||||
|
|
||||||
parser.add_argument('--cache', action='store_true',
|
|
||||||
help='enable LRU cache of input and target fields to reduce I/O')
|
|
||||||
parser.add_argument('--cache-maxsize', type=int,
|
|
||||||
help='maximum pairs of fields in cache, unlimited by default. '
|
|
||||||
'This only applies to training if not None, '
|
|
||||||
'in which case the testing cache maxsize is 1')
|
|
||||||
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
||||||
help='directory of custorm code defining callbacks for models, '
|
help='directory of custorm code defining callbacks for models, '
|
||||||
'norms, criteria, and optimizers. Disabled if not set. '
|
'norms, criteria, and optimizers. Disabled if not set. '
|
||||||
@ -153,8 +147,6 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
help='seed for initializing training')
|
help='seed for initializing training')
|
||||||
|
|
||||||
parser.add_argument('--div-data', action='store_true',
|
|
||||||
help='enable data division among GPUs, useful with cache')
|
|
||||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||||
choices=['gloo', 'nccl'], help='distributed backend')
|
choices=['gloo', 'nccl'], help='distributed backend')
|
||||||
parser.add_argument('--log-interval', default=100, type=int,
|
parser.add_argument('--log-interval', default=100, type=int,
|
||||||
@ -190,17 +182,6 @@ def set_common_args(args):
|
|||||||
if args.batches > 1:
|
if args.batches > 1:
|
||||||
args.loader_workers = args.batches
|
args.loader_workers = args.batches
|
||||||
|
|
||||||
if not args.cache and args.cache_maxsize is not None:
|
|
||||||
args.cache_maxsize = None
|
|
||||||
warnings.warn('Resetting cache maxsize given cache is disabled',
|
|
||||||
RuntimeWarning)
|
|
||||||
if (args.cache and args.cache_maxsize is not None
|
|
||||||
and args.cache_maxsize < 1):
|
|
||||||
args.cache = False
|
|
||||||
args.cache_maxsize = None
|
|
||||||
warnings.warn('Disabling cache given cache maxsize < 1',
|
|
||||||
RuntimeWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def set_train_args(args):
|
def set_train_args(args):
|
||||||
set_common_args(args)
|
set_common_args(args)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
from functools import lru_cache
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -39,20 +38,12 @@ class FieldDataset(Dataset):
|
|||||||
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
|
|
||||||
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
|
|
||||||
pairs (unlimited by default).
|
|
||||||
`div_data` enables data division, to be used with `cache`, so that different
|
|
||||||
fields are cached in different GPU processes.
|
|
||||||
This saves CPU RAM but limits stochasticity.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
pad=0, scale_factor=1,
|
pad=0, scale_factor=1):
|
||||||
cache=False, cache_maxsize=None, div_data=False,
|
|
||||||
rank=None, world_size=None):
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -65,10 +56,12 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||||
|
|
||||||
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
for f in self.in_files[0]]
|
||||||
|
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
|
for f in self.tgt_files[0]]
|
||||||
|
|
||||||
self.size = np.load(self.in_files[0][0]).shape[1:]
|
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
|
||||||
self.size = np.asarray(self.size)
|
self.size = np.asarray(self.size)
|
||||||
self.ndim = len(self.size)
|
self.ndim = len(self.size)
|
||||||
|
|
||||||
@ -126,47 +119,16 @@ class FieldDataset(Dataset):
|
|||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
if cache:
|
|
||||||
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
|
||||||
|
|
||||||
if div_data:
|
|
||||||
self.samples = []
|
|
||||||
|
|
||||||
# first add full fields when num_fields > num_GPU
|
|
||||||
for i in range(rank, self.nfile // world_size * world_size,
|
|
||||||
world_size):
|
|
||||||
self.samples.extend(list(
|
|
||||||
range(i * self.ncrop, (i + 1) * self.ncrop)
|
|
||||||
))
|
|
||||||
|
|
||||||
# then split the rest into fractions of fields
|
|
||||||
# drop the last incomplete batch of samples
|
|
||||||
frac_start = self.nfile // world_size * world_size * self.ncrop
|
|
||||||
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
|
||||||
self.samples.extend(list(
|
|
||||||
range(frac_start + rank * frac_samples,
|
|
||||||
frac_start + (rank + 1) * frac_samples)
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
self.samples = list(range(self.nfile * self.ncrop))
|
|
||||||
self.nsample = len(self.samples)
|
|
||||||
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
def get_fields(self, idx):
|
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
|
||||||
return in_fields, tgt_fields
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nsample
|
return self.nfile * self.ncrop
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
idx = self.samples[idx]
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
|
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
|
||||||
|
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
anchor = self.anchors[idx % self.ncrop]
|
anchor = self.anchors[icrop]
|
||||||
|
|
||||||
for d, shift in enumerate(self.aug_shift):
|
for d, shift in enumerate(self.aug_shift):
|
||||||
if shift is not None:
|
if shift is not None:
|
||||||
@ -221,7 +183,8 @@ def crop(fields, anchor, crop, pad, size):
|
|||||||
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
ind.append(i)
|
ind.append(i)
|
||||||
|
|
||||||
x = x[ind]
|
x = x[tuple(ind)]
|
||||||
|
x.setflags(write=True) # workaround numpy bug before 1.18
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ def test(args):
|
|||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
@ -74,23 +74,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
cache_maxsize=args.cache_maxsize,
|
|
||||||
div_data=args.div_data,
|
|
||||||
rank=rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
)
|
)
|
||||||
if args.div_data:
|
|
||||||
train_sampler = GroupedRandomSampler(
|
|
||||||
train_dataset,
|
|
||||||
group_size=None if args.cache_maxsize is None else
|
|
||||||
args.cache_maxsize * train_dataset.ncrop,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
except TypeError:
|
|
||||||
train_sampler = DistributedSampler(train_dataset) # old pytorch
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -117,19 +102,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
cache_maxsize=None if args.cache_maxsize is None else 1,
|
|
||||||
div_data=args.div_data,
|
|
||||||
rank=rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
)
|
)
|
||||||
if args.div_data:
|
|
||||||
val_sampler = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
except TypeError:
|
|
||||||
val_sampler = DistributedSampler(val_dataset) # old pytorch
|
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -252,7 +226,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
args.instance_noise_batches)
|
args.instance_noise_batches)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
if not args.div_data:
|
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader,
|
train_loss = train(epoch, train_loader,
|
||||||
@ -273,10 +246,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_scheduler.step(epoch_loss[0])
|
adv_scheduler.step(epoch_loss[0])
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
try:
|
|
||||||
logger.flush()
|
logger.flush()
|
||||||
except AttributeError:
|
|
||||||
logger.close() # old pytorch
|
|
||||||
|
|
||||||
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
||||||
and epoch >= args.adv_start):
|
and epoch >= args.adv_start):
|
||||||
@ -299,12 +269,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
if args.cache:
|
|
||||||
print('rank {} train data: {}'.format(
|
|
||||||
rank, train_dataset.get_fields.cache_info()))
|
|
||||||
print('rank {} val data: {}'.format(
|
|
||||||
rank, val_dataset.get_fields.cache_info()))
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,8 +38,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
|
||||||
--model UNet \
|
--model UNet \
|
||||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -38,8 +38,7 @@ m2m.py test \
|
|||||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pt \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0
|
||||||
--cache
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,8 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,10 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
||||||
--model VNet --adv-model PatchGAN --cgan \
|
--model VNet --adv-model PatchGAN --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
|
||||||
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -38,8 +38,7 @@ m2m.py test \
|
|||||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pt \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0
|
||||||
--cache
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,8 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
Loading…
Reference in New Issue
Block a user