Merge branch 'lag2eul' of github.com:eelregit/map2map into lag2eul
This commit is contained in:
commit
632c73db16
@ -70,7 +70,7 @@ def add_common_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--batches', type=int, required=True,
|
parser.add_argument('--batches', type=int, required=True,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', default=-2, type=int,
|
parser.add_argument('--loader-workers', default=-8, type=int,
|
||||||
help='number of subprocesses per data loader. '
|
help='number of subprocesses per data loader. '
|
||||||
'0 to disable multiprocessing; '
|
'0 to disable multiprocessing; '
|
||||||
'negative number to multiply by the batch size')
|
'negative number to multiply by the batch size')
|
||||||
@ -123,6 +123,14 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
help='seed for initializing training')
|
help='seed for initializing training')
|
||||||
|
|
||||||
|
parser.add_argument('--div-data', action='store_true',
|
||||||
|
help='enable data division among GPUs for better page caching. '
|
||||||
|
'Only relevant if there are multiple crops in each field')
|
||||||
|
parser.add_argument('--div-shuffle-dist', default=1, type=float,
|
||||||
|
help='distance to further shuffle within each data division. '
|
||||||
|
'Only relevant if there are multiple crops in each field. '
|
||||||
|
'The order of each sample is randomly displaced by this value. '
|
||||||
|
'Change this to balance cache locality and stochasticity')
|
||||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||||
choices=['gloo', 'nccl'], help='distributed backend')
|
choices=['gloo', 'nccl'], help='distributed backend')
|
||||||
parser.add_argument('--log-interval', default=100, type=int,
|
parser.add_argument('--log-interval', default=100, type=int,
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
from .fields import FieldDataset
|
from .fields import FieldDataset
|
||||||
from .sampler import GroupedRandomSampler
|
from .sampler import DistFieldSampler
|
||||||
|
@ -119,14 +119,16 @@ class FieldDataset(Dataset):
|
|||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
self.nsample = self.nfile * self.ncrop
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nfile * self.ncrop
|
return self.nsample
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
ifile, icrop = divmod(idx, self.ncrop)
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
|
in_fields = [np.load(f) for f in self.in_files[ifile]]
|
||||||
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
anchor = self.anchors[icrop]
|
anchor = self.anchors[icrop]
|
||||||
|
|
||||||
@ -184,7 +186,6 @@ def crop(fields, anchor, crop, pad, size):
|
|||||||
ind.append(i)
|
ind.append(i)
|
||||||
|
|
||||||
x = x[tuple(ind)]
|
x = x[tuple(ind)]
|
||||||
x.setflags(write=True) # workaround numpy bug before 1.18
|
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
@ -1,31 +1,77 @@
|
|||||||
from itertools import chain
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
|
||||||
class GroupedRandomSampler(Sampler):
|
class DistFieldSampler(Sampler):
|
||||||
"""Sample randomly within each group of samples and sequentially from group
|
"""Distributed sampler for field data, useful for multiple crops
|
||||||
to group.
|
|
||||||
|
|
||||||
This behaves like a simple random sampler by default
|
Stochastic training on fields with multiple crops puts burden on the IO.
|
||||||
|
A node may load files of the whole field but only need a small part of it.
|
||||||
|
Numpy memmap can load part of the field, but can also be very slow (even
|
||||||
|
slower than reading the whole thing)
|
||||||
|
|
||||||
|
`div_data` enables data file division among GPUs when `shuffle=True`.
|
||||||
|
For field with multiple crops, it helps IO by benefiting from the page
|
||||||
|
cache, but limits stochasticity.
|
||||||
|
Increase `div_shuffle_dist` can mitigate this by shuffling the order of
|
||||||
|
samples within the specified distance.
|
||||||
|
|
||||||
|
When `div_data=False` this sampler behaves similar to `DistributedSampler`,
|
||||||
|
except for the chunky (rather than strided) subsample slicing.
|
||||||
|
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
|
||||||
|
of each epoch during training.
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_source, group_size=None):
|
def __init__(self, dataset, shuffle,
|
||||||
self.data_source = data_source
|
div_data=False, div_shuffle_dist=0):
|
||||||
self.sample_size = len(data_source)
|
self.rank = dist.get_rank()
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
if group_size is None:
|
self.dataset = dataset
|
||||||
group_size = self.sample_size
|
self.nsample = len(dataset)
|
||||||
self.group_size = group_size
|
self.nfile = dataset.nfile
|
||||||
|
self.ncrop = dataset.ncrop
|
||||||
|
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
self.div_data = div_data
|
||||||
|
self.div_shuffle_dist = div_shuffle_dist
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
starts = range(0, self.sample_size, self.group_size)
|
# deterministically shuffle based on epoch
|
||||||
sizes = [self.group_size] * (len(starts) - 1)
|
g = torch.Generator()
|
||||||
sizes.append(self.sample_size - starts[-1])
|
g.manual_seed(self.epoch)
|
||||||
|
|
||||||
return iter(chain(*[
|
if self.shuffle:
|
||||||
(start + torch.randperm(size)).tolist()
|
if self.div_data:
|
||||||
for start, size in zip(starts, sizes)
|
# shuffle files
|
||||||
]))
|
ind = torch.randperm(self.nfile, generator=g)
|
||||||
|
ind = ind[:, None] * self.ncrop + torch.arange(self.ncrop)
|
||||||
|
ind = ind.flatten()
|
||||||
|
|
||||||
|
# displace crops with respect to files
|
||||||
|
dis = torch.rand((self.nfile, self.ncrop),
|
||||||
|
generator=g) * self.div_shuffle_dist
|
||||||
|
loc = torch.arange(self.nfile)
|
||||||
|
loc = loc[:, None] + dis
|
||||||
|
loc = loc.flatten() % self.nfile # periodic in files
|
||||||
|
dis_ind = loc.argsort()
|
||||||
|
|
||||||
|
# shuffle crops
|
||||||
|
ind = ind[dis_ind].tolist()
|
||||||
|
else:
|
||||||
|
ind = torch.randperm(self.nsample, generator=g).tolist()
|
||||||
|
else:
|
||||||
|
ind = list(range(self.nsample))
|
||||||
|
|
||||||
|
start = self.rank * len(self)
|
||||||
|
stop = start + len(self)
|
||||||
|
ind = ind[start:stop]
|
||||||
|
|
||||||
|
return iter(ind)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.sample_size
|
return self.nsample // self.world_size
|
||||||
|
|
||||||
|
def set_epoch(self, epoch):
|
||||||
|
self.epoch = epoch
|
||||||
|
@ -10,11 +10,10 @@ import torch.optim as optim
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.multiprocessing import spawn
|
from torch.multiprocessing import spawn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, DistFieldSampler
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_cast, resample, Lag2Eul
|
from .models import narrow_cast, resample, Lag2Eul
|
||||||
@ -72,7 +71,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
)
|
)
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||||
|
div_data=args.div_data,
|
||||||
|
div_shuffle_dist=args.div_shuffle_dist)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -100,7 +101,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
)
|
)
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||||
|
div_data=args.div_data,
|
||||||
|
div_shuffle_dist=args.div_shuffle_dist)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -300,6 +303,9 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
global_step=batch)
|
global_step=batch)
|
||||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
logger.add_scalar('loss/batch/train/lxe',
|
||||||
|
lag_loss.item() * eul_loss.item(),
|
||||||
|
global_step=batch)
|
||||||
|
|
||||||
logger.add_scalar('grad/lag/first', lag_grads[0],
|
logger.add_scalar('grad/lag/first', lag_grads[0],
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
@ -317,6 +323,8 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
||||||
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/train', plt_slices(
|
logger.add_figure('fig/epoch/train', plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
@ -365,6 +373,8 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
||||||
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/val', plt_slices(
|
logger.add_figure('fig/epoch/val', plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
|
9
setup.py
9
setup.py
@ -1,5 +1,4 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup, find_packages
|
||||||
from setuptools import find_packages
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='map2map',
|
name='map2map',
|
||||||
@ -8,6 +7,9 @@ setup(
|
|||||||
author='Yin Li et al.',
|
author='Yin Li et al.',
|
||||||
author_email='eelregit@gmail.com',
|
author_email='eelregit@gmail.com',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
|
scripts=[
|
||||||
|
'scripts/m2m.py',
|
||||||
|
],
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'torch>=1.2',
|
'torch>=1.2',
|
||||||
@ -17,7 +19,4 @@ setup(
|
|||||||
extras_require={
|
extras_require={
|
||||||
"visualization": ["tensorboard"],
|
"visualization": ["tensorboard"],
|
||||||
},
|
},
|
||||||
scripts=[
|
|
||||||
'scripts/m2m.py',
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user