Add training
This commit is contained in:
parent
6015dd6b31
commit
88bfd11594
0
map2map/__init__.py
Normal file
0
map2map/__init__.py
Normal file
76
map2map/args.py
Normal file
76
map2map/args.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = ArgumentParser(description='Transform field(s) to field(s)')
|
||||||
|
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
||||||
|
train_parser = subparsers.add_parser('train')
|
||||||
|
test_parser = subparsers.add_parser('test')
|
||||||
|
|
||||||
|
add_train_args(train_parser)
|
||||||
|
add_test_args(test_parser)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def add_common_args(parser):
|
||||||
|
parser.add_argument('--in-channels', type=int, required=True,
|
||||||
|
help='number of input channels')
|
||||||
|
parser.add_argument('--out-channels', type=int, required=True,
|
||||||
|
help='number of output or target channels')
|
||||||
|
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||||
|
'of normalization functions from map2map.data.norms')
|
||||||
|
parser.add_argument('--criterion', default='MSELoss',
|
||||||
|
help='model criterion from torch.nn')
|
||||||
|
parser.add_argument('--load-state', default='', type=str,
|
||||||
|
help='path to load model, optimizer, rng, etc.')
|
||||||
|
|
||||||
|
|
||||||
|
def add_train_args(parser):
|
||||||
|
add_common_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for training input data')
|
||||||
|
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for training target data')
|
||||||
|
parser.add_argument('--val-in-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for validation input data')
|
||||||
|
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
|
parser.add_argument('--epochs', default=128, type=int,
|
||||||
|
help='total number of epochs to run')
|
||||||
|
parser.add_argument('--batches-per-gpu', default=8, type=int,
|
||||||
|
help='mini-batch size per GPU')
|
||||||
|
parser.add_argument('--loader-workers-per-gpu', default=4, type=int,
|
||||||
|
help='number of data loading workers per GPU')
|
||||||
|
parser.add_argument('--augment', action='store_true',
|
||||||
|
help='enable training data augmentation')
|
||||||
|
parser.add_argument('--optimizer', default='Adam',
|
||||||
|
help='optimizer from torch.optim')
|
||||||
|
parser.add_argument('--lr', default=0.001, type=float,
|
||||||
|
help='initial learning rate')
|
||||||
|
# parser.add_argument('--momentum', default=0.9, type=float,
|
||||||
|
# help='momentum')
|
||||||
|
# parser.add_argument('--weight-decay', default=1e-4, type=float,
|
||||||
|
# help='weight decay')
|
||||||
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||||
|
choices=['gloo', 'nccl'], help='distributed backend')
|
||||||
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
help='seed for initializing training')
|
||||||
|
parser.add_argument('--log-interval', default=20, type=int,
|
||||||
|
help='interval between logging training loss')
|
||||||
|
|
||||||
|
|
||||||
|
def add_test_args(parser):
|
||||||
|
add_common_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for test input data')
|
||||||
|
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||||
|
help='comma-sep. list of glob patterns for test target data')
|
||||||
|
|
||||||
|
|
||||||
|
def str_list(s):
|
||||||
|
return s.split(',')
|
1
map2map/data/__init__.py
Normal file
1
map2map/data/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .fields import FieldDataset
|
101
map2map/data/fields.py
Normal file
101
map2map/data/fields.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
from glob import glob
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from . import norms
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDataset(Dataset):
|
||||||
|
"""Dataset of lists of fields.
|
||||||
|
|
||||||
|
`in_patterns` is a list of glob patterns for the input fields.
|
||||||
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
|
Likewise `tgt_patterns` is for target fields.
|
||||||
|
Input and target samples of all fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
|
Data augmentations are supported for scalar and vector fields.
|
||||||
|
|
||||||
|
`normalize` can be a list of callables to normalize each field.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_patterns, tgt_patterns, augment=False,
|
||||||
|
normalize=None, **kwargs):
|
||||||
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
|
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||||
|
self.tgt_files = list(zip(* tgt_file_lists))
|
||||||
|
|
||||||
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
|
'input and target sample sizes do not match'
|
||||||
|
|
||||||
|
self.augment = augment
|
||||||
|
|
||||||
|
self.normalize = normalize
|
||||||
|
if self.normalize is not None:
|
||||||
|
assert len(in_patterns) == len(self.normalize), \
|
||||||
|
'numbers of normalization callables and input fields do not match'
|
||||||
|
|
||||||
|
# self.__dict__.update(kwargs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.in_files)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
in_fields = [torch.from_numpy(np.load(f)).to(torch.float32)
|
||||||
|
for f in self.in_files[idx]]
|
||||||
|
tgt_fields = [torch.from_numpy(np.load(f)).to(torch.float32)
|
||||||
|
for f in self.tgt_files[idx]]
|
||||||
|
|
||||||
|
if self.augment:
|
||||||
|
flip_axes = torch.randint(2, (3,), dtype=torch.bool)
|
||||||
|
flip_axes = torch.arange(3)[flip_axes]
|
||||||
|
|
||||||
|
flip3d(in_fields, flip_axes)
|
||||||
|
flip3d(tgt_fields, flip_axes)
|
||||||
|
|
||||||
|
perm_axes = torch.randperm(3)
|
||||||
|
|
||||||
|
perm3d(in_fields, perm_axes)
|
||||||
|
perm3d(tgt_fields, perm_axes)
|
||||||
|
|
||||||
|
if self.normalize is not None:
|
||||||
|
def get_norm(path):
|
||||||
|
path = path.split('.')
|
||||||
|
norm = norms
|
||||||
|
while path:
|
||||||
|
norm = norm.__dict__[path.pop(0)]
|
||||||
|
return norm
|
||||||
|
|
||||||
|
for norm, ifield, tfield in zip(self.normalize, in_fields, tgt_fields):
|
||||||
|
if isinstance(norm, str):
|
||||||
|
norm = get_norm(norm)
|
||||||
|
|
||||||
|
norm(ifield)
|
||||||
|
norm(tfield)
|
||||||
|
|
||||||
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
|
def flip3d(fields, axes):
|
||||||
|
for i, x in enumerate(fields):
|
||||||
|
if x.size(0) == 3: # flip vector components
|
||||||
|
x[axes] = - x[axes]
|
||||||
|
|
||||||
|
axes = (1 + axes).tolist()
|
||||||
|
x = torch.flip(x, axes)
|
||||||
|
|
||||||
|
fields[i] = x
|
||||||
|
|
||||||
|
def perm3d(fields, axes):
|
||||||
|
for i, x in enumerate(fields):
|
||||||
|
if x.size(0) == 3: # permutate vector components
|
||||||
|
x = x[axes]
|
||||||
|
|
||||||
|
axes = [0] + (1 + axes).tolist()
|
||||||
|
x = x.permute(axes)
|
||||||
|
|
||||||
|
fields[i] = x
|
1
map2map/data/norms/__init__.py
Normal file
1
map2map/data/norms/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from . import cosmology
|
56
map2map/data/norms/cosmology.py
Normal file
56
map2map/data/norms/cosmology.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.special import hyp2f1
|
||||||
|
|
||||||
|
|
||||||
|
def dis(x, undo=False):
|
||||||
|
z = 0 # FIXME
|
||||||
|
dis_norm = 6 * D(z) # [Mpc/h]
|
||||||
|
|
||||||
|
if not undo:
|
||||||
|
dis_norm = 1 / dis_norm
|
||||||
|
|
||||||
|
x *= dis_norm
|
||||||
|
|
||||||
|
def vel(x, undo=False):
|
||||||
|
z = 0 # FIXME
|
||||||
|
vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||||
|
|
||||||
|
if not undo:
|
||||||
|
vel_norm = 1 / vel_norm
|
||||||
|
|
||||||
|
x *= vel_norm
|
||||||
|
|
||||||
|
def den(x, undo=False):
|
||||||
|
raise NotImplementedError
|
||||||
|
z = 0 # FIXME
|
||||||
|
den_norm = 0 # FIXME
|
||||||
|
|
||||||
|
if not undo:
|
||||||
|
den_norm = 1 / den_norm
|
||||||
|
|
||||||
|
x *= den_norm
|
||||||
|
|
||||||
|
|
||||||
|
def D(z, Om=0.31):
|
||||||
|
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
||||||
|
"""
|
||||||
|
OL = 1 - Om
|
||||||
|
a = 1 / (1+z)
|
||||||
|
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
|
||||||
|
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
|
||||||
|
|
||||||
|
def f(z, Om=0.31):
|
||||||
|
"""linear growth rate for flat LambdaCDM
|
||||||
|
"""
|
||||||
|
OL = 1 - Om
|
||||||
|
a = 1 / (1+z)
|
||||||
|
aa3 = OL * a**3 / Om
|
||||||
|
return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \
|
||||||
|
/ hyp2f1(1, 1/3, 11/6, -aa3)
|
||||||
|
|
||||||
|
def H(z, Om=0.31):
|
||||||
|
"""Hubble in [h km/s/Mpc] for flat LambdaCDM
|
||||||
|
"""
|
||||||
|
OL = 1 - Om
|
||||||
|
a = 1 / (1+z)
|
||||||
|
return 100 * np.sqrt(Om / a**3 + OL)
|
13
map2map/main.py
Normal file
13
map2map/main.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from .args import get_args
|
||||||
|
from . import train
|
||||||
|
from . import test
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
train.node_worker(args)
|
||||||
|
elif args.mode == 'test':
|
||||||
|
pass
|
2
map2map/models/__init__.py
Normal file
2
map2map/models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .unet import UNet
|
||||||
|
from .conv import narrow_like
|
68
map2map/models/conv.py
Normal file
68
map2map/models/conv.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBAC'):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
if mid_channels is None:
|
||||||
|
self.mid_channels = max(in_channels, out_channels)
|
||||||
|
|
||||||
|
self.bn_channels = in_channels
|
||||||
|
self.idx_conv = 0
|
||||||
|
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||||
|
|
||||||
|
layers = [self._get_layer(l) for l in seq]
|
||||||
|
|
||||||
|
self.convs = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _get_layer(self, l):
|
||||||
|
if l == 'U':
|
||||||
|
in_channels, out_channels = self._setup_conv()
|
||||||
|
return nn.ConvTranspose3d(in_channels, out_channels, 2, stride=2)
|
||||||
|
elif l == 'D':
|
||||||
|
in_channels, out_channels = self._setup_conv()
|
||||||
|
return nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
||||||
|
elif l == 'C':
|
||||||
|
in_channels, out_channels = self._setup_conv()
|
||||||
|
return nn.Conv3d(in_channels, out_channels, 3)
|
||||||
|
elif l == 'B':
|
||||||
|
return nn.BatchNorm3d(self.bn_channels)
|
||||||
|
elif l == 'A':
|
||||||
|
return nn.LeakyReLU(inplace=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
|
def _setup_conv(self):
|
||||||
|
self.idx_conv += 1
|
||||||
|
|
||||||
|
in_channels = out_channels = self.mid_channels
|
||||||
|
if self.idx_conv == 1:
|
||||||
|
in_channels = self.in_channels
|
||||||
|
if self.idx_conv == self.num_conv:
|
||||||
|
out_channels = self.out_channels
|
||||||
|
|
||||||
|
self.bn_channels = out_channels
|
||||||
|
|
||||||
|
return in_channels, out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.convs(x)
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_like(a, b):
|
||||||
|
"""Narrow a to be like b.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference,
|
||||||
|
consistent with the downsampling.
|
||||||
|
"""
|
||||||
|
for dim in range(2, 5):
|
||||||
|
width = a.size(dim) - b.size(dim)
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(dim, half_width, a.size(dim) - width)
|
||||||
|
return a
|
53
map2map/models/unet.py
Normal file
53
map2map/models/unet.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .conv import ConvBlock, narrow_like
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock(ConvBlock):
|
||||||
|
def __init__(self, in_channels, out_channels, seq='BADBA'):
|
||||||
|
super().__init__(in_channels, out_channels, seq=seq)
|
||||||
|
|
||||||
|
class UpBlock(ConvBlock):
|
||||||
|
def __init__(self, in_channels, out_channels, seq='BAUBA'):
|
||||||
|
super().__init__(in_channels, out_channels, seq=seq)
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
||||||
|
self.down_0l = DownBlock(64, 64)
|
||||||
|
self.conv_1l = ConvBlock(64, 64)
|
||||||
|
self.down_1l = DownBlock(64, 64)
|
||||||
|
|
||||||
|
self.conv_2c = ConvBlock(64, 64)
|
||||||
|
|
||||||
|
self.up_1r = UpBlock(64, 64)
|
||||||
|
self.conv_1r = ConvBlock(128, 64)
|
||||||
|
self.up_0r = UpBlock(64, 64)
|
||||||
|
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y0 = self.conv_0l(x)
|
||||||
|
x = self.down_0l(y0)
|
||||||
|
|
||||||
|
y1 = self.conv_1l(x)
|
||||||
|
x = self.down_1l(y1)
|
||||||
|
|
||||||
|
x = self.conv_2c(x)
|
||||||
|
|
||||||
|
x = self.up_1r(x)
|
||||||
|
y1 = narrow_like(y1, x)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_1r(x)
|
||||||
|
|
||||||
|
x = self.up_0r(x)
|
||||||
|
y0 = narrow_like(y0, x)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_0r(x)
|
||||||
|
|
||||||
|
return x
|
8
map2map/test.py
Normal file
8
map2map/test.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .data import FieldDataset
|
||||||
|
from .models import UNet, narrow_like
|
194
map2map/train.py
Normal file
194
map2map/train.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.multiprocessing import spawn
|
||||||
|
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .data import FieldDataset
|
||||||
|
from .models import UNet, narrow_like
|
||||||
|
|
||||||
|
|
||||||
|
def node_worker(args):
|
||||||
|
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
||||||
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||||
|
|
||||||
|
args.gpus_per_node = torch.cuda.device_count()
|
||||||
|
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
||||||
|
args.world_size = args.gpus_per_node * args.nodes
|
||||||
|
|
||||||
|
node = int(os.environ['SLURM_NODEID'])
|
||||||
|
if node == 0:
|
||||||
|
print(args)
|
||||||
|
args.node = node
|
||||||
|
|
||||||
|
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node)
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_worker(local_rank, args):
|
||||||
|
args.device = torch.device('cuda', local_rank)
|
||||||
|
torch.cuda.device(args.device)
|
||||||
|
|
||||||
|
args.rank = args.gpus_per_node * args.node + local_rank
|
||||||
|
|
||||||
|
init_process_group(
|
||||||
|
backend=args.dist_backend,
|
||||||
|
init_method='env://',
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = FieldDataset(
|
||||||
|
in_patterns=args.train_in_patterns,
|
||||||
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
|
augment=args.augment,
|
||||||
|
normalize=args.norms,
|
||||||
|
)
|
||||||
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batches_per_gpu,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=args.loader_workers_per_gpu,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = FieldDataset(
|
||||||
|
in_patterns=args.val_in_patterns,
|
||||||
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
|
augment=False,
|
||||||
|
normalize=args.norms,
|
||||||
|
)
|
||||||
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=args.batches_per_gpu,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=args.loader_workers_per_gpu,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
model = UNet(args.in_channels, args.out_channels)
|
||||||
|
model.to(args.device)
|
||||||
|
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||||
|
|
||||||
|
criterion = torch.nn.__dict__[args.criterion]()
|
||||||
|
criterion.to(args.device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.__dict__[args.optimizer](
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
#momentum=args.momentum,
|
||||||
|
#weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||||
|
|
||||||
|
if args.load_state:
|
||||||
|
checkpoint = torch.load(args.load_state, map_location=args.device)
|
||||||
|
args.start_epoch = checkpoint['epoch']
|
||||||
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||||
|
torch.set_rng_state(checkpoint['rng'].cpu()) # move rng state back
|
||||||
|
if args.rank == 0:
|
||||||
|
min_loss = checkpoint['min_loss']
|
||||||
|
print('checkpoint of epoch {} loaded from {}'.format(
|
||||||
|
checkpoint['epoch'], args.load_state))
|
||||||
|
del checkpoint
|
||||||
|
else:
|
||||||
|
args.start_epoch = 0
|
||||||
|
if args.rank == 0:
|
||||||
|
min_loss = None
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
args.logger = SummaryWriter()
|
||||||
|
hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
||||||
|
else str(v) for k, v in vars(args).items()}
|
||||||
|
args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||||
|
|
||||||
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
train(epoch, train_loader, model, criterion, optimizer, args)
|
||||||
|
|
||||||
|
val_loss = validate(epoch, val_loader, model, criterion, args)
|
||||||
|
|
||||||
|
scheduler.step(val_loss)
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
args.logger.close()
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'model': model.state_dict(),
|
||||||
|
'optimizer' : optimizer.state_dict(),
|
||||||
|
'scheduler' : scheduler.state_dict(),
|
||||||
|
'rng' : torch.get_rng_state(),
|
||||||
|
'min_loss': min_loss,
|
||||||
|
}
|
||||||
|
filename='checkpoint.pth'
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
del checkpoint
|
||||||
|
|
||||||
|
if min_loss is None or val_loss < min_loss:
|
||||||
|
min_loss = val_loss
|
||||||
|
shutil.copyfile(filename, 'best_model.pth')
|
||||||
|
|
||||||
|
destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def train(epoch, loader, model, criterion, optimizer, args):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for i, (input, target) in enumerate(loader):
|
||||||
|
input = input.to(args.device, non_blocking=True)
|
||||||
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
|
output = model(input)
|
||||||
|
target = narrow_like(target, output)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch = epoch * len(loader) + i
|
||||||
|
if batch % args.log_interval == 0:
|
||||||
|
all_reduce(loss)
|
||||||
|
loss /= args.world_size
|
||||||
|
if args.rank == 0:
|
||||||
|
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
|
||||||
|
|
||||||
|
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
|
||||||
|
|
||||||
|
def validate(epoch, loader, model, criterion, args):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (input, target) in enumerate(loader):
|
||||||
|
input = input.to(args.device, non_blocking=True)
|
||||||
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
|
output = model(input)
|
||||||
|
target = narrow_like(target, output)
|
||||||
|
|
||||||
|
loss += criterion(output, target)
|
||||||
|
|
||||||
|
all_reduce(loss)
|
||||||
|
loss /= len(loader) * args.world_size
|
||||||
|
if args.rank == 0:
|
||||||
|
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch)
|
||||||
|
|
||||||
|
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
|
||||||
|
|
||||||
|
return loss.item()
|
0
map2map/utils/__init__.py
Normal file
0
map2map/utils/__init__.py
Normal file
53
scripts/dis2dis.slurm
Normal file
53
scripts/dis2dis.slurm
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#SBATCH --job-name=dis2dis
|
||||||
|
#SBATCH --dependency=singleton
|
||||||
|
#SBATCH --output=%x-%j.out
|
||||||
|
#SBATCH --error=%x-%j.err
|
||||||
|
|
||||||
|
#SBATCH --partition=gpu
|
||||||
|
#SBATCH --gres=gpu:v100-32gb:4
|
||||||
|
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --nodes=2
|
||||||
|
#SBATCH --mem=0
|
||||||
|
#SBATCH --time=2-00:00:00
|
||||||
|
|
||||||
|
|
||||||
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
|
module load gcc openmpi2
|
||||||
|
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
||||||
|
|
||||||
|
source $HOME/anaconda3/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
|
export MASTER_ADDR=$HOSTNAME
|
||||||
|
export MASTER_PORT=8888
|
||||||
|
|
||||||
|
|
||||||
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
|
|
||||||
|
in_dir="linear"
|
||||||
|
tgt_dir="nonlin"
|
||||||
|
|
||||||
|
train_dirs="*[1-9]"
|
||||||
|
val_dirs="*[1-9]0"
|
||||||
|
|
||||||
|
files="dis/128x???.npy"
|
||||||
|
in_files="$files"
|
||||||
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
|
srun m2m.py train \
|
||||||
|
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files" \
|
||||||
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
|
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
||||||
|
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4
|
||||||
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
date
|
5
scripts/m2m.py
Normal file
5
scripts/m2m.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from map2map.main import main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
53
scripts/vel2vel.slurm
Normal file
53
scripts/vel2vel.slurm
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#SBATCH --job-name=vel2vel
|
||||||
|
#SBATCH --dependency=singleton
|
||||||
|
#SBATCH --output=%x-%j.out
|
||||||
|
#SBATCH --error=%x-%j.err
|
||||||
|
|
||||||
|
#SBATCH --partition=gpu
|
||||||
|
#SBATCH --gres=gpu:v100-32gb:4
|
||||||
|
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --nodes=2
|
||||||
|
#SBATCH --mem=0
|
||||||
|
#SBATCH --time=2-00:00:00
|
||||||
|
|
||||||
|
|
||||||
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
|
module load gcc openmpi2
|
||||||
|
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
||||||
|
|
||||||
|
source $HOME/anaconda3/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
|
export MASTER_ADDR=$HOSTNAME
|
||||||
|
export MASTER_PORT=8888
|
||||||
|
|
||||||
|
|
||||||
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
|
|
||||||
|
in_dir="linear"
|
||||||
|
tgt_dir="nonlin"
|
||||||
|
|
||||||
|
train_dirs="*[1-9]"
|
||||||
|
val_dirs="*[1-9]0"
|
||||||
|
|
||||||
|
files="vel/128x???.npy"
|
||||||
|
in_files="$files"
|
||||||
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
|
srun m2m.py train \
|
||||||
|
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files" \
|
||||||
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
|
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
||||||
|
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4
|
||||||
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
date
|
20
setup.py
Normal file
20
setup.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
from setuptools import find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='map2map',
|
||||||
|
version='0.0',
|
||||||
|
description='Neural network emulators to transform field data',
|
||||||
|
author='Yin Li',
|
||||||
|
author_email='eelregit@gmail.com',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
'torch',
|
||||||
|
'numpy',
|
||||||
|
'scipy',
|
||||||
|
'tensorboard',
|
||||||
|
],
|
||||||
|
scripts=[
|
||||||
|
'scripts/m2m.py',
|
||||||
|
]
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user