Add testing
This commit is contained in:
parent
bcf95275f3
commit
0211eed0ec
@ -21,11 +21,22 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--out-channels', type=int, required=True,
|
parser.add_argument('--out-channels', type=int, required=True,
|
||||||
help='number of output or target channels')
|
help='number of output or target channels')
|
||||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||||
'of normalization functions from map2map.data.norms')
|
'of normalization functions from data.norms')
|
||||||
parser.add_argument('--criterion', default='MSELoss',
|
parser.add_argument('--criterion', default='MSELoss',
|
||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
parser.add_argument('--load-state', default='', type=str,
|
parser.add_argument('--load-state', default='', type=str,
|
||||||
help='path to load model, optimizer, rng, etc.')
|
help='path to load model, optimizer, rng, etc.')
|
||||||
|
parser.add_argument('--batches', default=1, type=int,
|
||||||
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
|
parser.add_argument('--loader-workers', default=0, type=int,
|
||||||
|
help='number of data loading workers, per GPU in training or '
|
||||||
|
'in total in testing')
|
||||||
|
parser.add_argument('--pad-or-crop', default=0, type=int_tuple,
|
||||||
|
help='pad (>0) or crop (<0) the input data; '
|
||||||
|
'can be a int or a 6-tuple (by a comma-sep. list); '
|
||||||
|
'can be asymmetric to align the data with downsample '
|
||||||
|
'and upsample convolutions; '
|
||||||
|
'padding assumes periodic boundary condition')
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
@ -39,12 +50,8 @@ def add_train_args(parser):
|
|||||||
help='comma-sep. list of glob patterns for validation input data')
|
help='comma-sep. list of glob patterns for validation input data')
|
||||||
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
|
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
|
||||||
help='comma-sep. list of glob patterns for validation target data')
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
parser.add_argument('--epochs', default=1024, type=int,
|
||||||
help='total number of epochs to run')
|
help='total number of epochs to run')
|
||||||
parser.add_argument('--batches-per-gpu', default=8, type=int,
|
|
||||||
help='mini-batch size per GPU')
|
|
||||||
parser.add_argument('--loader-workers-per-gpu', default=4, type=int,
|
|
||||||
help='number of data loading workers per GPU')
|
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument('--augment', action='store_true',
|
||||||
help='enable training data augmentation')
|
help='enable training data augmentation')
|
||||||
parser.add_argument('--optimizer', default='Adam',
|
parser.add_argument('--optimizer', default='Adam',
|
||||||
@ -74,3 +81,13 @@ def add_test_args(parser):
|
|||||||
|
|
||||||
def str_list(s):
|
def str_list(s):
|
||||||
return s.split(',')
|
return s.split(',')
|
||||||
|
|
||||||
|
|
||||||
|
def int_tuple(t):
|
||||||
|
t = t.split(',')
|
||||||
|
t = tuple(int(i) for i in t)
|
||||||
|
if len(t) == 1:
|
||||||
|
t = t[0]
|
||||||
|
elif len(t) != 6:
|
||||||
|
raise ValueError('pad or crop size must be int or 6-tuple')
|
||||||
|
return t
|
||||||
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from . import norms
|
from .norms import import_norm
|
||||||
|
|
||||||
|
|
||||||
class FieldDataset(Dataset):
|
class FieldDataset(Dataset):
|
||||||
@ -14,12 +14,15 @@ class FieldDataset(Dataset):
|
|||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target samples of all fields are matched by sorting the globbed files.
|
Input and target samples of all fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
|
Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`.
|
||||||
|
Padding assumes periodic boundary condition.
|
||||||
|
|
||||||
Data augmentations are supported for scalar and vector fields.
|
Data augmentations are supported for scalar and vector fields.
|
||||||
|
|
||||||
`normalize` can be a list of callables to normalize each field.
|
`norms` can be a list of callables to normalize each field.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns, augment=False,
|
def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False,
|
||||||
normalize=None, **kwargs):
|
norms=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -29,23 +32,31 @@ class FieldDataset(Dataset):
|
|||||||
assert len(self.in_files) == len(self.tgt_files), \
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
'input and target sample sizes do not match'
|
'input and target sample sizes do not match'
|
||||||
|
|
||||||
|
if isinstance(pad_or_crop, int):
|
||||||
|
pad_or_crop = (pad_or_crop,) * 6
|
||||||
|
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
|
||||||
|
'pad or crop size must be int or 6-tuple'
|
||||||
|
self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2)
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
|
||||||
self.normalize = normalize
|
if norms is not None:
|
||||||
if self.normalize is not None:
|
assert len(in_patterns) == len(norms), \
|
||||||
assert len(in_patterns) == len(self.normalize), \
|
|
||||||
'numbers of normalization callables and input fields do not match'
|
'numbers of normalization callables and input fields do not match'
|
||||||
|
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||||
# self.__dict__.update(kwargs)
|
self.norms = norms
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files)
|
return len(self.in_files)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
in_fields = [torch.from_numpy(np.load(f)).to(torch.float32)
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
for f in self.in_files[idx]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
tgt_fields = [torch.from_numpy(np.load(f)).to(torch.float32)
|
|
||||||
for f in self.tgt_files[idx]]
|
padcrop(in_fields, self.pad_or_crop) # with numpy
|
||||||
|
|
||||||
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
flip_axes = torch.randint(2, (3,), dtype=torch.bool)
|
flip_axes = torch.randint(2, (3,), dtype=torch.bool)
|
||||||
@ -59,18 +70,8 @@ class FieldDataset(Dataset):
|
|||||||
perm3d(in_fields, perm_axes)
|
perm3d(in_fields, perm_axes)
|
||||||
perm3d(tgt_fields, perm_axes)
|
perm3d(tgt_fields, perm_axes)
|
||||||
|
|
||||||
if self.normalize is not None:
|
if self.norms is not None:
|
||||||
def get_norm(path):
|
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
|
||||||
path = path.split('.')
|
|
||||||
norm = norms
|
|
||||||
while path:
|
|
||||||
norm = norm.__dict__[path.pop(0)]
|
|
||||||
return norm
|
|
||||||
|
|
||||||
for norm, ifield, tfield in zip(self.normalize, in_fields, tgt_fields):
|
|
||||||
if isinstance(norm, str):
|
|
||||||
norm = get_norm(norm)
|
|
||||||
|
|
||||||
norm(ifield)
|
norm(ifield)
|
||||||
norm(tfield)
|
norm(tfield)
|
||||||
|
|
||||||
@ -80,6 +81,22 @@ class FieldDataset(Dataset):
|
|||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
|
def padcrop(fields, width):
|
||||||
|
for i, x in enumerate(fields):
|
||||||
|
if (width >= 0).all():
|
||||||
|
x = np.pad(x, width, mode='wrap')
|
||||||
|
elif (width <= 0).all():
|
||||||
|
x = x[...,
|
||||||
|
-width[1, 0] : width[1, 1],
|
||||||
|
-width[2, 0] : width[2, 1],
|
||||||
|
-width[3, 0] : width[3, 1],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('mixed pad-and-crop not supported')
|
||||||
|
|
||||||
|
fields[i] = x
|
||||||
|
|
||||||
|
|
||||||
def flip3d(fields, axes):
|
def flip3d(fields, axes):
|
||||||
for i, x in enumerate(fields):
|
for i, x in enumerate(fields):
|
||||||
if x.size(0) == 3: # flip vector components
|
if x.size(0) == 3: # flip vector components
|
||||||
@ -90,6 +107,7 @@ def flip3d(fields, axes):
|
|||||||
|
|
||||||
fields[i] = x
|
fields[i] = x
|
||||||
|
|
||||||
|
|
||||||
def perm3d(fields, axes):
|
def perm3d(fields, axes):
|
||||||
for i, x in enumerate(fields):
|
for i, x in enumerate(fields):
|
||||||
if x.size(0) == 3: # permutate vector components
|
if x.size(0) == 3: # permutate vector components
|
||||||
|
@ -1 +1,10 @@
|
|||||||
|
from importlib import import_module
|
||||||
|
|
||||||
from . import cosmology
|
from . import cosmology
|
||||||
|
|
||||||
|
|
||||||
|
def import_norm(path):
|
||||||
|
mod, func = path.rsplit('.', 1)
|
||||||
|
mod = import_module('.' + mod, __name__)
|
||||||
|
func = getattr(mod, func)
|
||||||
|
return func
|
||||||
|
@ -10,4 +10,4 @@ def main():
|
|||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
train.node_worker(args)
|
train.node_worker(args)
|
||||||
elif args.mode == 'test':
|
elif args.mode == 'test':
|
||||||
pass
|
test.test(args)
|
||||||
|
@ -4,15 +4,6 @@ import torch.nn as nn
|
|||||||
from .conv import ConvBlock, ResBlock, narrow_like
|
from .conv import ConvBlock, ResBlock, narrow_like
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(ConvBlock):
|
|
||||||
def __init__(self, in_channels, out_channels, seq='BADBA'):
|
|
||||||
super().__init__(in_channels, out_channels, seq=seq)
|
|
||||||
|
|
||||||
class UpBlock(ConvBlock):
|
|
||||||
def __init__(self, in_channels, out_channels, seq='BAUBA'):
|
|
||||||
super().__init__(in_channels, out_channels, seq=seq)
|
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,8 +1,58 @@
|
|||||||
import os
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from .models import UNet, narrow_like
|
from .models import UNet, narrow_like
|
||||||
|
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
test_dataset = FieldDataset(
|
||||||
|
in_patterns=args.test_in_patterns,
|
||||||
|
tgt_patterns=args.test_tgt_patterns,
|
||||||
|
augment=False,
|
||||||
|
norms=args.norms,
|
||||||
|
pad_or_crop=args.pad_or_crop,
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=args.batches,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = UNet(args.in_channels, args.out_channels)
|
||||||
|
criterion = torch.nn.__dict__[args.criterion]()
|
||||||
|
|
||||||
|
device = torch.device('cpu')
|
||||||
|
state = torch.load(args.load_state, map_location=device)
|
||||||
|
from collections import OrderedDict
|
||||||
|
model_state = OrderedDict()
|
||||||
|
for k, v in state['model'].items():
|
||||||
|
model_k = k.replace('module.', '', 1) # FIXME
|
||||||
|
model_state[model_k] = v
|
||||||
|
model.load_state_dict(model_state)
|
||||||
|
print('model state at epoch {} loaded from {}'.format(
|
||||||
|
state['epoch'], args.load_state))
|
||||||
|
del state
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (input, target) in enumerate(test_loader):
|
||||||
|
output = model(input)
|
||||||
|
if args.pad_or_crop > 0: # FIXME
|
||||||
|
output = narrow_like(output, target)
|
||||||
|
else:
|
||||||
|
target = narrow_like(target, output)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
print('sample {} loss: {}'.format(i, loss))
|
||||||
|
|
||||||
|
if args.norms is not None:
|
||||||
|
norm = test_dataset.norms[0] # FIXME
|
||||||
|
norm(output, undo=True)
|
||||||
|
|
||||||
|
np.savez('{}.npz'.format(i), input=input.numpy(),
|
||||||
|
output=output.numpy(), target=target.numpy())
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing import spawn
|
from torch.multiprocessing import spawn
|
||||||
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
||||||
@ -46,15 +45,16 @@ def gpu_worker(local_rank, args):
|
|||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
normalize=args.norms,
|
norms=args.norms,
|
||||||
|
pad_or_crop=args.pad_or_crop,
|
||||||
)
|
)
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches_per_gpu,
|
batch_size=args.batches,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=args.loader_workers_per_gpu,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,15 +62,16 @@ def gpu_worker(local_rank, args):
|
|||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
augment=False,
|
augment=False,
|
||||||
normalize=args.norms,
|
norms=args.norms,
|
||||||
|
pad_or_crop=args.pad_or_crop,
|
||||||
)
|
)
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches_per_gpu,
|
batch_size=args.batches,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=args.loader_workers_per_gpu,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,17 +91,17 @@ def gpu_worker(local_rank, args):
|
|||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
checkpoint = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
args.start_epoch = checkpoint['epoch']
|
args.start_epoch = state['epoch']
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(state['model'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
scheduler.load_state_dict(state['scheduler'])
|
||||||
torch.set_rng_state(checkpoint['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
min_loss = checkpoint['min_loss']
|
min_loss = state['min_loss']
|
||||||
print('checkpoint of epoch {} loaded from {}'.format(
|
print('checkpoint at epoch {} loaded from {}'.format(
|
||||||
checkpoint['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del checkpoint
|
del state
|
||||||
else:
|
else:
|
||||||
args.start_epoch = 0
|
args.start_epoch = 0
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
@ -125,7 +126,7 @@ def gpu_worker(local_rank, args):
|
|||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger.close()
|
args.logger.close()
|
||||||
|
|
||||||
checkpoint = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'model': model.state_dict(),
|
'model': model.state_dict(),
|
||||||
'optimizer' : optimizer.state_dict(),
|
'optimizer' : optimizer.state_dict(),
|
||||||
@ -134,8 +135,8 @@ def gpu_worker(local_rank, args):
|
|||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
filename='checkpoint.pth'
|
filename='checkpoint.pth'
|
||||||
torch.save(checkpoint, filename)
|
torch.save(state, filename)
|
||||||
del checkpoint
|
del state
|
||||||
|
|
||||||
if min_loss is None or val_loss < min_loss:
|
if min_loss is None or val_loss < min_loss:
|
||||||
min_loss = val_loss
|
min_loss = val_loss
|
||||||
@ -152,7 +153,7 @@ def train(epoch, loader, model, criterion, optimizer, args):
|
|||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
target = narrow_like(target, output)
|
target = narrow_like(target, output) # FIXME pad
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
@ -167,7 +168,6 @@ def train(epoch, loader, model, criterion, optimizer, args):
|
|||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
|
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
|
||||||
|
|
||||||
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
|
|
||||||
|
|
||||||
def validate(epoch, loader, model, criterion, args):
|
def validate(epoch, loader, model, criterion, args):
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -180,7 +180,7 @@ def validate(epoch, loader, model, criterion, args):
|
|||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
target = narrow_like(target, output)
|
target = narrow_like(target, output) # FIXME pad
|
||||||
|
|
||||||
loss += criterion(output, target)
|
loss += criterion(output, target)
|
||||||
|
|
||||||
@ -189,6 +189,4 @@ def validate(epoch, loader, model, criterion, args):
|
|||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
|
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
|
||||||
|
|
||||||
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
|
|
||||||
|
|
||||||
return loss.item()
|
return loss.item()
|
||||||
|
48
scripts/dis2dis-test.slurm
Normal file
48
scripts/dis2dis-test.slurm
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#SBATCH --job-name=dis2dis-test
|
||||||
|
#SBATCH --output=%x-%j.out
|
||||||
|
#SBATCH --error=%x-%j.err
|
||||||
|
|
||||||
|
#SBATCH --partition=ccm
|
||||||
|
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --mem=0
|
||||||
|
#SBATCH --time=1-00:00:00
|
||||||
|
|
||||||
|
|
||||||
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
|
module load gcc openmpi2
|
||||||
|
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
||||||
|
|
||||||
|
source $HOME/anaconda3/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
|
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
|
||||||
|
|
||||||
|
|
||||||
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
|
|
||||||
|
in_dir="linear"
|
||||||
|
tgt_dir="nonlin"
|
||||||
|
|
||||||
|
test_dirs="0" # FIXME
|
||||||
|
|
||||||
|
files="dis/128x???.npy"
|
||||||
|
in_files="$files"
|
||||||
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
|
srun m2m.py test \
|
||||||
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
|
--in-channels 3 --out-channels 3 --norms cosmology.dis \
|
||||||
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
|
||||||
|
date
|
@ -11,7 +11,7 @@
|
|||||||
#SBATCH --exclusive
|
#SBATCH --exclusive
|
||||||
#SBATCH --nodes=2
|
#SBATCH --nodes=2
|
||||||
#SBATCH --mem=0
|
#SBATCH --mem=0
|
||||||
#SBATCH --time=2-00:00:00
|
#SBATCH --time=7-00:00:00
|
||||||
|
|
||||||
|
|
||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
@ -46,7 +46,7 @@ srun m2m.py train \
|
|||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
||||||
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
48
scripts/vel2vel-test.slurm
Normal file
48
scripts/vel2vel-test.slurm
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#SBATCH --job-name=vel2vel-test
|
||||||
|
#SBATCH --output=%x-%j.out
|
||||||
|
#SBATCH --error=%x-%j.err
|
||||||
|
|
||||||
|
#SBATCH --partition=ccm
|
||||||
|
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --mem=0
|
||||||
|
#SBATCH --time=1-00:00:00
|
||||||
|
|
||||||
|
|
||||||
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
|
module load gcc openmpi2
|
||||||
|
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
|
||||||
|
|
||||||
|
source $HOME/anaconda3/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
|
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
|
||||||
|
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
|
||||||
|
|
||||||
|
|
||||||
|
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||||
|
|
||||||
|
in_dir="linear"
|
||||||
|
tgt_dir="nonlin"
|
||||||
|
|
||||||
|
test_dirs="0" # FIXME
|
||||||
|
|
||||||
|
files="vel/128x???.npy"
|
||||||
|
in_files="$files"
|
||||||
|
tgt_files="$files"
|
||||||
|
|
||||||
|
|
||||||
|
srun m2m.py test \
|
||||||
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
|
--in-channels 3 --out-channels 3 --norms cosmology.vel \
|
||||||
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
|
||||||
|
date
|
@ -11,7 +11,7 @@
|
|||||||
#SBATCH --exclusive
|
#SBATCH --exclusive
|
||||||
#SBATCH --nodes=2
|
#SBATCH --nodes=2
|
||||||
#SBATCH --mem=0
|
#SBATCH --mem=0
|
||||||
#SBATCH --time=2-00:00:00
|
#SBATCH --time=7-00:00:00
|
||||||
|
|
||||||
|
|
||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
@ -46,7 +46,7 @@ srun m2m.py train \
|
|||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
||||||
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user