Add grouped cache for data bigger than CPU RAM

This commit is contained in:
Yin Li 2020-05-28 23:01:04 -04:00
parent 2e687da905
commit 5bb2a19933
5 changed files with 103 additions and 24 deletions

View File

@ -1,4 +1,5 @@
import argparse import argparse
import warnings
from .train import ckpt_link from .train import ckpt_link
@ -53,7 +54,7 @@ def add_common_args(parser):
parser.add_argument('--load-state', default=ckpt_link, type=str, parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. ' help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. ' 'Default is the checkpoint. '
'Start from scratch if the checkpoint does not exist') 'Start from scratch if set empty or the checkpoint is missing')
parser.add_argument('--load-state-non-strict', action='store_false', parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states', help='allow incompatible keys when loading model states',
dest='load_state_strict') dest='load_state_strict')
@ -173,6 +174,17 @@ def set_common_args(args):
if args.batches > 1: if args.batches > 1:
args.loader_workers = args.batches args.loader_workers = args.batches
if not args.cache and args.cache_maxsize is not None:
args.cache_maxsize = None
warnings.warn('Resetting cache maxsize given cache is disabled',
RuntimeWarning)
if (args.cache and args.cache_maxsize is not None
and args.cache_maxsize < 1):
args.cache = False
args.cache_maxsize = None
warnings.warn('Disabling cache given cache maxsize < 1',
RuntimeWarning)
def set_train_args(args): def set_train_args(args):
set_common_args(args) set_common_args(args)

View File

@ -1 +1,2 @@
from .fields import FieldDataset from .fields import FieldDataset
from .sampler import GroupedRandomSampler

View File

