Add grouped cache for data bigger than CPU RAM
This commit is contained in:
parent
2e687da905
commit
5bb2a19933
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .train import ckpt_link
|
from .train import ckpt_link
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||||
help='path to load the states of model, optimizer, rng, etc. '
|
help='path to load the states of model, optimizer, rng, etc. '
|
||||||
'Default is the checkpoint. '
|
'Default is the checkpoint. '
|
||||||
'Start from scratch if the checkpoint does not exist')
|
'Start from scratch if set empty or the checkpoint is missing')
|
||||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
@ -173,6 +174,17 @@ def set_common_args(args):
|
|||||||
if args.batches > 1:
|
if args.batches > 1:
|
||||||
args.loader_workers = args.batches
|
args.loader_workers = args.batches
|
||||||
|
|
||||||
|
if not args.cache and args.cache_maxsize is not None:
|
||||||
|
args.cache_maxsize = None
|
||||||
|
warnings.warn('Resetting cache maxsize given cache is disabled',
|
||||||
|
RuntimeWarning)
|
||||||
|
if (args.cache and args.cache_maxsize is not None
|
||||||
|
and args.cache_maxsize < 1):
|
||||||
|
args.cache = False
|
||||||
|
args.cache_maxsize = None
|
||||||
|
warnings.warn('Disabling cache given cache maxsize < 1',
|
||||||
|
RuntimeWarning)
|
||||||
|
|
||||||
|
|
||||||
def set_train_args(args):
|
def set_train_args(args):
|
||||||
set_common_args(args)
|
set_common_args(args)
|
||||||
|
@ -1 +1,2 @@
|
|||||||
from .fields import FieldDataset
|
from .fields import FieldDataset
|
||||||
|
from .sampler import GroupedRandomSampler
|
||||||
|
@ -11,7 +11,7 @@ from .norms import import_norm
|
|||||||
class FieldDataset(Dataset):
|
class FieldDataset(Dataset):
|
||||||
"""Dataset of lists of fields.
|
"""Dataset of lists of fields.
|
||||||
|
|
||||||
`in_patterns` is a list of glob patterns for the input fields.
|
`in_patterns` is a list of glob patterns for the input field files.
|
||||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target fields are matched by sorting the globbed files.
|
Input and target fields are matched by sorting the globbed files.
|
||||||
@ -51,15 +51,9 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
assert len(self.in_files) == len(self.tgt_files), \
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
'number of input and target fields do not match'
|
'number of input and target fields do not match'
|
||||||
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
assert len(self.in_files) > 0, 'file not found'
|
assert self.nfile > 0, 'file not found'
|
||||||
|
|
||||||
if div_data:
|
|
||||||
files = len(self.in_files) // world_size
|
|
||||||
self.in_files = self.in_files[rank * files : (rank + 1) * files]
|
|
||||||
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
|
|
||||||
|
|
||||||
assert len(self.in_files) > 0, 'files not divisible among ranks'
|
|
||||||
|
|
||||||
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
||||||
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
||||||
@ -92,7 +86,7 @@ class FieldDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.crop = np.broadcast_to(crop, self.size.shape)
|
self.crop = np.broadcast_to(crop, self.size.shape)
|
||||||
self.reps = self.size // self.crop
|
self.reps = self.size // self.crop
|
||||||
self.tot_reps = int(np.prod(self.reps))
|
self.ncrop = int(np.prod(self.reps))
|
||||||
|
|
||||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||||
@ -104,19 +98,44 @@ class FieldDataset(Dataset):
|
|||||||
if cache:
|
if cache:
|
||||||
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
||||||
|
|
||||||
|
if div_data:
|
||||||
|
self.samples = []
|
||||||
|
|
||||||
|
# first add full fields when num_fields > num_GPU
|
||||||
|
for i in range(rank, self.nfile // world_size * world_size,
|
||||||
|
world_size):
|
||||||
|
self.samples.append(
|
||||||
|
range(i * self.ncrop, (i + 1) * self.ncrop)
|
||||||
|
)
|
||||||
|
|
||||||
|
# then split the rest into fractions of fields
|
||||||
|
# drop the last incomplete batch of samples
|
||||||
|
frac_start = self.nfile // world_size * world_size * self.ncrop
|
||||||
|
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
||||||
|
self.samples.append(range(frac_start + rank * frac_samples,
|
||||||
|
frac_start + (rank + 1) * frac_samples))
|
||||||
|
|
||||||
|
self.samples = np.concatenate(self.samples)
|
||||||
|
else:
|
||||||
|
self.samples = np.arange(self.nfile * self.ncrop)
|
||||||
|
self.nsample = len(self.samples)
|
||||||
|
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
def get_fields(self, idx):
|
def get_fields(self, idx):
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files) * self.tot_reps
|
return self.nsample
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
idx = self.samples[idx]
|
||||||
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
|
||||||
|
|
||||||
in_fields, tgt_fields = self.get_fields(idx)
|
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
|
||||||
|
|
||||||
|
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
||||||
|
31
map2map/data/sampler.py
Normal file
31
map2map/data/sampler.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from itertools import chain
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedRandomSampler(Sampler):
|
||||||
|
"""Sample randomly within each group of samples and sequentially from group
|
||||||
|
to group.
|
||||||
|
|
||||||
|
This behaves like a simple random sampler by default
|
||||||
|
"""
|
||||||
|
def __init__(self, data_source, group_size=None):
|
||||||
|
self.data_source = data_source
|
||||||
|
self.sample_size = len(data_source)
|
||||||
|
|
||||||
|
if group_size is None:
|
||||||
|
group_size = self.sample_size
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
starts = range(0, self.sample_size, self.group_size)
|
||||||
|
sizes = [self.group_size] * (len(starts) - 1)
|
||||||
|
sizes.append(self.sample_size - starts[-1])
|
||||||
|
|
||||||
|
return iter(chain(*[
|
||||||
|
(start + torch.randperm(size)).tolist()
|
||||||
|
for start, size in zip(starts, sizes)
|
||||||
|
]))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.sample_size
|
@ -13,7 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import fig3d
|
from .data.figures import fig3d
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_like,
|
from .models import (narrow_like,
|
||||||
@ -48,11 +48,13 @@ def gpu_worker(local_rank, node, args):
|
|||||||
device = torch.device('cuda', local_rank)
|
device = torch.device('cuda', local_rank)
|
||||||
torch.cuda.device(device)
|
torch.cuda.device(device)
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
|
||||||
|
|
||||||
rank = args.gpus_per_node * node + local_rank
|
rank = args.gpus_per_node * node + local_rank
|
||||||
|
|
||||||
|
# Need randomness across processes, for sampler, augmentation, noise etc.
|
||||||
|
# Note DDP broadcasts initial model states from rank 0
|
||||||
|
torch.manual_seed(args.seed + rank)
|
||||||
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||||
|
|
||||||
dist_init(rank, args)
|
dist_init(rank, args)
|
||||||
|
|
||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
@ -72,7 +74,13 @@ def gpu_worker(local_rank, node, args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if args.div_data:
|
||||||
|
train_sampler = GroupedRandomSampler(
|
||||||
|
train_dataset,
|
||||||
|
group_size=None if args.cache_maxsize is None else
|
||||||
|
args.cache_maxsize * train_dataset.ncrop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -80,8 +88,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
shuffle=args.div_data,
|
shuffle=False,
|
||||||
sampler=None if args.div_data else train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
@ -104,7 +112,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if args.div_data:
|
||||||
|
val_sampler = None
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -113,7 +123,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=None if args.div_data else val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
@ -278,6 +288,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
|
if args.cache:
|
||||||
|
print('rank {} train data: {}'.format(
|
||||||
|
rank, train_dataset.get_fields.cache_info()))
|
||||||
|
print('rank {} val data: {}'.format(
|
||||||
|
rank, val_dataset.get_fields.cache_info()))
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user