Add DistFieldSampler to benefit from page cache

This commit is contained in:
Yin Li 2020-07-21 21:13:52 -07:00
parent ec46f41ba5
commit 9f54e02c3a
6 changed files with 91 additions and 34 deletions

View File

@ -70,7 +70,7 @@ def add_common_args(parser):
parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=-2, type=int,
parser.add_argument('--loader-workers', default=-8, type=int,
help='number of subprocesses per data loader. '
'0 to disable multiprocessing; '
'negative number to multiply by the batch size')
@ -123,6 +123,14 @@ def add_train_args(parser):
parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs for better page caching. '
'Only relevant if there are multiple crops in each field')
parser.add_argument('--div-shuffle-dist', default=1, type=float,
help='distance to further shuffle within each data division. '
'Only relevant if there are multiple crops in each field. '
'The order of each sample is randomly displaced by this value. '
'Change this to balance cache locality and stochasticity')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int,

View File

@ -1,2 +1,2 @@
from .fields import FieldDataset
from .sampler import GroupedRandomSampler
from .sampler import DistFieldSampler

View File

@ -119,14 +119,16 @@ class FieldDataset(Dataset):
'only support integer upsampling'
self.scale_factor = scale_factor
self.nsample = self.nfile * self.ncrop
def __len__(self):
return self.nfile * self.ncrop
return self.nsample
def __getitem__(self, idx):
ifile, icrop = divmod(idx, self.ncrop)
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
in_fields = [np.load(f) for f in self.in_files[ifile]]
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
anchor = self.anchors[icrop]
@ -184,7 +186,6 @@ def crop(fields, anchor, crop, pad, size):
ind.append(i)
x = x[tuple(ind)]
x.setflags(write=True) # workaround numpy bug before 1.18
new_fields.append(x)

View File

@ -1,31 +1,77 @@
from itertools import chain
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
class GroupedRandomSampler(Sampler):
"""Sample randomly within each group of samples and sequentially from group
to group.
class DistFieldSampler(Sampler):
"""Distributed sampler for field data, useful for multiple crops
This behaves like a simple random sampler by default
Stochastic training on fields with multiple crops puts burden on the IO.
A node may load files of the whole field but only need a small part of it.
Numpy memmap can load part of the field, but can also be very slow (even
slower than reading the whole thing)
`div_data` enables data file division among GPUs when `shuffle=True`.
For field with multiple crops, it helps IO by benefiting from the page
cache, but limits stochasticity.
Increase `div_shuffle_dist` can mitigate this by shuffling the order of
samples within the specified distance.
When `div_data=False` this sampler behaves similar to `DistributedSampler`,
except for the chunky (rather than strided) subsample slicing.
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
of each epoch during training.
"""
def __init__(self, data_source, group_size=None):
self.data_source = data_source
self.sample_size = len(data_source)
def __init__(self, dataset, shuffle,
div_data=False, div_shuffle_dist=0):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if group_size is None:
group_size = self.sample_size
self.group_size = group_size
self.dataset = dataset
self.nsample = len(dataset)
self.nfile = dataset.nfile
self.ncrop = dataset.ncrop
self.shuffle = shuffle
self.div_data = div_data
self.div_shuffle_dist = div_shuffle_dist
def __iter__(self):
starts = range(0, self.sample_size, self.group_size)
sizes = [self.group_size] * (len(starts) - 1)
sizes.append(self.sample_size - starts[-1])
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
return iter(chain(*[
(start + torch.randperm(size)).tolist()
for start, size in zip(starts, sizes)
]))
if self.shuffle:
if self.div_data:
# shuffle files
ind = torch.randperm(self.nfile, generator=g)
ind = ind[:, None] * self.ncrop + torch.arange(self.ncrop)
ind = ind.flatten()
# displace crops with respect to files
dis = torch.rand((self.nfile, self.ncrop),
generator=g) * self.div_shuffle_dist
loc = torch.arange(self.nfile)
loc = loc[:, None] + dis
loc = loc.flatten() % self.nfile # periodic in files
dis_ind = loc.argsort()
# shuffle crops
ind = ind[dis_ind].tolist()
else:
ind = torch.randperm(self.nsample, generator=g).tolist()
else:
ind = list(range(self.nsample))
start = self.rank * len(self)
stop = start + len(self)
ind = ind[start:stop]
return iter(ind)
def __len__(self):
return self.sample_size
return self.nsample // self.world_size
def set_epoch(self, epoch):
self.epoch = epoch

View File

@ -10,11 +10,10 @@ import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices
from . import models
from .models import narrow_cast, resample, Lag2Eul
@ -72,7 +71,9 @@ def gpu_worker(local_rank, node, args):
pad=args.pad,
scale_factor=args.scale_factor,
)
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
@ -100,7 +101,9 @@ def gpu_worker(local_rank, node, args):
pad=args.pad,
scale_factor=args.scale_factor,
)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,

View File

@ -1,5 +1,4 @@
from setuptools import setup
from setuptools import find_packages
from setuptools import setup, find_packages
setup(
name='map2map',
@ -8,6 +7,9 @@ setup(
author='Yin Li et al.',
author_email='eelregit@gmail.com',
packages=find_packages(),
scripts=[
'scripts/m2m.py',
],
python_requires='>=3.6',
install_requires=[
'torch>=1.2',
@ -17,7 +19,4 @@ setup(
extras_require={
"visualization": ["tensorboard"],
},
scripts=[
'scripts/m2m.py',
]
)