Add LRU cache to replace existing cache
This commit is contained in:
parent
4cc2fd51eb
commit
2e687da905
3 changed files with 24 additions and 22 deletions
|
@ -66,7 +66,11 @@ def add_common_args(parser):
|
|||
'Default is the batch size or 0 for batch size 1')
|
||||
|
||||
parser.add_argument('--cache', action='store_true',
|
||||
help='enable caching in field datasets')
|
||||
help='enable LRU cache of input and target fields to reduce I/O')
|
||||
parser.add_argument('--cache-maxsize', type=int,
|
||||
help='maximum pairs of fields in cache, unlimited by default. '
|
||||
'This only applies to training if not None, '
|
||||
'in which case the testing cache maxsize is 1')
|
||||
|
||||
|
||||
def add_train_args(parser):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from glob import glob
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -13,7 +14,7 @@ class FieldDataset(Dataset):
|
|||
`in_patterns` is a list of glob patterns for the input fields.
|
||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||
Likewise `tgt_patterns` is for target fields.
|
||||
Input and target samples are matched by sorting the globbed files.
|
||||
Input and target fields are matched by sorting the globbed files.
|
||||
|
||||
`in_norms` is a list of of functions to normalize the input fields.
|
||||
Likewise for `tgt_norms`.
|
||||
|
@ -30,14 +31,17 @@ class FieldDataset(Dataset):
|
|||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
|
||||
pairs (unlimited by default).
|
||||
`div_data` enables data division, to be used with `cache`, so that different
|
||||
fields are cached in different GPU processes.
|
||||
This saves CPU RAM but limits stochasticity.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None,
|
||||
augment=False, aug_add=None, aug_mul=None,
|
||||
crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False,
|
||||
cache=False, cache_maxsize=None, div_data=False,
|
||||
rank=None, world_size=None):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
@ -46,7 +50,7 @@ class FieldDataset(Dataset):
|
|||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'input and target sample sizes do not match'
|
||||
'number of input and target fields do not match'
|
||||
|
||||
assert len(self.in_files) > 0, 'file not found'
|
||||
|
||||
|
@ -97,10 +101,13 @@ class FieldDataset(Dataset):
|
|||
'only support integer upsampling'
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = {}
|
||||
self.tgt_fields = {}
|
||||
if cache:
|
||||
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
||||
|
||||
def get_fields(self, idx):
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
return in_fields, tgt_fields
|
||||
|
||||
def __len__(self):
|
||||
return len(self.in_files) * self.tot_reps
|
||||
|
@ -109,18 +116,7 @@ class FieldDataset(Dataset):
|
|||
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
||||
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
||||
|
||||
if self.cache:
|
||||
try:
|
||||
in_fields = self.in_fields[idx]
|
||||
tgt_fields = self.tgt_fields[idx]
|
||||
except KeyError:
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
self.in_fields[idx] = in_fields
|
||||
self.tgt_fields[idx] = tgt_fields
|
||||
else:
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
in_fields, tgt_fields = self.get_fields(idx)
|
||||
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
||||
|
|
|
@ -67,6 +67,7 @@ def gpu_worker(local_rank, node, args):
|
|||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
cache=args.cache,
|
||||
cache_maxsize=args.cache_maxsize,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
world_size=args.world_size,
|
||||
|
@ -98,6 +99,7 @@ def gpu_worker(local_rank, node, args):
|
|||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
cache=args.cache,
|
||||
cache_maxsize=None if args.cache_maxsize is None else 1,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
world_size=args.world_size,
|
||||
|
|
Loading…
Reference in a new issue