@ -11,7 +11,7 @@ from .norms import import_norm
class FieldDataset(Dataset): class FieldDataset(Dataset):
"""Dataset of lists of fields. """Dataset of lists of fields.
`in_patterns` is a list of glob patterns for the input fields. `in_patterns` is a list of glob patterns for the input field files.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`. For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target fields are matched by sorting the globbed files. Input and target fields are matched by sorting the globbed files.
@ -51,15 +51,9 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \ assert len(self.in_files) == len(self.tgt_files), \
'number of input and target fields do not match' 'number of input and target fields do not match'
self.nfile = len(self.in_files)
assert len(self.in_files) > 0, 'file not found' assert self.nfile > 0, 'file not found'
if div_data:
files = len(self.in_files) // world_size
self.in_files = self.in_files[rank * files : (rank + 1) * files]
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
assert len(self.in_files) > 0, 'files not divisible among ranks'
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]] self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]] self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
@ -92,7 +86,7 @@ class FieldDataset(Dataset):
else: else:
self.crop = np.broadcast_to(crop, self.size.shape) self.crop = np.broadcast_to(crop, self.size.shape)
self.reps = self.size // self.crop self.reps = self.size // self.crop
self.tot_reps = int(np.prod(self.reps)) self.ncrop = int(np.prod(self.reps))
assert isinstance(pad, int), 'only support symmetric padding for now' assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2)) self.pad = np.broadcast_to(pad, (self.ndim, 2))
@ -104,19 +98,44 @@ class FieldDataset(Dataset):
if cache: if cache:
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields) self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
if div_data:
self.samples = []
# first add full fields when num_fields > num_GPU
for i in range(rank, self.nfile // world_size * world_size,
world_size):
self.samples.append(
range(i * self.ncrop, (i + 1) * self.ncrop)
)
# then split the rest into fractions of fields
# drop the last incomplete batch of samples
frac_start = self.nfile // world_size * world_size * self.ncrop
frac_samples = self.nfile % world_size * self.ncrop // world_size
self.samples.append(range(frac_start + rank * frac_samples,
frac_start + (rank + 1) * frac_samples))
self.samples = np.concatenate(self.samples)
else:
self.samples = np.arange(self.nfile * self.ncrop)
self.nsample = len(self.samples)
self.rank = rank
def get_fields(self, idx): def get_fields(self, idx):
in_fields = [np.load(f) for f in self.in_files[idx]] in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
return in_fields, tgt_fields return in_fields, tgt_fields
def __len__(self): def __len__(self):
return len(self.in_files) * self.tot_reps return self.nsample
def __getitem__(self, idx): def __getitem__(self, idx):
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps idx = self.samples[idx]
start = np.unravel_index(sub_idx, self.reps) * self.crop
in_fields, tgt_fields = self.get_fields(idx) in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
in_fields = crop(in_fields, start, self.crop, self.pad) in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor, tgt_fields = crop(tgt_fields, start * self.scale_factor,

31
map2map/data/sampler.py Normal file
View File

@ -0,0 +1,31 @@
from itertools import chain
import torch
from torch.utils.data import Sampler
class GroupedRandomSampler(Sampler):
"""Sample randomly within each group of samples and sequentially from group
to group.
This behaves like a simple random sampler by default
"""
def __init__(self, data_source, group_size=None):
self.data_source = data_source
self.sample_size = len(data_source)
if group_size is None:
group_size = self.sample_size
self.group_size = group_size
def __iter__(self):
starts = range(0, self.sample_size, self.group_size)
sizes = [self.group_size] * (len(starts) - 1)
sizes.append(self.sample_size - starts[-1])
return iter(chain(*[
(start + torch.randperm(size)).tolist()
for start, size in zip(starts, sizes)
]))
def __len__(self):
return self.sample_size

View File

@ -13,7 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset from .data import FieldDataset, GroupedRandomSampler
from .data.figures import fig3d from .data.figures import fig3d
from . import models from . import models
from .models import (narrow_like, from .models import (narrow_like,
@ -48,11 +48,13 @@ def gpu_worker(local_rank, node, args):
device = torch.device('cuda', local_rank) device = torch.device('cuda', local_rank)
torch.cuda.device(device) torch.cuda.device(device)
torch.manual_seed(args.seed)
#torch.backends.cudnn.deterministic = True # NOTE: test perf
rank = args.gpus_per_node * node + local_rank rank = args.gpus_per_node * node + local_rank
# Need randomness across processes, for sampler, augmentation, noise etc.
# Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank)
#torch.backends.cudnn.deterministic = True # NOTE: test perf
dist_init(rank, args) dist_init(rank, args)
train_dataset = FieldDataset( train_dataset = FieldDataset(
@ -72,7 +74,13 @@ def gpu_worker(local_rank, node, args):
rank=rank, rank=rank,
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if args.div_data:
train_sampler = GroupedRandomSampler(
train_dataset,
group_size=None if args.cache_maxsize is None else
args.cache_maxsize * train_dataset.ncrop,
)
else:
try: try:
train_sampler = DistributedSampler(train_dataset, shuffle=True) train_sampler = DistributedSampler(train_dataset, shuffle=True)
except TypeError: except TypeError:
@ -80,8 +88,8 @@ def gpu_worker(local_rank, node, args):
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches, batch_size=args.batches,
shuffle=args.div_data, shuffle=False,
sampler=None if args.div_data else train_sampler, sampler=train_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True, pin_memory=True,
) )
@ -104,7 +112,9 @@ def gpu_worker(local_rank, node, args):
rank=rank, rank=rank,
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if args.div_data:
val_sampler = None
else:
try: try:
val_sampler = DistributedSampler(val_dataset, shuffle=False) val_sampler = DistributedSampler(val_dataset, shuffle=False)
except TypeError: except TypeError:
@ -113,7 +123,7 @@ def gpu_worker(local_rank, node, args):
val_dataset, val_dataset,
batch_size=args.batches, batch_size=args.batches,
shuffle=False, shuffle=False,
sampler=None if args.div_data else val_sampler, sampler=val_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True, pin_memory=True,
) )
@ -278,6 +288,12 @@ def gpu_worker(local_rank, node, args):
os.symlink(state_file, tmp_link) # workaround to overwrite os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link) os.rename(tmp_link, ckpt_link)
if args.cache:
print('rank {} train data: {}'.format(
rank, train_dataset.get_fields.cache_info()))
print('rank {} val data: {}'.format(
rank, val_dataset.get_fields.cache_info()))
dist.destroy_process_group() dist.destroy_process_group()