Merge branch 'lag2eul' of github.com:eelregit/map2map into lag2eul
This commit is contained in:
commit
632c73db16
@ -70,7 +70,7 @@ def add_common_args(parser):
|
||||
|
||||
parser.add_argument('--batches', type=int, required=True,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', default=-2, type=int,
|
||||
parser.add_argument('--loader-workers', default=-8, type=int,
|
||||
help='number of subprocesses per data loader. '
|
||||
'0 to disable multiprocessing; '
|
||||
'negative number to multiply by the batch size')
|
||||
@ -123,6 +123,14 @@ def add_train_args(parser):
|
||||
parser.add_argument('--seed', default=42, type=int,
|
||||
help='seed for initializing training')
|
||||
|
||||
parser.add_argument('--div-data', action='store_true',
|
||||
help='enable data division among GPUs for better page caching. '
|
||||
'Only relevant if there are multiple crops in each field')
|
||||
parser.add_argument('--div-shuffle-dist', default=1, type=float,
|
||||
help='distance to further shuffle within each data division. '
|
||||
'Only relevant if there are multiple crops in each field. '
|
||||
'The order of each sample is randomly displaced by this value. '
|
||||
'Change this to balance cache locality and stochasticity')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
choices=['gloo', 'nccl'], help='distributed backend')
|
||||
parser.add_argument('--log-interval', default=100, type=int,
|
||||
|
@ -1,2 +1,2 @@
|
||||
from .fields import FieldDataset
|
||||
from .sampler import GroupedRandomSampler
|
||||
from .sampler import DistFieldSampler
|
||||
|
@ -119,14 +119,16 @@ class FieldDataset(Dataset):
|
||||
'only support integer upsampling'
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.nsample = self.nfile * self.ncrop
|
||||
|
||||
def __len__(self):
|
||||
return self.nfile * self.ncrop
|
||||
return self.nsample
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ifile, icrop = divmod(idx, self.ncrop)
|
||||
|
||||
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
|
||||
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||
|
||||
anchor = self.anchors[icrop]
|
||||
|
||||
@ -184,7 +186,6 @@ def crop(fields, anchor, crop, pad, size):
|
||||
ind.append(i)
|
||||
|
||||
x = x[tuple(ind)]
|
||||
x.setflags(write=True) # workaround numpy bug before 1.18
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
|
@ -1,31 +1,77 @@
|
||||
from itertools import chain
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class GroupedRandomSampler(Sampler):
|
||||
"""Sample randomly within each group of samples and sequentially from group
|
||||
to group.
|
||||
class DistFieldSampler(Sampler):
|
||||
"""Distributed sampler for field data, useful for multiple crops
|
||||
|
||||
This behaves like a simple random sampler by default
|
||||
Stochastic training on fields with multiple crops puts burden on the IO.
|
||||
A node may load files of the whole field but only need a small part of it.
|
||||
Numpy memmap can load part of the field, but can also be very slow (even
|
||||
slower than reading the whole thing)
|
||||
|
||||
`div_data` enables data file division among GPUs when `shuffle=True`.
|
||||
For field with multiple crops, it helps IO by benefiting from the page
|
||||
cache, but limits stochasticity.
|
||||
Increase `div_shuffle_dist` can mitigate this by shuffling the order of
|
||||
samples within the specified distance.
|
||||
|
||||
When `div_data=False` this sampler behaves similar to `DistributedSampler`,
|
||||
except for the chunky (rather than strided) subsample slicing.
|
||||
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
|
||||
of each epoch during training.
|
||||
"""
|
||||
def __init__(self, data_source, group_size=None):
|
||||
self.data_source = data_source
|
||||
self.sample_size = len(data_source)
|
||||
def __init__(self, dataset, shuffle,
|
||||
div_data=False, div_shuffle_dist=0):
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
if group_size is None:
|
||||
group_size = self.sample_size
|
||||
self.group_size = group_size
|
||||
self.dataset = dataset
|
||||
self.nsample = len(dataset)
|
||||
self.nfile = dataset.nfile
|
||||
self.ncrop = dataset.ncrop
|
||||
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.div_data = div_data
|
||||
self.div_shuffle_dist = div_shuffle_dist
|
||||
|
||||
def __iter__(self):
|
||||
starts = range(0, self.sample_size, self.group_size)
|
||||
sizes = [self.group_size] * (len(starts) - 1)
|
||||
sizes.append(self.sample_size - starts[-1])
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
return iter(chain(*[
|
||||
(start + torch.randperm(size)).tolist()
|
||||
for start, size in zip(starts, sizes)
|
||||
]))
|
||||
if self.shuffle:
|
||||
if self.div_data:
|
||||
# shuffle files
|
||||
ind = torch.randperm(self.nfile, generator=g)
|
||||
ind = ind[:, None] * self.ncrop + torch.arange(self.ncrop)
|
||||
ind = ind.flatten()
|
||||
|
||||
# displace crops with respect to files
|
||||
dis = torch.rand((self.nfile, self.ncrop),
|
||||
generator=g) * self.div_shuffle_dist
|
||||
loc = torch.arange(self.nfile)
|
||||
loc = loc[:, None] + dis
|
||||
loc = loc.flatten() % self.nfile # periodic in files
|
||||
dis_ind = loc.argsort()
|
||||
|
||||
# shuffle crops
|
||||
ind = ind[dis_ind].tolist()
|
||||
else:
|
||||
ind = torch.randperm(self.nsample, generator=g).tolist()
|
||||
else:
|
||||
ind = list(range(self.nsample))
|
||||
|
||||
start = self.rank * len(self)
|
||||
stop = start + len(self)
|
||||
ind = ind[start:stop]
|
||||
|
||||
return iter(ind)
|
||||
|
||||
def __len__(self):
|
||||
return self.sample_size
|
||||
return self.nsample // self.world_size
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
@ -10,11 +10,10 @@ import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import spawn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .data import FieldDataset, GroupedRandomSampler
|
||||
from .data import FieldDataset, DistFieldSampler
|
||||
from .data.figures import plt_slices
|
||||
from . import models
|
||||
from .models import narrow_cast, resample, Lag2Eul
|
||||
@ -72,7 +71,9 @@ def gpu_worker(local_rank, node, args):
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
)
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batches,
|
||||
@ -100,7 +101,9 @@ def gpu_worker(local_rank, node, args):
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
)
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
@ -300,6 +303,9 @@ def train(epoch, loader, model, lag2eul, criterion,
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/lxe',
|
||||
lag_loss.item() * eul_loss.item(),
|
||||
global_step=batch)
|
||||
|
||||
logger.add_scalar('grad/lag/first', lag_grads[0],
|
||||
global_step=batch)
|
||||
@ -317,6 +323,8 @@ def train(epoch, loader, model, lag2eul, criterion,
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
||||
global_step=epoch+1)
|
||||
|
||||
logger.add_figure('fig/epoch/train', plt_slices(
|
||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||
@ -365,6 +373,8 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
||||
global_step=epoch+1)
|
||||
|
||||
logger.add_figure('fig/epoch/val', plt_slices(
|
||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||
|
9
setup.py
9
setup.py
@ -1,5 +1,4 @@
|
||||
from setuptools import setup
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='map2map',
|
||||
@ -8,6 +7,9 @@ setup(
|
||||
author='Yin Li et al.',
|
||||
author_email='eelregit@gmail.com',
|
||||
packages=find_packages(),
|
||||
scripts=[
|
||||
'scripts/m2m.py',
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
install_requires=[
|
||||
'torch>=1.2',
|
||||
@ -17,7 +19,4 @@ setup(
|
||||
extras_require={
|
||||
"visualization": ["tensorboard"],
|
||||
},
|
||||
scripts=[
|
||||
'scripts/m2m.py',
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user