Add LRU cache to replace existing cache
This commit is contained in:
parent
4cc2fd51eb
commit
2e687da905
@ -66,7 +66,11 @@ def add_common_args(parser):
|
|||||||
'Default is the batch size or 0 for batch size 1')
|
'Default is the batch size or 0 for batch size 1')
|
||||||
|
|
||||||
parser.add_argument('--cache', action='store_true',
|
parser.add_argument('--cache', action='store_true',
|
||||||
help='enable caching in field datasets')
|
help='enable LRU cache of input and target fields to reduce I/O')
|
||||||
|
parser.add_argument('--cache-maxsize', type=int,
|
||||||
|
help='maximum pairs of fields in cache, unlimited by default. '
|
||||||
|
'This only applies to training if not None, '
|
||||||
|
'in which case the testing cache maxsize is 1')
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
|
from functools import lru_cache
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -13,7 +14,7 @@ class FieldDataset(Dataset):
|
|||||||
`in_patterns` is a list of glob patterns for the input fields.
|
`in_patterns` is a list of glob patterns for the input fields.
|
||||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target samples are matched by sorting the globbed files.
|
Input and target fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
`in_norms` is a list of of functions to normalize the input fields.
|
`in_norms` is a list of of functions to normalize the input fields.
|
||||||
Likewise for `tgt_norms`.
|
Likewise for `tgt_norms`.
|
||||||
@ -30,14 +31,17 @@ class FieldDataset(Dataset):
|
|||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
|
|
||||||
`cache` enables data caching.
|
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
|
||||||
`div_data` enables data division, useful when combined with caching.
|
pairs (unlimited by default).
|
||||||
|
`div_data` enables data division, to be used with `cache`, so that different
|
||||||
|
fields are cached in different GPU processes.
|
||||||
|
This saves CPU RAM but limits stochasticity.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None,
|
in_norms=None, tgt_norms=None,
|
||||||
augment=False, aug_add=None, aug_mul=None,
|
augment=False, aug_add=None, aug_mul=None,
|
||||||
crop=None, pad=0, scale_factor=1,
|
crop=None, pad=0, scale_factor=1,
|
||||||
cache=False, div_data=False,
|
cache=False, cache_maxsize=None, div_data=False,
|
||||||
rank=None, world_size=None):
|
rank=None, world_size=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
@ -46,7 +50,7 @@ class FieldDataset(Dataset):
|
|||||||
self.tgt_files = list(zip(* tgt_file_lists))
|
self.tgt_files = list(zip(* tgt_file_lists))
|
||||||
|
|
||||||
assert len(self.in_files) == len(self.tgt_files), \
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
'input and target sample sizes do not match'
|
'number of input and target fields do not match'
|
||||||
|
|
||||||
assert len(self.in_files) > 0, 'file not found'
|
assert len(self.in_files) > 0, 'file not found'
|
||||||
|
|
||||||
@ -97,10 +101,13 @@ class FieldDataset(Dataset):
|
|||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
self.cache = cache
|
if cache:
|
||||||
if self.cache:
|
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
||||||
self.in_fields = {}
|
|
||||||
self.tgt_fields = {}
|
def get_fields(self, idx):
|
||||||
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files) * self.tot_reps
|
return len(self.in_files) * self.tot_reps
|
||||||
@ -109,18 +116,7 @@ class FieldDataset(Dataset):
|
|||||||
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
||||||
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
||||||
|
|
||||||
if self.cache:
|
in_fields, tgt_fields = self.get_fields(idx)
|
||||||
try:
|
|
||||||
in_fields = self.in_fields[idx]
|
|
||||||
tgt_fields = self.tgt_fields[idx]
|
|
||||||
except KeyError:
|
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
|
||||||
self.in_fields[idx] = in_fields
|
|
||||||
self.tgt_fields[idx] = tgt_fields
|
|
||||||
else:
|
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
||||||
|
@ -67,6 +67,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
|
cache_maxsize=args.cache_maxsize,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
@ -98,6 +99,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
|
cache_maxsize=None if args.cache_maxsize is None else 1,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user