Add DistFieldSampler to benefit from page cache

This commit is contained in:
Yin Li 2020-07-21 21:13:52 -07:00
parent ec46f41ba5
commit 98cdd4795c
6 changed files with 91 additions and 34 deletions

View File

@ -70,7 +70,7 @@ def add_common_args(parser):
parser.add_argument('--batches', type=int, required=True, parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=-2, type=int, parser.add_argument('--loader-workers', default=-8, type=int,
help='number of subprocesses per data loader. ' help='number of subprocesses per data loader. '
'0 to disable multiprocessing; ' '0 to disable multiprocessing; '
'negative number to multiply by the batch size') 'negative number to multiply by the batch size')
@ -123,6 +123,14 @@ def add_train_args(parser):
parser.add_argument('--seed', default=42, type=int, parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training') help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs for better page caching. '
'Only relevant if there are multiple crops in each field')
parser.add_argument('--div-shuffle-dist', default=1, type=float,
help='distance to further shuffle within each data division. '
'Only relevant if there are multiple crops in each field. '
'The order of each sample is randomly displaced by this value. '
'Change this to balance cache locality and stochasticity')
parser.add_argument('--dist-backend', default='nccl', type=str, parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend') choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int, parser.add_argument('--log-interval', default=100, type=int,

View File

@ -1,2 +1,2 @@
from .fields import FieldDataset from .fields import FieldDataset
from .sampler import GroupedRandomSampler from .sampler import DistFieldSampler

View File

@ -119,14 +119,16 @@ class FieldDataset(Dataset):
'only support integer upsampling' 'only support integer upsampling'
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.nsample = self.nfile * self.ncrop
def __len__(self): def __len__(self):
return self.nfile * self.ncrop return self.nsample
def __getitem__(self, idx): def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop) ifile, icrop = divmod(idx, self.ncrop)
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]] in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]] tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
anchor = self.anchors[icrop] anchor = self.anchors[icrop]
@ -184,7 +186,6 @@ def crop(fields, anchor, crop, pad, size):
ind.append(i) ind.append(i)
x = x[tuple(ind)] x = x[tuple(ind)]
x.setflags(write=True) # workaround numpy bug before 1.18
new_fields.append(x) new_fields.append(x)

View File

@ -1,31 +1,77 @@
from itertools import chain
import torch import torch
import torch.distributed as dist
from torch.utils.data import Sampler from torch.utils.data import Sampler
class GroupedRandomSampler(Sampler): class DistFieldSampler(Sampler):
"""Sample randomly within each group of samples and sequentially from group """Distributed sampler for field data, useful for multiple crops
to group.
This behaves like a simple random sampler by default Stochastic training on fields with multiple crops puts burden on the IO.
A node may load files of the whole field but only need a small part of it.
Numpy memmap can load part of the field, but can also be very slow (even
slower than reading the whole thing)
`div_data` enables data file division among GPUs when `shuffle=True`.
For field with multiple crops, it helps IO by benefiting from the page
cache, but limits stochasticity.
Increase `div_shuffle_dist` can mitigate this by shuffling the order of
samples within the specified distance.
When `div_data=False` this sampler behaves similar to `DistributedSampler`,
except for the chunky (rather than strided) subsample slicing.
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
of each epoch during training.
""" """
def __init__(self, data_source, group_size=None): def __init__(self, dataset, shuffle,
self.data_source = data_source div_data=False, div_shuffle_dist=0):
self.sample_size = len(data_source) self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if group_size is None: self.dataset = dataset
group_size = self.sample_size self.nsample = len(dataset)
self.group_size = group_size self.nfile = dataset.nfile
self.ncrop = dataset.ncrop
self.shuffle = shuffle
self.div_data = div_data
self.div_shuffle_dist = div_shuffle_dist
def __iter__(self): def __iter__(self):
starts = range(0, self.sample_size, self.group_size) if self.shuffle:
sizes = [self.group_size] * (len(starts) - 1) # deterministically shuffle based on epoch
sizes.append(self.sample_size - starts[-1]) g = torch.Generator()
g.manual_seed(self.epoch)
return iter(chain(*[ if self.div_data:
(start + torch.randperm(size)).tolist() # shuffle files
for start, size in zip(starts, sizes) ind = torch.randperm(self.nfile, generator=g)
])) ind = ind[:, None] * self.ncrop + torch.arange(self.ncrop)
ind = ind.flatten()
# displace crops with respect to files
dis = torch.rand((self.nfile, self.ncrop),
generator=g) * self.div_shuffle_dist
loc = torch.arange(self.nfile)
loc = loc[:, None] + dis
loc = loc.flatten() % self.nfile # periodic in files
dis_ind = loc.argsort()
# shuffle crops
ind = ind[dis_ind].tolist()
else:
ind = torch.randperm(self.nsample, generator=g).tolist()
else:
ind = list(range(self.nsample))
start = self.rank * len(self)
stop = start + len(self)
ind = ind[start:stop]
return iter(ind)
def __len__(self): def __len__(self):
return self.sample_size return self.nsample // self.world_size
def set_epoch(self, epoch):
self.epoch = epoch

View File

@ -10,11 +10,10 @@ import torch.optim as optim
import torch.distributed as dist import torch.distributed as dist
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices from .data.figures import plt_slices
from . import models from . import models
from .models import narrow_cast, resample, Lag2Eul from .models import narrow_cast, resample, Lag2Eul
@ -72,7 +71,9 @@ def gpu_worker(local_rank, node, args):
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
) )
train_sampler = DistributedSampler(train_dataset, shuffle=True) train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches, batch_size=args.batches,
@ -100,7 +101,9 @@ def gpu_worker(local_rank, node, args):
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
) )
val_sampler = DistributedSampler(val_dataset, shuffle=False) val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batches, batch_size=args.batches,

View File

@ -1,5 +1,4 @@
from setuptools import setup from setuptools import setup, find_packages
from setuptools import find_packages
setup( setup(
name='map2map', name='map2map',
@ -8,6 +7,9 @@ setup(
author='Yin Li et al.', author='Yin Li et al.',
author_email='eelregit@gmail.com', author_email='eelregit@gmail.com',
packages=find_packages(), packages=find_packages(),
scripts=[
'scripts/m2m.py',
],
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[ install_requires=[
'torch>=1.2', 'torch>=1.2',
@ -17,7 +19,4 @@ setup(
extras_require={ extras_require={
"visualization": ["tensorboard"], "visualization": ["tensorboard"],
}, },
scripts=[
'scripts/m2m.py',
]
) )