Add LRU cache to replace existing cache

This commit is contained in:
Yin Li 2020-05-16 17:57:01 -04:00
parent 4cc2fd51eb
commit 2e687da905
3 changed files with 24 additions and 22 deletions

View file

@ -66,7 +66,11 @@ def add_common_args(parser):
'Default is the batch size or 0 for batch size 1')
parser.add_argument('--cache', action='store_true',
help='enable caching in field datasets')
help='enable LRU cache of input and target fields to reduce I/O')
parser.add_argument('--cache-maxsize', type=int,
help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, '
'in which case the testing cache maxsize is 1')
def add_train_args(parser):

View file

@ -1,4 +1,5 @@
from glob import glob
from functools import lru_cache
import numpy as np
import torch
import torch.nn.functional as F
@ -13,7 +14,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input fields.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Likewise `tgt_patterns` is for target fields.
Input and target samples are matched by sorting the globbed files.
Input and target fields are matched by sorting the globbed files.
`in_norms` is a list of of functions to normalize the input fields.
Likewise for `tgt_norms`.
@ -30,14 +31,17 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
pairs (unlimited by default).
`div_data` enables data division, to be used with `cache`, so that different
fields are cached in different GPU processes.
This saves CPU RAM but limits stochasticity.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None,
augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1,
cache=False, div_data=False,
cache=False, cache_maxsize=None, div_data=False,
rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -46,7 +50,7 @@ class FieldDataset(Dataset):
self.tgt_files = list(zip(* tgt_file_lists))
assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match'
'number of input and target fields do not match'
assert len(self.in_files) > 0, 'file not found'
@ -97,10 +101,13 @@ class FieldDataset(Dataset):
'only support integer upsampling'
self.scale_factor = scale_factor
self.cache = cache
if self.cache:
self.in_fields = {}
self.tgt_fields = {}
if cache:
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
def get_fields(self, idx):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
return in_fields, tgt_fields
def __len__(self):
return len(self.in_files) * self.tot_reps
@ -109,18 +116,7 @@ class FieldDataset(Dataset):
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
start = np.unravel_index(sub_idx, self.reps) * self.crop
if self.cache:
try:
in_fields = self.in_fields[idx]
tgt_fields = self.tgt_fields[idx]
except KeyError:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
self.in_fields[idx] = in_fields
self.tgt_fields[idx] = tgt_fields
else:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields, tgt_fields = self.get_fields(idx)
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor,

View file

@ -67,6 +67,7 @@ def gpu_worker(local_rank, node, args):
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=args.cache_maxsize,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
@ -98,6 +99,7 @@ def gpu_worker(local_rank, node, args):
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=None if args.cache_maxsize is None else 1,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,