chore: Blackify the source code for pretty formatting
This commit is contained in:
parent
ec9732df21
commit
7361c3eb9d
397
map2map/args.py
397
map2map/args.py
@ -7,18 +7,16 @@ from .train import ckpt_link
|
|||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
"""Parse arguments and set runtime defaults.
|
"""Parse arguments and set runtime defaults."""
|
||||||
"""
|
parser = argparse.ArgumentParser(description="Transform field(s) to field(s)")
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Transform field(s) to field(s)')
|
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
subparsers = parser.add_subparsers(title="modes", dest="mode", required=True)
|
||||||
train_parser = subparsers.add_parser(
|
train_parser = subparsers.add_parser(
|
||||||
'train',
|
"train",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
test_parser = subparsers.add_parser(
|
test_parser = subparsers.add_parser(
|
||||||
'test',
|
"test",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,163 +25,293 @@ def get_args():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == "train":
|
||||||
set_train_args(args)
|
set_train_args(args)
|
||||||
elif args.mode == 'test':
|
elif args.mode == "test":
|
||||||
set_test_args(args)
|
set_test_args(args)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def add_common_args(parser):
|
def add_common_args(parser):
|
||||||
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
|
parser.add_argument(
|
||||||
'of input normalization functions')
|
"--in-norms",
|
||||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
type=str_list,
|
||||||
'of target normalization functions')
|
help="comma-sep. list " "of input normalization functions",
|
||||||
parser.add_argument('--crop', type=int_tuple,
|
)
|
||||||
help='size to crop the input and target data. Default is the '
|
parser.add_argument(
|
||||||
'field size. Comma-sep. list of 1 or d integers')
|
"--tgt-norms",
|
||||||
parser.add_argument('--crop-start', type=int_tuple,
|
type=str_list,
|
||||||
help='starting point of the first crop. Default is the origin. '
|
help="comma-sep. list " "of target normalization functions",
|
||||||
'Comma-sep. list of 1 or d integers')
|
)
|
||||||
parser.add_argument('--crop-stop', type=int_tuple,
|
parser.add_argument(
|
||||||
help='stopping point of the last crop. Default is the opposite '
|
"--crop",
|
||||||
'corner to the origin. Comma-sep. list of 1 or d integers')
|
type=int_tuple,
|
||||||
parser.add_argument('--crop-step', type=int_tuple,
|
help="size to crop the input and target data. Default is the "
|
||||||
help='spacing between crops. Default is the crop size. '
|
"field size. Comma-sep. list of 1 or d integers",
|
||||||
'Comma-sep. list of 1 or d integers')
|
)
|
||||||
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
|
parser.add_argument(
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
"--crop-start",
|
||||||
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
|
type=int_tuple,
|
||||||
'integers, to pad equally along all axes, symmetrically on each, '
|
help="starting point of the first crop. Default is the origin. "
|
||||||
'or by the specified size on every boundary, respectively')
|
"Comma-sep. list of 1 or d integers",
|
||||||
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
|
)
|
||||||
help='size to pad the target data beyond the crop size, assuming '
|
parser.add_argument(
|
||||||
'periodic boundary condition, useful for super-resolution. '
|
"--crop-stop",
|
||||||
'Comma-sep. list with the same format as --in-pad')
|
type=int_tuple,
|
||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
help="stopping point of the last crop. Default is the opposite "
|
||||||
help='upsampling factor for super-resolution, in which case '
|
"corner to the origin. Comma-sep. list of 1 or d integers",
|
||||||
'crop and pad are sizes of the input resolution')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop-step",
|
||||||
|
type=int_tuple,
|
||||||
|
help="spacing between crops. Default is the crop size. "
|
||||||
|
"Comma-sep. list of 1 or d integers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--in-pad",
|
||||||
|
"--pad",
|
||||||
|
default=0,
|
||||||
|
type=int_tuple,
|
||||||
|
help="size to pad the input data beyond the crop size, assuming "
|
||||||
|
"periodic boundary condition. Comma-sep. list of 1, d, or dx2 "
|
||||||
|
"integers, to pad equally along all axes, symmetrically on each, "
|
||||||
|
"or by the specified size on every boundary, respectively",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tgt-pad",
|
||||||
|
default=0,
|
||||||
|
type=int_tuple,
|
||||||
|
help="size to pad the target data beyond the crop size, assuming "
|
||||||
|
"periodic boundary condition, useful for super-resolution. "
|
||||||
|
"Comma-sep. list with the same format as --in-pad",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale-factor",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="upsampling factor for super-resolution, in which case "
|
||||||
|
"crop and pad are sizes of the input resolution",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--model', type=str, required=True,
|
parser.add_argument("--model", type=str, required=True, help="(generator) model")
|
||||||
help='(generator) model')
|
parser.add_argument(
|
||||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
"--criterion", default="MSELoss", type=str, help="loss function"
|
||||||
help='loss function')
|
)
|
||||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
parser.add_argument(
|
||||||
help='path to load the states of model, optimizer, rng, etc. '
|
"--load-state",
|
||||||
'Default is the checkpoint. '
|
default=ckpt_link,
|
||||||
'Start from scratch in case of empty string or missing checkpoint')
|
type=str,
|
||||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
help="path to load the states of model, optimizer, rng, etc. "
|
||||||
help='allow incompatible keys when loading model states',
|
"Default is the checkpoint. "
|
||||||
dest='load_state_strict')
|
"Start from scratch in case of empty string or missing checkpoint",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-state-non-strict",
|
||||||
|
action="store_false",
|
||||||
|
help="allow incompatible keys when loading model states",
|
||||||
|
dest="load_state_strict",
|
||||||
|
)
|
||||||
|
|
||||||
# somehow I named it "batches" instead of batch_size at first
|
# somehow I named it "batches" instead of batch_size at first
|
||||||
# "batches" is kept for now for backward compatibility
|
# "batches" is kept for now for backward compatibility
|
||||||
parser.add_argument('--batch-size', '--batches', type=int, required=True,
|
parser.add_argument(
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
"--batch-size",
|
||||||
parser.add_argument('--loader-workers', default=8, type=int,
|
"--batches",
|
||||||
help='number of subprocesses per data loader. '
|
type=int,
|
||||||
'0 to disable multiprocessing')
|
required=True,
|
||||||
|
help="mini-batch size, per GPU in training or in total in testing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--loader-workers",
|
||||||
|
default=8,
|
||||||
|
type=int,
|
||||||
|
help="number of subprocesses per data loader. " "0 to disable multiprocessing",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
parser.add_argument(
|
||||||
help='directory of custorm code defining callbacks for models, '
|
"--callback-at",
|
||||||
'norms, criteria, and optimizers. Disabled if not set. '
|
type=lambda s: os.path.abspath(s),
|
||||||
'This is appended to the default locations, '
|
help="directory of custorm code defining callbacks for models, "
|
||||||
'thus has the lowest priority')
|
"norms, criteria, and optimizers. Disabled if not set. "
|
||||||
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
|
"This is appended to the default locations, "
|
||||||
help='miscellaneous keyword arguments for custom models and '
|
"thus has the lowest priority",
|
||||||
'norms. Be careful with name collisions')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--misc-kwargs",
|
||||||
|
default="{}",
|
||||||
|
type=json.loads,
|
||||||
|
help="miscellaneous keyword arguments for custom models and "
|
||||||
|
"norms. Be careful with name collisions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
add_common_args(parser)
|
add_common_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--train-style-pattern', type=str, required=True,
|
parser.add_argument(
|
||||||
help='glob pattern for training data styles')
|
"--train-style-pattern",
|
||||||
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
type=str,
|
||||||
help='comma-sep. list of glob patterns for training input data')
|
required=True,
|
||||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
help="glob pattern for training data styles",
|
||||||
help='comma-sep. list of glob patterns for training target data')
|
)
|
||||||
parser.add_argument('--val-style-pattern', type=str,
|
parser.add_argument(
|
||||||
help='glob pattern for validation data styles')
|
"--train-in-patterns",
|
||||||
parser.add_argument('--val-in-patterns', type=str_list,
|
type=str_list,
|
||||||
help='comma-sep. list of glob patterns for validation input data')
|
required=True,
|
||||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
help="comma-sep. list of glob patterns for training input data",
|
||||||
help='comma-sep. list of glob patterns for validation target data')
|
)
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument(
|
||||||
help='enable data augmentation of axis flipping and permutation')
|
"--train-tgt-patterns",
|
||||||
parser.add_argument('--aug-shift', type=int_tuple,
|
type=str_list,
|
||||||
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
|
required=True,
|
||||||
'useful for models that treat neighboring pixels differently, '
|
help="comma-sep. list of glob patterns for training target data",
|
||||||
'e.g. with strided convolutions. '
|
)
|
||||||
'Comma-sep. list of 1 or d integers')
|
parser.add_argument(
|
||||||
parser.add_argument('--aug-add', type=float,
|
"--val-style-pattern", type=str, help="glob pattern for validation data styles"
|
||||||
help='additive data augmentation, (normal) std, '
|
)
|
||||||
'same factor for all fields')
|
parser.add_argument(
|
||||||
parser.add_argument('--aug-mul', type=float,
|
"--val-in-patterns",
|
||||||
help='multiplicative data augmentation, (log-normal) std, '
|
type=str_list,
|
||||||
'same factor for all fields')
|
help="comma-sep. list of glob patterns for validation input data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-tgt-patterns",
|
||||||
|
type=str_list,
|
||||||
|
help="comma-sep. list of glob patterns for validation target data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--augment",
|
||||||
|
action="store_true",
|
||||||
|
help="enable data augmentation of axis flipping and permutation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aug-shift",
|
||||||
|
type=int_tuple,
|
||||||
|
help="data augmentation by shifting cropping by [0, aug_shift) pixels, "
|
||||||
|
"useful for models that treat neighboring pixels differently, "
|
||||||
|
"e.g. with strided convolutions. "
|
||||||
|
"Comma-sep. list of 1 or d integers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aug-add",
|
||||||
|
type=float,
|
||||||
|
help="additive data augmentation, (normal) std, " "same factor for all fields",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--aug-mul",
|
||||||
|
type=float,
|
||||||
|
help="multiplicative data augmentation, (log-normal) std, "
|
||||||
|
"same factor for all fields",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument(
|
||||||
help='optimization algorithm')
|
"--optimizer", default="Adam", type=str, help="optimization algorithm"
|
||||||
parser.add_argument('--lr', type=float, required=True,
|
)
|
||||||
help='initial learning rate')
|
parser.add_argument("--lr", type=float, required=True, help="initial learning rate")
|
||||||
parser.add_argument('--optimizer-args', default='{}', type=json.loads,
|
parser.add_argument(
|
||||||
help='optimizer arguments in addition to the learning rate, '
|
"--optimizer-args",
|
||||||
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
|
default="{}",
|
||||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
type=json.loads,
|
||||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
help="optimizer arguments in addition to the learning rate, "
|
||||||
parser.add_argument('--scheduler-args', default='{"verbose": true}',
|
"e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'",
|
||||||
type=json.loads,
|
)
|
||||||
help='arguments for the ReduceLROnPlateau scheduler')
|
parser.add_argument(
|
||||||
parser.add_argument('--init-weight-std', type=float,
|
"--reduce-lr-on-plateau",
|
||||||
help='weight initialization std')
|
action="store_true",
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
help="Enable ReduceLROnPlateau learning rate scheduler",
|
||||||
help='total number of epochs to run')
|
)
|
||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument(
|
||||||
help='seed for initializing training')
|
"--scheduler-args",
|
||||||
|
default='{"verbose": true}',
|
||||||
|
type=json.loads,
|
||||||
|
help="arguments for the ReduceLROnPlateau scheduler",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--init-weight-std", type=float, help="weight initialization std"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs", default=128, type=int, help="total number of epochs to run"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", default=42, type=int, help="seed for initializing training"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--div-data', action='store_true',
|
parser.add_argument(
|
||||||
help='enable data division among GPUs for better page caching. '
|
"--div-data",
|
||||||
'Data division is shuffled every epoch. '
|
action="store_true",
|
||||||
'Only relevant if there are multiple crops in each field')
|
help="enable data division among GPUs for better page caching. "
|
||||||
parser.add_argument('--div-shuffle-dist', default=1, type=float,
|
"Data division is shuffled every epoch. "
|
||||||
help='distance to further shuffle cropped samples relative to '
|
"Only relevant if there are multiple crops in each field",
|
||||||
'their fields, to be used with --div-data. '
|
)
|
||||||
'Only relevant if there are multiple crops in each file. '
|
parser.add_argument(
|
||||||
'The order of each sample is randomly displaced by this value. '
|
"--div-shuffle-dist",
|
||||||
'Setting it to 0 turn off this randomization, and setting it to N '
|
default=1,
|
||||||
'limits the shuffling within a distance of N files. '
|
type=float,
|
||||||
'Change this to balance cache locality and stochasticity')
|
help="distance to further shuffle cropped samples relative to "
|
||||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
"their fields, to be used with --div-data. "
|
||||||
choices=['gloo', 'nccl'], help='distributed backend')
|
"Only relevant if there are multiple crops in each file. "
|
||||||
parser.add_argument('--log-interval', default=100, type=int,
|
"The order of each sample is randomly displaced by this value. "
|
||||||
help='interval (batches) between logging training loss')
|
"Setting it to 0 turn off this randomization, and setting it to N "
|
||||||
parser.add_argument('--detect-anomaly', action='store_true',
|
"limits the shuffling within a distance of N files. "
|
||||||
help='enable anomaly detection for the autograd engine')
|
"Change this to balance cache locality and stochasticity",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dist-backend",
|
||||||
|
default="nccl",
|
||||||
|
type=str,
|
||||||
|
choices=["gloo", "nccl"],
|
||||||
|
help="distributed backend",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-interval",
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help="interval (batches) between logging training loss",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--detect-anomaly",
|
||||||
|
action="store_true",
|
||||||
|
help="enable anomaly detection for the autograd engine",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_test_args(parser):
|
def add_test_args(parser):
|
||||||
add_common_args(parser)
|
add_common_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--test-style-pattern', type=str, required=True,
|
parser.add_argument(
|
||||||
help='glob pattern for test data styles')
|
"--test-style-pattern",
|
||||||
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
type=str,
|
||||||
help='comma-sep. list of glob patterns for test input data')
|
required=True,
|
||||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
help="glob pattern for test data styles",
|
||||||
help='comma-sep. list of glob patterns for test target data')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-in-patterns",
|
||||||
|
type=str_list,
|
||||||
|
required=True,
|
||||||
|
help="comma-sep. list of glob patterns for test input data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-tgt-patterns",
|
||||||
|
type=str_list,
|
||||||
|
required=True,
|
||||||
|
help="comma-sep. list of glob patterns for test target data",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--num-threads', type=int,
|
parser.add_argument(
|
||||||
help='number of CPU threads when cuda is unavailable. '
|
"--num-threads",
|
||||||
'Default is the number of CPUs on the node by slurm')
|
type=int,
|
||||||
|
help="number of CPU threads when cuda is unavailable. "
|
||||||
|
"Default is the number of CPUs on the node by slurm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def str_list(s):
|
def str_list(s):
|
||||||
return s.split(',')
|
return s.split(",")
|
||||||
|
|
||||||
|
|
||||||
def int_tuple(s):
|
def int_tuple(s):
|
||||||
t = s.split(',')
|
t = s.split(",")
|
||||||
t = tuple(int(i) for i in t)
|
t = tuple(int(i) for i in t)
|
||||||
if len(t) == 1:
|
if len(t) == 1:
|
||||||
return t[0]
|
return t[0]
|
||||||
@ -198,8 +326,7 @@ def set_common_args(args):
|
|||||||
def set_train_args(args):
|
def set_train_args(args):
|
||||||
set_common_args(args)
|
set_common_args(args)
|
||||||
|
|
||||||
args.val = args.val_in_patterns is not None and \
|
args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None
|
||||||
args.val_tgt_patterns is not None
|
|
||||||
|
|
||||||
|
|
||||||
def set_test_args(args):
|
def set_test_args(args):
|
||||||
|
@ -43,54 +43,72 @@ class FieldDataset(Dataset):
|
|||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
"""
|
"""
|
||||||
def __init__(self, style_pattern, in_patterns, tgt_patterns,
|
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
def __init__(
|
||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
self,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
style_pattern,
|
||||||
in_pad=0, tgt_pad=0, scale_factor=1,
|
in_patterns,
|
||||||
**kwargs):
|
tgt_patterns,
|
||||||
|
in_norms=None,
|
||||||
|
tgt_norms=None,
|
||||||
|
callback_at=None,
|
||||||
|
augment=False,
|
||||||
|
aug_shift=None,
|
||||||
|
aug_add=None,
|
||||||
|
aug_mul=None,
|
||||||
|
crop=None,
|
||||||
|
crop_start=None,
|
||||||
|
crop_stop=None,
|
||||||
|
crop_step=None,
|
||||||
|
in_pad=0,
|
||||||
|
tgt_pad=0,
|
||||||
|
scale_factor=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.style_files = sorted(glob(style_pattern))
|
self.style_files = sorted(glob(style_pattern))
|
||||||
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(*in_file_lists))
|
||||||
|
|
||||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||||
self.tgt_files = list(zip(* tgt_file_lists))
|
self.tgt_files = list(zip(*tgt_file_lists))
|
||||||
|
|
||||||
# self.style_files = self.style_files[:1]
|
# self.style_files = self.style_files[:1]
|
||||||
# self.in_files = self.in_files[:1]
|
# self.in_files = self.in_files[:1]
|
||||||
# self.tgt_files = self.tgt_files[:1]
|
# self.tgt_files = self.tgt_files[:1]
|
||||||
|
|
||||||
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||||
raise ValueError('number of style, input, and target files do not match')
|
raise ValueError("number of style, input, and target files do not match")
|
||||||
self.nfile = len(self.in_files)
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
if self.nfile == 0:
|
if self.nfile == 0:
|
||||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
raise FileNotFoundError("file not found for {}".format(in_patterns))
|
||||||
|
|
||||||
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
|
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
|
||||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
self.in_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.in_files[0]]
|
||||||
for f in self.in_files[0]]
|
self.tgt_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.tgt_files[0]]
|
||||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
|
||||||
for f in self.tgt_files[0]]
|
|
||||||
|
|
||||||
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
|
self.size = np.load(self.in_files[0][0], mmap_mode="r").shape[1:]
|
||||||
self.size = np.asarray(self.size)
|
self.size = np.asarray(self.size)
|
||||||
self.ndim = len(self.size)
|
self.ndim = len(self.size)
|
||||||
|
|
||||||
if in_norms is not None and len(in_patterns) != len(in_norms):
|
if in_norms is not None and len(in_patterns) != len(in_norms):
|
||||||
raise ValueError('numbers of input normalization functions and fields do not match')
|
raise ValueError(
|
||||||
|
"numbers of input normalization functions and fields do not match"
|
||||||
|
)
|
||||||
self.in_norms = in_norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
|
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
|
||||||
raise ValueError('numbers of target normalization functions and fields do not match')
|
raise ValueError(
|
||||||
|
"numbers of target normalization functions and fields do not match"
|
||||||
|
)
|
||||||
self.tgt_norms = tgt_norms
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
self.callback_at = callback_at
|
self.callback_at = callback_at
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
raise ValueError('cannot augment 1D fields')
|
raise ValueError("cannot augment 1D fields")
|
||||||
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
||||||
self.aug_add = aug_add
|
self.aug_add = aug_add
|
||||||
self.aug_mul = aug_mul
|
self.aug_mul = aug_mul
|
||||||
@ -116,10 +134,15 @@ class FieldDataset(Dataset):
|
|||||||
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||||
self.crop_step = crop_step
|
self.crop_step = crop_step
|
||||||
|
|
||||||
self.anchors = np.stack(np.mgrid[tuple(
|
self.anchors = np.stack(
|
||||||
slice(crop_start[d], crop_stop[d], crop_step[d])
|
np.mgrid[
|
||||||
for d in range(self.ndim)
|
tuple(
|
||||||
)], axis=-1).reshape(-1, self.ndim)
|
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||||
|
for d in range(self.ndim)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
).reshape(-1, self.ndim)
|
||||||
self.ncrop = len(self.anchors)
|
self.ncrop = len(self.anchors)
|
||||||
|
|
||||||
def format_pad(pad, ndim):
|
def format_pad(pad, ndim):
|
||||||
@ -130,15 +153,16 @@ class FieldDataset(Dataset):
|
|||||||
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
|
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
|
||||||
pad = np.array(pad)
|
pad = np.array(pad)
|
||||||
else:
|
else:
|
||||||
raise ValueError('pad and ndim mismatch')
|
raise ValueError("pad and ndim mismatch")
|
||||||
return pad.reshape(ndim, 2)
|
return pad.reshape(ndim, 2)
|
||||||
|
|
||||||
self.in_pad = format_pad(in_pad, self.ndim)
|
self.in_pad = format_pad(in_pad, self.ndim)
|
||||||
self.tgt_pad = format_pad(tgt_pad, self.ndim)
|
self.tgt_pad = format_pad(tgt_pad, self.ndim)
|
||||||
|
|
||||||
if scale_factor != 1:
|
if scale_factor != 1:
|
||||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
tgt_size = np.load(self.tgt_files[0][0], mmap_mode="r").shape[1:]
|
||||||
if any(self.size * scale_factor != tgt_size):
|
if any(self.size * scale_factor != tgt_size):
|
||||||
raise ValueError('input size x scale factor != target size')
|
raise ValueError("input size x scale factor != target size")
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
self.nsample = self.nfile * self.ncrop
|
self.nsample = self.nfile * self.ncrop
|
||||||
@ -148,9 +172,7 @@ class FieldDataset(Dataset):
|
|||||||
self.assembly_line = {}
|
self.assembly_line = {}
|
||||||
|
|
||||||
self.commonpath = os.path.commonpath(
|
self.commonpath = os.path.commonpath(
|
||||||
file
|
file for files in self.in_files[:2] + self.tgt_files[:2] for file in files
|
||||||
for files in self.in_files[:2] + self.tgt_files[:2]
|
|
||||||
for file in files
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -180,12 +202,18 @@ class FieldDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
argsort_perm_axes = slice(None)
|
argsort_perm_axes = slice(None)
|
||||||
|
|
||||||
crop(in_fields, anchor,
|
crop(
|
||||||
self.crop[argsort_perm_axes],
|
in_fields,
|
||||||
self.in_pad[argsort_perm_axes])
|
anchor,
|
||||||
crop(tgt_fields, anchor * self.scale_factor,
|
self.crop[argsort_perm_axes],
|
||||||
self.crop[argsort_perm_axes] * self.scale_factor,
|
self.in_pad[argsort_perm_axes],
|
||||||
self.tgt_pad[argsort_perm_axes])
|
)
|
||||||
|
crop(
|
||||||
|
tgt_fields,
|
||||||
|
anchor * self.scale_factor,
|
||||||
|
self.crop[argsort_perm_axes] * self.scale_factor,
|
||||||
|
self.tgt_pad[argsort_perm_axes],
|
||||||
|
)
|
||||||
|
|
||||||
style = torch.from_numpy(style).to(torch.float32)
|
style = torch.from_numpy(style).to(torch.float32)
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
@ -221,17 +249,19 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
#in_relpath = [os.path.relpath(file, start=self.commonpath)
|
# in_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||||
# for file in self.in_files[ifile]]
|
# for file in self.in_files[ifile]]
|
||||||
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
|
tgt_relpath = [
|
||||||
for file in self.tgt_files[ifile]]
|
os.path.relpath(file, start=self.commonpath)
|
||||||
|
for file in self.tgt_files[ifile]
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'style': style,
|
"style": style,
|
||||||
'input': in_fields,
|
"input": in_fields,
|
||||||
'target': tgt_fields,
|
"target": tgt_fields,
|
||||||
#'input_relpath': in_relpath,
|
#'input_relpath': in_relpath,
|
||||||
'target_relpath': tgt_relpath,
|
"target_relpath": tgt_relpath,
|
||||||
}
|
}
|
||||||
|
|
||||||
def assemble(self, label, chan, patches, paths):
|
def assemble(self, label, chan, patches, paths):
|
||||||
@ -255,29 +285,29 @@ class FieldDataset(Dataset):
|
|||||||
if isinstance(patches, torch.Tensor):
|
if isinstance(patches, torch.Tensor):
|
||||||
patches = patches.detach().cpu().numpy()
|
patches = patches.detach().cpu().numpy()
|
||||||
|
|
||||||
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
|
assert patches.ndim == 2 + self.ndim, "ndim mismatch"
|
||||||
if any(self.crop_step > patches.shape[2:]):
|
if any(self.crop_step > patches.shape[2:]):
|
||||||
raise RuntimeError('patch too small to tile')
|
raise RuntimeError("patch too small to tile")
|
||||||
|
|
||||||
# the batched paths are a list of lists with shape (channel, batch)
|
# the batched paths are a list of lists with shape (channel, batch)
|
||||||
# since pytorch default_collate batches list of strings transposedly
|
# since pytorch default_collate batches list of strings transposedly
|
||||||
# therefore we transpose below back to (batch, channel)
|
# therefore we transpose below back to (batch, channel)
|
||||||
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
|
assert patches.shape[1] == sum(chan), "number of channels mismatch"
|
||||||
assert len(paths) == len(chan), 'number of fields mismatch'
|
assert len(paths) == len(chan), "number of fields mismatch"
|
||||||
paths = list(zip(* paths))
|
paths = list(zip(*paths))
|
||||||
assert patches.shape[0] == len(paths), 'batch size mismatch'
|
assert patches.shape[0] == len(paths), "batch size mismatch"
|
||||||
|
|
||||||
patches = list(patches)
|
patches = list(patches)
|
||||||
if label in self.assembly_line:
|
if label in self.assembly_line:
|
||||||
self.assembly_line[label] += patches
|
self.assembly_line[label] += patches
|
||||||
self.assembly_line[label + 'path'] += paths
|
self.assembly_line[label + "path"] += paths
|
||||||
else:
|
else:
|
||||||
self.assembly_line[label] = patches
|
self.assembly_line[label] = patches
|
||||||
self.assembly_line[label + 'path'] = paths
|
self.assembly_line[label + "path"] = paths
|
||||||
|
|
||||||
del patches, paths
|
del patches, paths
|
||||||
patches = self.assembly_line[label]
|
patches = self.assembly_line[label]
|
||||||
paths = self.assembly_line[label + 'path']
|
paths = self.assembly_line[label + "path"]
|
||||||
|
|
||||||
# NOTE anchor positioning assumes sufficient target padding and
|
# NOTE anchor positioning assumes sufficient target padding and
|
||||||
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
|
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
|
||||||
@ -285,31 +315,26 @@ class FieldDataset(Dataset):
|
|||||||
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
|
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
|
||||||
|
|
||||||
while len(patches) >= self.ncrop:
|
while len(patches) >= self.ncrop:
|
||||||
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
|
fields = np.zeros(patches[0].shape[:1] + tuple(self.size), patches[0].dtype)
|
||||||
patches[0].dtype)
|
|
||||||
|
|
||||||
for patch, anchor in zip(patches, anchors):
|
for patch, anchor in zip(patches, anchors):
|
||||||
fill(fields, patch, anchor)
|
fill(fields, patch, anchor)
|
||||||
|
|
||||||
for field, path in zip(
|
for field, path in zip(np.split(fields, np.cumsum(chan), axis=0), paths[0]):
|
||||||
np.split(fields, np.cumsum(chan), axis=0),
|
pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
|
||||||
paths[0]):
|
|
||||||
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
|
|
||||||
exist_ok=True)
|
|
||||||
|
|
||||||
path = label.join(os.path.splitext(path))
|
path = label.join(os.path.splitext(path))
|
||||||
np.save(path, field)
|
np.save(path, field)
|
||||||
|
|
||||||
del patches[:self.ncrop], paths[:self.ncrop]
|
del patches[: self.ncrop], paths[: self.ncrop]
|
||||||
|
|
||||||
|
|
||||||
def fill(field, patch, anchor):
|
def fill(field, patch, anchor):
|
||||||
ndim = len(anchor)
|
ndim = len(anchor)
|
||||||
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch'
|
assert field.ndim == patch.ndim == 1 + ndim, "ndim mismatch"
|
||||||
|
|
||||||
ind = [slice(None)]
|
ind = [slice(None)]
|
||||||
for d, (p, a, s) in enumerate(zip(
|
for d, (p, a, s) in enumerate(zip(patch.shape[1:], anchor, field.shape[1:])):
|
||||||
patch.shape[1:], anchor, field.shape[1:])):
|
|
||||||
i = np.arange(a, a + p)
|
i = np.arange(a, a + p)
|
||||||
i %= s
|
i %= s
|
||||||
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
@ -320,10 +345,10 @@ def fill(field, patch, anchor):
|
|||||||
|
|
||||||
|
|
||||||
def crop(fields, anchor, crop, pad):
|
def crop(fields, anchor, crop, pad):
|
||||||
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch'
|
assert all(x.shape == fields[0].shape for x in fields), "shape mismatch"
|
||||||
size = fields[0].shape[1:]
|
size = fields[0].shape[1:]
|
||||||
ndim = len(size)
|
ndim = len(size)
|
||||||
assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch'
|
assert ndim == len(anchor) == len(crop) == len(pad), "ndim mismatch"
|
||||||
|
|
||||||
ind = [slice(None)]
|
ind = [slice(None)]
|
||||||
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||||
@ -342,7 +367,7 @@ def crop(fields, anchor, crop, pad):
|
|||||||
|
|
||||||
|
|
||||||
def flip(fields, axes, ndim):
|
def flip(fields, axes, ndim):
|
||||||
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
|
assert ndim > 1, "flipping is ambiguous for 1D scalars/vectors"
|
||||||
|
|
||||||
if axes is None:
|
if axes is None:
|
||||||
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
||||||
@ -350,7 +375,7 @@ def flip(fields, axes, ndim):
|
|||||||
|
|
||||||
for i, x in enumerate(fields):
|
for i, x in enumerate(fields):
|
||||||
if x.shape[0] == ndim: # flip vector components
|
if x.shape[0] == ndim: # flip vector components
|
||||||
x[axes] = - x[axes]
|
x[axes] = -x[axes]
|
||||||
|
|
||||||
shifted_axes = (1 + axes).tolist()
|
shifted_axes = (1 + axes).tolist()
|
||||||
x = torch.flip(x, shifted_axes)
|
x = torch.flip(x, shifted_axes)
|
||||||
@ -361,7 +386,7 @@ def flip(fields, axes, ndim):
|
|||||||
|
|
||||||
|
|
||||||
def perm(fields, axes, ndim):
|
def perm(fields, axes, ndim):
|
||||||
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
assert ndim > 1, "permutation is not necessary for 1D fields"
|
||||||
|
|
||||||
if axes is None:
|
if axes is None:
|
||||||
axes = torch.randperm(ndim)
|
axes = torch.randperm(ndim)
|
||||||
|
@ -10,6 +10,7 @@ def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
|||||||
|
|
||||||
x *= dis_norm
|
x *= dis_norm
|
||||||
|
|
||||||
|
|
||||||
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||||
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||||
|
|
||||||
@ -20,25 +21,28 @@ def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def D(z, Om=0.31):
|
def D(z, Om=0.31):
|
||||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero"""
|
||||||
"""
|
|
||||||
OL = 1 - Om
|
OL = 1 - Om
|
||||||
a = 1 / (1+z)
|
a = 1 / (1 + z)
|
||||||
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
|
return (
|
||||||
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
|
a
|
||||||
|
* hyp2f1(1, 1 / 3, 11 / 6, -OL * a**3 / Om)
|
||||||
|
/ hyp2f1(1, 1 / 3, 11 / 6, -OL / Om)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def f(z, Om=0.31):
|
def f(z, Om=0.31):
|
||||||
"""linear growth rate for flat LambdaCDM
|
"""linear growth rate for flat LambdaCDM"""
|
||||||
"""
|
|
||||||
OL = 1 - Om
|
OL = 1 - Om
|
||||||
a = 1 / (1+z)
|
a = 1 / (1 + z)
|
||||||
aa3 = OL * a**3 / Om
|
aa3 = OL * a**3 / Om
|
||||||
return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \
|
return 1 - 6 / 11 * aa3 * hyp2f1(2, 4 / 3, 17 / 6, -aa3) / hyp2f1(
|
||||||
/ hyp2f1(1, 1/3, 11/6, -aa3)
|
1, 1 / 3, 11 / 6, -aa3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def H(z, Om=0.31):
|
def H(z, Om=0.31):
|
||||||
"""Hubble in [h km/s/Mpc] for flat LambdaCDM
|
"""Hubble in [h km/s/Mpc] for flat LambdaCDM"""
|
||||||
"""
|
|
||||||
OL = 1 - Om
|
OL = 1 - Om
|
||||||
a = 1 / (1+z)
|
a = 1 / (1 + z)
|
||||||
return 100 * np.sqrt(Om / a**3 + OL)
|
return 100 * np.sqrt(Om / a**3 + OL)
|
||||||
|
@ -7,18 +7,21 @@ def exp(x, undo=False, **kwargs):
|
|||||||
else:
|
else:
|
||||||
torch.log(x, out=x)
|
torch.log(x, out=x)
|
||||||
|
|
||||||
|
|
||||||
def log(x, eps=1e-8, undo=False, **kwargs):
|
def log(x, eps=1e-8, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.log(x + eps, out=x)
|
torch.log(x + eps, out=x)
|
||||||
else:
|
else:
|
||||||
torch.exp(x, out=x)
|
torch.exp(x, out=x)
|
||||||
|
|
||||||
|
|
||||||
def expm1(x, undo=False, **kwargs):
|
def expm1(x, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.expm1(x, out=x)
|
torch.expm1(x, out=x)
|
||||||
else:
|
else:
|
||||||
torch.log1p(x, out=x)
|
torch.log1p(x, out=x)
|
||||||
|
|
||||||
|
|
||||||
def log1p(x, eps=1e-7, undo=False, **kwargs):
|
def log1p(x, eps=1e-7, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.log1p(x + eps, out=x)
|
torch.log1p(x + eps, out=x)
|
||||||
|
@ -22,8 +22,8 @@ class DistFieldSampler(Sampler):
|
|||||||
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
|
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
|
||||||
of each epoch during training.
|
of each epoch during training.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dataset, shuffle,
|
|
||||||
div_data=False, div_shuffle_dist=0):
|
def __init__(self, dataset, shuffle, div_data=False, div_shuffle_dist=0):
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
@ -50,8 +50,10 @@ class DistFieldSampler(Sampler):
|
|||||||
ind = ind.flatten()
|
ind = ind.flatten()
|
||||||
|
|
||||||
# displace crops with respect to files
|
# displace crops with respect to files
|
||||||
dis = torch.rand((self.nfile, self.ncrop),
|
dis = (
|
||||||
generator=g) * self.div_shuffle_dist
|
torch.rand((self.nfile, self.ncrop), generator=g)
|
||||||
|
* self.div_shuffle_dist
|
||||||
|
)
|
||||||
loc = torch.arange(self.nfile)
|
loc = torch.arange(self.nfile)
|
||||||
loc = loc[:, None] + dis
|
loc = loc[:, None] + dis
|
||||||
loc = loc.flatten() % self.nfile # periodic in files
|
loc = loc.flatten() % self.nfile # periodic in files
|
||||||
|
@ -3,6 +3,7 @@ from . import test
|
|||||||
import click
|
import click
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from yaml import CLoader as Loader
|
from yaml import CLoader as Loader
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -12,20 +13,24 @@ import importlib.resources
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
def _load_resource_file(resource_path):
|
def _load_resource_file(resource_path):
|
||||||
# Import the package
|
# Import the package
|
||||||
pkg_files = importlib.resources.files() / resource_path
|
pkg_files = importlib.resources.files() / resource_path
|
||||||
with pkg_files.open() as file:
|
with pkg_files.open() as file:
|
||||||
return file.read() # Read the file and return its content
|
return file.read() # Read the file and return its content
|
||||||
|
|
||||||
|
|
||||||
def _str_list(value):
|
def _str_list(value):
|
||||||
return value.split(',')
|
return value.split(",")
|
||||||
|
|
||||||
|
|
||||||
def _int_tuple(value):
|
def _int_tuple(value):
|
||||||
t = value.split(',')
|
t = value.split(",")
|
||||||
t = tuple(int(i) for i in t)
|
t = tuple(int(i) for i in t)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
class VariadicType(click.ParamType):
|
class VariadicType(click.ParamType):
|
||||||
"""
|
"""
|
||||||
A custom parameter type for Click command-line interface.
|
A custom parameter type for Click command-line interface.
|
||||||
@ -61,7 +66,7 @@ class VariadicType(click.ParamType):
|
|||||||
self._typename = typename
|
self._typename = typename
|
||||||
self.name = self._type["type"]
|
self.name = self._type["type"]
|
||||||
if "func" not in self._type:
|
if "func" not in self._type:
|
||||||
self._type["func"] = eval(self._type['type'])
|
self._type["func"] = eval(self._type["type"])
|
||||||
|
|
||||||
def convert(self, value, param, ctx):
|
def convert(self, value, param, ctx):
|
||||||
try:
|
try:
|
||||||
@ -69,24 +74,26 @@ class VariadicType(click.ParamType):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
|
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
|
||||||
|
|
||||||
|
|
||||||
def _apply_options(options_file, f):
|
def _apply_options(options_file, f):
|
||||||
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
|
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
|
||||||
common_args = common_args['arguments']
|
common_args = common_args["arguments"]
|
||||||
|
|
||||||
for arg in common_args:
|
for arg in common_args:
|
||||||
argopt = common_args[arg]
|
argopt = common_args[arg]
|
||||||
if 'type' in argopt:
|
if "type" in argopt:
|
||||||
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
|
if type(argopt["type"]) == dict and argopt["type"]["type"] == "choice":
|
||||||
argopt['type'] = click.Choice(argopt['type']['opts'])
|
argopt["type"] = click.Choice(argopt["type"]["opts"])
|
||||||
else:
|
else:
|
||||||
argopt['type'] = VariadicType(argopt['type'])
|
argopt["type"] = VariadicType(argopt["type"])
|
||||||
f = click.option(f'--{arg}', **argopt)(f)
|
f = click.option(f"--{arg}", **argopt)(f)
|
||||||
else:
|
else:
|
||||||
f = click.option(f'--{arg}', **argopt)(f)
|
f = click.option(f"--{arg}", **argopt)(f)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
m2m_options=partial(_apply_options,"common_args.yaml")
|
|
||||||
|
m2m_options = partial(_apply_options, "common_args.yaml")
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@ -94,21 +101,24 @@ m2m_options=partial(_apply_options,"common_args.yaml")
|
|||||||
@click.pass_context
|
@click.pass_context
|
||||||
def main(ctx, config):
|
def main(ctx, config):
|
||||||
if config is not None and os.path.exists(config):
|
if config is not None and os.path.exists(config):
|
||||||
with open(config, 'r') as f:
|
with open(config, "r") as f:
|
||||||
config = yaml.load(f.read(), Loader=Loader)
|
config = yaml.load(f.read(), Loader=Loader)
|
||||||
ctx.default_map = config
|
ctx.default_map = config
|
||||||
|
|
||||||
|
|
||||||
# Make a class that provides access to dict with the attribute mechanism
|
# Make a class that provides access to dict with the attribute mechanism
|
||||||
class DictProxy:
|
class DictProxy:
|
||||||
def __init__(self, d):
|
def __init__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__ = d
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
@m2m_options
|
@m2m_options
|
||||||
@partial(_apply_options, "train_args.yaml")
|
@partial(_apply_options, "train_args.yaml")
|
||||||
def train(**kwargs):
|
def train(**kwargs):
|
||||||
train.node_worker(DictProxy(kwargs))
|
train.node_worker(DictProxy(kwargs))
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
@m2m_options
|
@m2m_options
|
||||||
@partial(_apply_options, "test_args.yaml")
|
@partial(_apply_options, "test_args.yaml")
|
||||||
|
@ -5,6 +5,7 @@ def adv_model_wrapper(module):
|
|||||||
"""Wrap an adversary model to also take lists of Tensors as input,
|
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||||
to be concatenated along the batch dimension
|
to be concatenated along the batch dimension
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _new_module(module):
|
class _new_module(module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
@ -22,6 +23,7 @@ def adv_criterion_wrapper(module):
|
|||||||
* expand target shape as that of input
|
* expand target shape as that of input
|
||||||
* return a list of losses, one for each pair of input and target Tensors
|
* return a list of losses, one for each pair of input and target Tensors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _new_module(module):
|
class _new_module(module):
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
assert isinstance(input, torch.Tensor)
|
assert isinstance(input, torch.Tensor)
|
||||||
@ -35,8 +37,9 @@ def adv_criterion_wrapper(module):
|
|||||||
|
|
||||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||||
|
|
||||||
loss = [super(new_module, self).forward(i, t)
|
loss = [
|
||||||
for i, t in zip(input, target)]
|
super(new_module, self).forward(i, t) for i, t in zip(input, target)
|
||||||
|
]
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -15,8 +15,10 @@ class ConvBlock(nn.Module):
|
|||||||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||||
'D': downsampling convolution of kernel size 2 and stride 2
|
'D': downsampling convolution of kernel size 2 and stride 2
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
|
||||||
kernel_size=3, stride=1, seq='CBA'):
|
def __init__(
|
||||||
|
self, in_chan, out_chan=None, mid_chan=None, kernel_size=3, stride=1, seq="CBA"
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if out_chan is None:
|
if out_chan is None:
|
||||||
@ -31,31 +33,30 @@ class ConvBlock(nn.Module):
|
|||||||
|
|
||||||
self.norm_chan = in_chan
|
self.norm_chan = in_chan
|
||||||
self.idx_conv = 0
|
self.idx_conv = 0
|
||||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
|
||||||
|
|
||||||
layers = [self._get_layer(l) for l in seq]
|
layers = [self._get_layer(l) for l in seq]
|
||||||
|
|
||||||
self.convs = nn.Sequential(*layers)
|
self.convs = nn.Sequential(*layers)
|
||||||
|
|
||||||
def _get_layer(self, l):
|
def _get_layer(self, l):
|
||||||
if l == 'U':
|
if l == "U":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
|
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
|
||||||
elif l == 'D':
|
elif l == "D":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
||||||
elif l == 'C':
|
elif l == "C":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.Conv3d(in_chan, out_chan, self.kernel_size,
|
return nn.Conv3d(in_chan, out_chan, self.kernel_size, stride=self.stride)
|
||||||
stride=self.stride)
|
elif l == "B":
|
||||||
elif l == 'B':
|
|
||||||
return nn.BatchNorm3d(self.norm_chan)
|
return nn.BatchNorm3d(self.norm_chan)
|
||||||
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
# return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||||
#return nn.InstanceNorm3d(self.norm_chan)
|
# return nn.InstanceNorm3d(self.norm_chan)
|
||||||
elif l == 'A':
|
elif l == "A":
|
||||||
return nn.LeakyReLU()
|
return nn.LeakyReLU()
|
||||||
else:
|
else:
|
||||||
raise ValueError('layer type {} not supported'.format(l))
|
raise ValueError("layer type {} not supported".format(l))
|
||||||
|
|
||||||
def _setup_conv(self):
|
def _setup_conv(self):
|
||||||
self.idx_conv += 1
|
self.idx_conv += 1
|
||||||
@ -88,13 +89,22 @@ class ResBlock(ConvBlock):
|
|||||||
|
|
||||||
See `ConvBlock` for `seq` types.
|
See `ConvBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
|
||||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chan,
|
||||||
|
out_chan=None,
|
||||||
|
mid_chan=None,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
seq="CBACBA",
|
||||||
|
last_act=None,
|
||||||
|
):
|
||||||
if last_act is None:
|
if last_act is None:
|
||||||
last_act = seq[-1] == 'A'
|
last_act = seq[-1] == "A"
|
||||||
elif last_act and seq[-1] != 'A':
|
elif last_act and seq[-1] != "A":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Disabling last_act without trailing activation in seq',
|
"Disabling last_act without trailing activation in seq",
|
||||||
RuntimeWarning,
|
RuntimeWarning,
|
||||||
)
|
)
|
||||||
last_act = False
|
last_act = False
|
||||||
@ -102,8 +112,14 @@ class ResBlock(ConvBlock):
|
|||||||
if last_act:
|
if last_act:
|
||||||
seq = seq[:-1]
|
seq = seq[:-1]
|
||||||
|
|
||||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
super().__init__(
|
||||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
in_chan,
|
||||||
|
out_chan=out_chan,
|
||||||
|
mid_chan=mid_chan,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
seq=seq,
|
||||||
|
)
|
||||||
|
|
||||||
if last_act:
|
if last_act:
|
||||||
self.act = nn.LeakyReLU()
|
self.act = nn.LeakyReLU()
|
||||||
@ -115,9 +131,10 @@ class ResBlock(ConvBlock):
|
|||||||
else:
|
else:
|
||||||
self.skip = nn.Conv3d(in_chan, out_chan, 1)
|
self.skip = nn.Conv3d(in_chan, out_chan, 1)
|
||||||
|
|
||||||
if 'U' in seq or 'D' in seq:
|
if "U" in seq or "D" in seq:
|
||||||
raise NotImplementedError('upsample and downsample layers '
|
raise NotImplementedError(
|
||||||
'not supported yet')
|
"upsample and downsample layers " "not supported yet"
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
|
@ -2,7 +2,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
class DiceLoss(nn.Module):
|
class DiceLoss(nn.Module):
|
||||||
def __init__(self, eps=0.):
|
def __init__(self, eps=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class DiceLoss(nn.Module):
|
|||||||
return dice_loss(input, target, self.eps)
|
return dice_loss(input, target, self.eps)
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(input, target, eps=0.):
|
def dice_loss(input, target, eps=0.0):
|
||||||
input = input.view(-1)
|
input = input.view(-1)
|
||||||
target = target.view(-1)
|
target = target.view(-1)
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class InstanceNoise:
|
class InstanceNoise:
|
||||||
"""Instance noise, with a linear decaying schedule
|
"""Instance noise, with a linear decaying schedule"""
|
||||||
"""
|
|
||||||
def __init__(self, init_std, batches):
|
def __init__(self, init_std, batches):
|
||||||
assert init_std >= 0, 'Noise std cannot be negative'
|
assert init_std >= 0, "Noise std cannot be negative"
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self._std = init_std
|
self._std = init_std
|
||||||
self.batches = batches
|
self.batches = batches
|
||||||
|
@ -5,17 +5,18 @@ from ..data.norms.cosmology import D
|
|||||||
|
|
||||||
|
|
||||||
def lag2eul(
|
def lag2eul(
|
||||||
dis,
|
dis,
|
||||||
val=1.0,
|
val=1.0,
|
||||||
eul_scale_factor=1,
|
eul_scale_factor=1,
|
||||||
eul_pad=0,
|
eul_pad=0,
|
||||||
rm_dis_mean=True,
|
rm_dis_mean=True,
|
||||||
periodic=False,
|
periodic=False,
|
||||||
z=0.0,
|
z=0.0,
|
||||||
dis_std=6.0,
|
dis_std=6.0,
|
||||||
boxsize=1000.,
|
boxsize=1000.0,
|
||||||
meshsize=512,
|
meshsize=512,
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
"""Transform fields from Lagrangian description to Eulerian description
|
"""Transform fields from Lagrangian description to Eulerian description
|
||||||
|
|
||||||
Only works for 3d fields, output same mesh size as input.
|
Only works for 3d fields, output same mesh size as input.
|
||||||
@ -48,20 +49,19 @@ def lag2eul(
|
|||||||
if isinstance(val, (float, torch.Tensor)):
|
if isinstance(val, (float, torch.Tensor)):
|
||||||
val = [val]
|
val = [val]
|
||||||
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
|
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
|
||||||
raise ValueError('dis-val field mismatch')
|
raise ValueError("dis-val field mismatch")
|
||||||
|
|
||||||
if any(d.dim() != 5 for d in dis):
|
if any(d.dim() != 5 for d in dis):
|
||||||
raise NotImplementedError('only support 3d fields for now')
|
raise NotImplementedError("only support 3d fields for now")
|
||||||
if any(d.shape[1] != 3 for d in dis):
|
if any(d.shape[1] != 3 for d in dis):
|
||||||
raise ValueError('only support 3d displacement fields')
|
raise ValueError("only support 3d displacement fields")
|
||||||
|
|
||||||
# common mean displacement of all inputs
|
# common mean displacement of all inputs
|
||||||
# if removed, fewer particles go outside of the box
|
# if removed, fewer particles go outside of the box
|
||||||
# common for all inputs so outputs are comparable in the same coords
|
# common for all inputs so outputs are comparable in the same coords
|
||||||
d_mean = 0
|
d_mean = 0
|
||||||
if rm_dis_mean:
|
if rm_dis_mean:
|
||||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) for d in dis) / len(dis)
|
||||||
for d in dis) / len(dis)
|
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
if len(dis) == 1 and len(val) != 1:
|
if len(dis) == 1 and len(val) != 1:
|
||||||
@ -85,12 +85,15 @@ def lag2eul(
|
|||||||
pos = (d - d_mean) * dis_norm
|
pos = (d - d_mean) * dis_norm
|
||||||
del d
|
del d
|
||||||
|
|
||||||
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
|
pos[:, 0] += torch.arange(
|
||||||
dtype=dtype, device=device)[:, None, None]
|
0, DHW[0] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||||
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
|
)[:, None, None]
|
||||||
dtype=dtype, device=device)[:, None]
|
pos[:, 1] += torch.arange(
|
||||||
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
|
0, DHW[1] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||||
dtype=dtype, device=device)
|
)[:, None]
|
||||||
|
pos[:, 2] += torch.arange(
|
||||||
|
0, DHW[2] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
|
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
|
||||||
|
|
||||||
@ -118,8 +121,7 @@ def lag2eul(
|
|||||||
if periodic:
|
if periodic:
|
||||||
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
||||||
|
|
||||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]) * DHW[2] + tgtpos[n, 2]
|
||||||
) * DHW[2] + tgtpos[n, 2]
|
|
||||||
src = v[n]
|
src = v[n]
|
||||||
|
|
||||||
if not periodic:
|
if not periodic:
|
||||||
|
@ -3,8 +3,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
def narrow_by(a, c):
|
def narrow_by(a, c):
|
||||||
"""Narrow a by size c symmetrically on all edges.
|
"""Narrow a by size c symmetrically on all edges."""
|
||||||
"""
|
|
||||||
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
||||||
return a[ind]
|
return a[ind]
|
||||||
|
|
||||||
|
@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.convs = nn.Sequential(
|
self.convs = nn.Sequential(
|
||||||
ConvBlock(in_chan, 32, seq='CA'),
|
ConvBlock(in_chan, 32, seq="CA"),
|
||||||
ConvBlock(32, 64, seq='CBA'),
|
ConvBlock(32, 64, seq="CBA"),
|
||||||
ConvBlock(64, 128, seq='CBA'),
|
ConvBlock(64, 128, seq="CBA"),
|
||||||
ConvBlock(128, out_chan, seq='C'),
|
ConvBlock(128, out_chan, seq="C"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -19,23 +19,20 @@ class PatchGAN(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PatchGAN42(nn.Module):
|
class PatchGAN42(nn.Module):
|
||||||
"""PatchGAN similar to the one in pix2pix
|
"""PatchGAN similar to the one in pix2pix"""
|
||||||
"""
|
|
||||||
def __init__(self, in_chan, out_chan=1, **kwargs):
|
def __init__(self, in_chan, out_chan=1, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.convs = nn.Sequential(
|
self.convs = nn.Sequential(
|
||||||
nn.Conv3d(in_chan, 64, 4, stride=2),
|
nn.Conv3d(in_chan, 64, 4, stride=2),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
nn.Conv3d(64, 128, 4, stride=2),
|
nn.Conv3d(64, 128, 4, stride=2),
|
||||||
nn.BatchNorm3d(128),
|
nn.BatchNorm3d(128),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
nn.Conv3d(128, 256, 4, stride=2),
|
nn.Conv3d(128, 256, 4, stride=2),
|
||||||
nn.BatchNorm3d(256),
|
nn.BatchNorm3d(256),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
nn.Conv3d(256, out_chan, 1),
|
nn.Conv3d(256, out_chan, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,8 +26,7 @@ def power(x: torch.Tensor):
|
|||||||
P = P.sum(dim=0)
|
P = P.sum(dim=0)
|
||||||
del x
|
del x
|
||||||
|
|
||||||
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape]
|
||||||
for d in P.shape]
|
|
||||||
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
|
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
|
||||||
k = torch.meshgrid(*k)
|
k = torch.meshgrid(*k)
|
||||||
k = torch.stack(k, dim=0)
|
k = torch.stack(k, dim=0)
|
||||||
@ -49,9 +48,9 @@ def power(x: torch.Tensor):
|
|||||||
del kbin
|
del kbin
|
||||||
|
|
||||||
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
||||||
k = k[1:1+kmax]
|
k = k[1 : 1 + kmax]
|
||||||
P = P[1:1+kmax]
|
P = P[1 : 1 + kmax]
|
||||||
N = N[1:1+kmax]
|
N = N[1 : 1 + kmax]
|
||||||
|
|
||||||
k /= N
|
k /= N
|
||||||
P /= N
|
P /= N
|
||||||
|
@ -5,12 +5,11 @@ from .narrow import narrow_by
|
|||||||
|
|
||||||
|
|
||||||
def resample(x, scale_factor, narrow=True):
|
def resample(x, scale_factor, narrow=True):
|
||||||
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
|
||||||
ndim = x.dim() - 2
|
ndim = x.dim() - 2
|
||||||
mode = modes[ndim]
|
mode = modes[ndim]
|
||||||
|
|
||||||
x = F.interpolate(x, scale_factor=scale_factor,
|
x = F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=False)
|
||||||
mode=mode, align_corners=False)
|
|
||||||
|
|
||||||
if scale_factor > 1 and narrow == True:
|
if scale_factor > 1 and narrow == True:
|
||||||
edges = round(scale_factor) // 2
|
edges = round(scale_factor) // 2
|
||||||
@ -25,18 +24,20 @@ class Resampler(nn.Module):
|
|||||||
|
|
||||||
By default discard the inaccurate edges when upsampling.
|
By default discard the inaccurate edges when upsampling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ndim, scale_factor, narrow=True):
|
def __init__(self, ndim, scale_factor, narrow=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
|
||||||
self.mode = modes[ndim]
|
self.mode = modes[ndim]
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.narrow = narrow
|
self.narrow = narrow
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.interpolate(x, scale_factor=self.scale_factor,
|
x = F.interpolate(
|
||||||
mode=self.mode, align_corners=False)
|
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
if self.scale_factor > 1 and self.narrow == True:
|
if self.scale_factor > 1 and self.narrow == True:
|
||||||
edges = round(self.scale_factor) // 2
|
edges = round(self.scale_factor) // 2
|
||||||
|
@ -4,8 +4,18 @@ from torch.nn.utils import spectral_norm, remove_spectral_norm
|
|||||||
|
|
||||||
def add_spectral_norm(module):
|
def add_spectral_norm(module):
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
if isinstance(
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
child,
|
||||||
|
(
|
||||||
|
nn.Linear,
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
),
|
||||||
|
):
|
||||||
setattr(module, name, spectral_norm(child))
|
setattr(module, name, spectral_norm(child))
|
||||||
else:
|
else:
|
||||||
add_spectral_norm(child)
|
add_spectral_norm(child)
|
||||||
@ -13,8 +23,18 @@ def add_spectral_norm(module):
|
|||||||
|
|
||||||
def rm_spectral_norm(module):
|
def rm_spectral_norm(module):
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
if isinstance(
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
child,
|
||||||
|
(
|
||||||
|
nn.Linear,
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
),
|
||||||
|
):
|
||||||
setattr(module, name, remove_spectral_norm(child))
|
setattr(module, name, remove_spectral_norm(child))
|
||||||
else:
|
else:
|
||||||
rm_spectral_norm(child)
|
rm_spectral_norm(child)
|
||||||
|
@ -7,9 +7,17 @@ from .resample import Resampler
|
|||||||
|
|
||||||
|
|
||||||
class G(nn.Module):
|
class G(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
def __init__(
|
||||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
|
self,
|
||||||
**kwargs):
|
in_chan,
|
||||||
|
out_chan,
|
||||||
|
scale_factor=16,
|
||||||
|
chan_base=512,
|
||||||
|
chan_min=64,
|
||||||
|
chan_max=512,
|
||||||
|
cat_noise=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -30,15 +38,14 @@ class G(nn.Module):
|
|||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
for b in range(num_blocks):
|
for b in range(num_blocks):
|
||||||
prev_chan, next_chan = chan(b), chan(b+1)
|
prev_chan, next_chan = chan(b), chan(b + 1)
|
||||||
self.blocks.append(
|
self.blocks.append(HBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||||
HBlock(prev_chan, next_chan, out_chan, cat_noise))
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x # direct upsampling from the input
|
y = x # direct upsampling from the input
|
||||||
x = self.block0(x)
|
x = self.block0(x)
|
||||||
|
|
||||||
#y = None # no direct upsampling from the input
|
# y = None # no direct upsampling from the input
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x, y = block(x, y)
|
x, y = block(x, y)
|
||||||
|
|
||||||
@ -71,6 +78,7 @@ class HBlock(nn.Module):
|
|||||||
-----
|
-----
|
||||||
next_size = 2 * prev_size - 6
|
next_size = 2 * prev_size - 6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
|
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -81,7 +89,6 @@ class HBlock(nn.Module):
|
|||||||
self.upsample,
|
self.upsample,
|
||||||
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
|
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
|
||||||
nn.LeakyReLU(0.2, True),
|
nn.LeakyReLU(0.2, True),
|
||||||
|
|
||||||
AddNoise(cat_noise, chan=next_chan),
|
AddNoise(cat_noise, chan=next_chan),
|
||||||
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
|
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
|
||||||
nn.LeakyReLU(0.2, True),
|
nn.LeakyReLU(0.2, True),
|
||||||
@ -114,6 +121,7 @@ class AddNoise(nn.Module):
|
|||||||
The number of channels `chan` should be 1 (StyleGAN2)
|
The number of channels `chan` should be 1 (StyleGAN2)
|
||||||
or that of the input (StyleGAN).
|
or that of the input (StyleGAN).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cat, chan=1):
|
def __init__(self, cat, chan=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -137,9 +145,16 @@ class AddNoise(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class D(nn.Module):
|
class D(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
def __init__(
|
||||||
chan_base=512, chan_min=64, chan_max=512,
|
self,
|
||||||
**kwargs):
|
in_chan,
|
||||||
|
out_chan,
|
||||||
|
scale_factor=16,
|
||||||
|
chan_base=512,
|
||||||
|
chan_min=64,
|
||||||
|
chan_max=512,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -163,7 +178,7 @@ class D(nn.Module):
|
|||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
for b in reversed(range(num_blocks)):
|
for b in reversed(range(num_blocks)):
|
||||||
prev_chan, next_chan = chan(b+1), chan(b)
|
prev_chan, next_chan = chan(b + 1), chan(b)
|
||||||
self.blocks.append(ResBlock(prev_chan, next_chan))
|
self.blocks.append(ResBlock(prev_chan, next_chan))
|
||||||
|
|
||||||
self.block9 = nn.Sequential(
|
self.block9 = nn.Sequential(
|
||||||
@ -192,13 +207,13 @@ class ResBlock(nn.Module):
|
|||||||
-----
|
-----
|
||||||
next_size = (prev_size - 4) // 2
|
next_size = (prev_size - 4) // 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prev_chan, next_chan):
|
def __init__(self, prev_chan, next_chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv3d(prev_chan, prev_chan, 3),
|
nn.Conv3d(prev_chan, prev_chan, 3),
|
||||||
nn.LeakyReLU(0.2, True),
|
nn.LeakyReLU(0.2, True),
|
||||||
|
|
||||||
nn.Conv3d(prev_chan, next_chan, 3),
|
nn.Conv3d(prev_chan, next_chan, 3),
|
||||||
nn.LeakyReLU(0.2, True),
|
nn.LeakyReLU(0.2, True),
|
||||||
)
|
)
|
||||||
|
@ -9,6 +9,7 @@ class PixelNorm(nn.Module):
|
|||||||
|
|
||||||
See ProGAN, StyleGAN.
|
See ProGAN, StyleGAN.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ class LinearElr(nn.Module):
|
|||||||
|
|
||||||
Useful at all if not for regularization(1706.05350)?
|
Useful at all if not for regularization(1706.05350)?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_size, out_size, bias=True, act=None):
|
def __init__(self, in_size, out_size, bias=True, act=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class LinearElr(nn.Module):
|
|||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.zeros(out_size))
|
self.bias = nn.Parameter(torch.zeros(out_size))
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.act = act
|
self.act = act
|
||||||
|
|
||||||
@ -52,20 +54,20 @@ class ConvElr3d(nn.Module):
|
|||||||
|
|
||||||
Useful at all if not for regularization(1706.05350)?
|
Useful at all if not for regularization(1706.05350)?
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan, kernel_size,
|
|
||||||
stride=1, padding=0, bias=True):
|
def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(
|
||||||
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
||||||
)
|
)
|
||||||
fan_in = in_chan * kernel_size ** 3
|
fan_in = in_chan * kernel_size**3
|
||||||
self.wnorm = 1 / math.sqrt(fan_in)
|
self.wnorm = 1 / math.sqrt(fan_in)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
@ -87,38 +89,49 @@ class ConvStyled3d(nn.Module):
|
|||||||
|
|
||||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
|
|
||||||
bias=True, resample=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
style_size,
|
||||||
|
in_chan,
|
||||||
|
out_chan,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
bias=True,
|
||||||
|
resample=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
||||||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
nn.init.kaiming_uniform_(
|
||||||
mode='fan_in', nonlinearity='leaky_relu')
|
self.style_weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"
|
||||||
|
)
|
||||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||||
if resample is None:
|
if resample is None:
|
||||||
K3 = (kernel_size,) * 3
|
K3 = (kernel_size,) * 3
|
||||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv = F.conv3d
|
self.conv = F.conv3d
|
||||||
elif resample == 'U':
|
elif resample == "U":
|
||||||
K3 = (2,) * 3
|
K3 = (2,) * 3
|
||||||
# NOTE not clear to me why convtranspose have channels swapped
|
# NOTE not clear to me why convtranspose have channels swapped
|
||||||
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
||||||
self.stride = 2
|
self.stride = 2
|
||||||
self.conv = F.conv_transpose3d
|
self.conv = F.conv_transpose3d
|
||||||
elif resample == 'D':
|
elif resample == "D":
|
||||||
K3 = (2,) * 3
|
K3 = (2,) * 3
|
||||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||||
self.stride = 2
|
self.stride = 2
|
||||||
self.conv = F.conv3d
|
self.conv = F.conv3d
|
||||||
else:
|
else:
|
||||||
raise ValueError('resample type {} not supported'.format(resample))
|
raise ValueError("resample type {} not supported".format(resample))
|
||||||
self.resample = resample
|
self.resample = resample
|
||||||
|
|
||||||
nn.init.kaiming_uniform_(
|
nn.init.kaiming_uniform_(
|
||||||
self.weight, a=math.sqrt(5),
|
self.weight,
|
||||||
mode='fan_in', # effectively 'fan_out' for 'D'
|
a=math.sqrt(5),
|
||||||
nonlinearity='leaky_relu',
|
mode="fan_in", # effectively 'fan_out' for 'D'
|
||||||
|
nonlinearity="leaky_relu",
|
||||||
)
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
@ -127,12 +140,12 @@ class ConvStyled3d(nn.Module):
|
|||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
nn.init.uniform_(self.bias, -bound, bound)
|
nn.init.uniform_(self.bias, -bound, bound)
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(self, x, s, eps=1e-8):
|
def forward(self, x, s, eps=1e-8):
|
||||||
N, Cin, *DHWin = x.shape
|
N, Cin, *DHWin = x.shape
|
||||||
C0, C1, *K3 = self.weight.shape
|
C0, C1, *K3 = self.weight.shape
|
||||||
if self.resample == 'U':
|
if self.resample == "U":
|
||||||
Cin, Cout = C0, C1
|
Cin, Cout = C0, C1
|
||||||
else:
|
else:
|
||||||
Cout, Cin = C0, C1
|
Cout, Cin = C0, C1
|
||||||
@ -140,14 +153,14 @@ class ConvStyled3d(nn.Module):
|
|||||||
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
||||||
|
|
||||||
# modulation
|
# modulation
|
||||||
if self.resample == 'U':
|
if self.resample == "U":
|
||||||
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
||||||
else:
|
else:
|
||||||
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
||||||
w = self.weight * s
|
w = self.weight * s
|
||||||
|
|
||||||
# demodulation
|
# demodulation
|
||||||
if self.resample == 'U':
|
if self.resample == "U":
|
||||||
fan_in_dim = (1, 3, 4, 5)
|
fan_in_dim = (1, 3, 4, 5)
|
||||||
else:
|
else:
|
||||||
fan_in_dim = (2, 3, 4, 5)
|
fan_in_dim = (2, 3, 4, 5)
|
||||||
@ -161,18 +174,22 @@ class ConvStyled3d(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class BatchNormStyled3d(nn.BatchNorm3d) :
|
|
||||||
""" Trivially does standard batch normalization, but accepts second argument
|
class BatchNormStyled3d(nn.BatchNorm3d):
|
||||||
|
"""Trivially does standard batch normalization, but accepts second argument
|
||||||
|
|
||||||
for style array that is not used
|
for style array that is not used
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
class LeakyReLUStyled(nn.LeakyReLU):
|
class LeakyReLUStyled(nn.LeakyReLU):
|
||||||
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
"""Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||||
|
|
||||||
for sytle array that is not used
|
for sytle array that is not used
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
@ -16,8 +16,17 @@ class ConvStyledBlock(nn.Module):
|
|||||||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||||
'D': downsampling convolution of kernel size 2 and stride 2
|
'D': downsampling convolution of kernel size 2 and stride 2
|
||||||
"""
|
"""
|
||||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
|
||||||
kernel_size=3, stride=1, seq='CBA'):
|
def __init__(
|
||||||
|
self,
|
||||||
|
style_size,
|
||||||
|
in_chan,
|
||||||
|
out_chan=None,
|
||||||
|
mid_chan=None,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
seq="CBA",
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if out_chan is None:
|
if out_chan is None:
|
||||||
@ -33,31 +42,34 @@ class ConvStyledBlock(nn.Module):
|
|||||||
|
|
||||||
self.norm_chan = in_chan
|
self.norm_chan = in_chan
|
||||||
self.idx_conv = 0
|
self.idx_conv = 0
|
||||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
|
||||||
|
|
||||||
layers = [self._get_layer(l) for l in seq]
|
layers = [self._get_layer(l) for l in seq]
|
||||||
|
|
||||||
self.convs = nn.ModuleList(layers)
|
self.convs = nn.ModuleList(layers)
|
||||||
|
|
||||||
def _get_layer(self, l):
|
def _get_layer(self, l):
|
||||||
if l == 'U':
|
if l == "U":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
return ConvStyled3d(
|
||||||
resample = 'U')
|
self.style_size, in_chan, out_chan, 2, stride=2, resample="U"
|
||||||
elif l == 'D':
|
)
|
||||||
|
elif l == "D":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
return ConvStyled3d(
|
||||||
resample = 'D')
|
self.style_size, in_chan, out_chan, 2, stride=2, resample="D"
|
||||||
elif l == 'C':
|
)
|
||||||
|
elif l == "C":
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
return ConvStyled3d(
|
||||||
stride=self.stride)
|
self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride
|
||||||
elif l == 'B':
|
)
|
||||||
|
elif l == "B":
|
||||||
return BatchNormStyled3d(self.norm_chan)
|
return BatchNormStyled3d(self.norm_chan)
|
||||||
elif l == 'A':
|
elif l == "A":
|
||||||
return LeakyReLUStyled()
|
return LeakyReLUStyled()
|
||||||
else:
|
else:
|
||||||
raise ValueError('layer type {} not supported'.format(l))
|
raise ValueError("layer type {} not supported".format(l))
|
||||||
|
|
||||||
def _setup_conv(self):
|
def _setup_conv(self):
|
||||||
self.idx_conv += 1
|
self.idx_conv += 1
|
||||||
@ -92,13 +104,23 @@ class ResStyledBlock(ConvStyledBlock):
|
|||||||
|
|
||||||
See `ConvStyledBlock` for `seq` types.
|
See `ConvStyledBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
|
||||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
style_size,
|
||||||
|
in_chan,
|
||||||
|
out_chan=None,
|
||||||
|
mid_chan=None,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
seq="CBACBA",
|
||||||
|
last_act=None,
|
||||||
|
):
|
||||||
if last_act is None:
|
if last_act is None:
|
||||||
last_act = seq[-1] == 'A'
|
last_act = seq[-1] == "A"
|
||||||
elif last_act and seq[-1] != 'A':
|
elif last_act and seq[-1] != "A":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Disabling last_act without trailing activation in seq',
|
"Disabling last_act without trailing activation in seq",
|
||||||
RuntimeWarning,
|
RuntimeWarning,
|
||||||
)
|
)
|
||||||
last_act = False
|
last_act = False
|
||||||
@ -106,8 +128,15 @@ class ResStyledBlock(ConvStyledBlock):
|
|||||||
if last_act:
|
if last_act:
|
||||||
seq = seq[:-1]
|
seq = seq[:-1]
|
||||||
|
|
||||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
super().__init__(
|
||||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
style_size,
|
||||||
|
in_chan,
|
||||||
|
out_chan=out_chan,
|
||||||
|
mid_chan=mid_chan,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
seq=seq,
|
||||||
|
)
|
||||||
|
|
||||||
if last_act:
|
if last_act:
|
||||||
self.act = LeakyReLUStyled()
|
self.act = LeakyReLUStyled()
|
||||||
@ -119,9 +148,10 @@ class ResStyledBlock(ConvStyledBlock):
|
|||||||
else:
|
else:
|
||||||
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||||
|
|
||||||
if 'U' in seq or 'D' in seq:
|
if "U" in seq or "D" in seq:
|
||||||
raise NotImplementedError('upsample and downsample layers '
|
raise NotImplementedError(
|
||||||
'not supported yet')
|
"upsample and downsample layers " "not supported yet"
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
y = x
|
y = x
|
||||||
|
@ -15,17 +15,17 @@ class StyledVNet(nn.Module):
|
|||||||
|
|
||||||
# activate non-identity skip connection in residual block
|
# activate non-identity skip connection in residual block
|
||||||
# by explicitly setting out_chan
|
# by explicitly setting out_chan
|
||||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
|
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq="CACBA")
|
||||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
|
self.down_l0 = ConvStyledBlock(style_size, 64, seq="DBA")
|
||||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
|
||||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
|
self.down_l1 = ConvStyledBlock(style_size, 64, seq="DBA")
|
||||||
|
|
||||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
self.conv_c = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
|
||||||
|
|
||||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
|
self.up_r1 = ConvStyledBlock(style_size, 64, seq="UBA")
|
||||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
|
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq="CBACBA")
|
||||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
self.up_r0 = ConvStyledBlock(style_size, 64, seq="UBA")
|
||||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq="CAC")
|
||||||
|
|
||||||
if bypass is None:
|
if bypass is None:
|
||||||
self.bypass = in_chan == out_chan
|
self.bypass = in_chan == out_chan
|
||||||
|
@ -19,17 +19,17 @@ class UNet(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA')
|
self.conv_l0 = ConvBlock(in_chan, 64, seq="CACBA")
|
||||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
self.down_l0 = ConvBlock(64, seq="DBA")
|
||||||
self.conv_l1 = ConvBlock(64, seq='CBACBA')
|
self.conv_l1 = ConvBlock(64, seq="CBACBA")
|
||||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
self.down_l1 = ConvBlock(64, seq="DBA")
|
||||||
|
|
||||||
self.conv_c = ConvBlock(64, seq='CBACBA')
|
self.conv_c = ConvBlock(64, seq="CBACBA")
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
self.up_r1 = ConvBlock(64, seq="UBA")
|
||||||
self.conv_r1 = ConvBlock(128, 64, seq='CBACBA')
|
self.conv_r1 = ConvBlock(128, 64, seq="CBACBA")
|
||||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
self.up_r0 = ConvBlock(64, seq="UBA")
|
||||||
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
|
self.conv_r0 = ConvBlock(128, out_chan, seq="CAC")
|
||||||
|
|
||||||
self.bypass = in_chan == out_chan
|
self.bypass = in_chan == out_chan
|
||||||
|
|
||||||
|
@ -23,17 +23,17 @@ class VNet(nn.Module):
|
|||||||
|
|
||||||
# activate non-identity skip connection in residual block
|
# activate non-identity skip connection in residual block
|
||||||
# by explicitly setting out_chan
|
# by explicitly setting out_chan
|
||||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
self.conv_l0 = ResBlock(in_chan, 64, seq="CACBA")
|
||||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
self.down_l0 = ConvBlock(64, seq="DBA")
|
||||||
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
self.conv_l1 = ResBlock(64, 64, seq="CBACBA")
|
||||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
self.down_l1 = ConvBlock(64, seq="DBA")
|
||||||
|
|
||||||
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
self.conv_c = ResBlock(64, 64, seq="CBACBA")
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
self.up_r1 = ConvBlock(64, seq="UBA")
|
||||||
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
self.conv_r1 = ResBlock(128, 64, seq="CBACBA")
|
||||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
self.up_r0 = ConvBlock(64, seq="UBA")
|
||||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
self.conv_r0 = ResBlock(128, out_chan, seq="CAC")
|
||||||
|
|
||||||
if bypass is None:
|
if bypass is None:
|
||||||
self.bypass = in_chan == out_chan
|
self.bypass = in_chan == out_chan
|
||||||
|
@ -16,21 +16,21 @@ from .utils import import_attr, load_model_state_dict
|
|||||||
def test(args):
|
def test(args):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
warnings.warn('Not parallelized but given more than 1 GPUs')
|
warnings.warn("Not parallelized but given more than 1 GPUs")
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
device = torch.device('cuda', 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
else: # CPU multithreading
|
else: # CPU multithreading
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
|
|
||||||
if args.num_threads is None:
|
if args.num_threads is None:
|
||||||
args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE'])
|
args.num_threads = int(os.environ["SLURM_CPUS_ON_NODE"])
|
||||||
|
|
||||||
torch.set_num_threads(args.num_threads)
|
torch.set_num_threads(args.num_threads)
|
||||||
|
|
||||||
print('pytorch {}'.format(torch.__version__))
|
print("pytorch {}".format(torch.__version__))
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
@ -67,26 +67,33 @@ def test(args):
|
|||||||
out_chan = test_dataset.tgt_chan
|
out_chan = test_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(style_size, sum(in_chan), sum(out_chan),
|
model = model(
|
||||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
style_size,
|
||||||
|
sum(in_chan),
|
||||||
|
sum(out_chan),
|
||||||
|
scale_factor=args.scale_factor,
|
||||||
|
**args.misc_kwargs,
|
||||||
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, torch.nn, models,
|
criterion = import_attr(
|
||||||
callback_at=args.callback_at)
|
args.criterion, torch.nn, models, callback_at=args.callback_at
|
||||||
|
)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
load_model_state_dict(model, state["model"], strict=args.load_state_strict)
|
||||||
print('model state at epoch {} loaded from {}'.format(
|
print(
|
||||||
state['epoch'], args.load_state))
|
"model state at epoch {} loaded from {}".format(state["epoch"], args.load_state)
|
||||||
|
)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, data in enumerate(test_loader):
|
for i, data in enumerate(test_loader):
|
||||||
style, input, target = data['style'], data['input'], data['target']
|
style, input, target = data["style"], data["input"], data["target"]
|
||||||
|
|
||||||
style = style.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
@ -94,21 +101,21 @@ def test(args):
|
|||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
if i < 5:
|
if i < 5:
|
||||||
print('##### sample :', i)
|
print("##### sample :", i)
|
||||||
print('style shape :', style.shape)
|
print("style shape :", style.shape)
|
||||||
print('input shape :', input.shape)
|
print("input shape :", input.shape)
|
||||||
print('output shape :', output.shape)
|
print("output shape :", output.shape)
|
||||||
print('target shape :', target.shape)
|
print("target shape :", target.shape)
|
||||||
|
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
if i < 5:
|
if i < 5:
|
||||||
print('narrowed shape :', output.shape, flush=True)
|
print("narrowed shape :", output.shape, flush=True)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
print('sample {} loss: {}'.format(i, loss.item()))
|
print("sample {} loss: {}".format(i, loss.item()))
|
||||||
|
|
||||||
#if args.in_norms is not None:
|
# if args.in_norms is not None:
|
||||||
# start = 0
|
# start = 0
|
||||||
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||||
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
@ -119,12 +126,11 @@ def test(args):
|
|||||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
|
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
|
||||||
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
# norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
||||||
start = stop
|
start = stop
|
||||||
|
|
||||||
#test_dataset.assemble('_in', in_chan, input,
|
# test_dataset.assemble('_in', in_chan, input,
|
||||||
# data['input_relpath'])
|
# data['input_relpath'])
|
||||||
test_dataset.assemble('_out', out_chan, output,
|
test_dataset.assemble("_out", out_chan, output, data["target_relpath"])
|
||||||
data['target_relpath'])
|
# test_dataset.assemble('_tgt', out_chan, target,
|
||||||
#test_dataset.assemble('_tgt', out_chan, target,
|
|
||||||
# data['target_relpath'])
|
# data['target_relpath'])
|
||||||
|
269
map2map/train.py
269
map2map/train.py
@ -18,34 +18,34 @@ from .models import narrow_cast, resample
|
|||||||
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
||||||
|
|
||||||
|
|
||||||
ckpt_link = 'checkpoint.pt'
|
ckpt_link = "checkpoint.pt"
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
if "SLURM_STEP_NUM_NODES" in os.environ:
|
||||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
args.nodes = int(os.environ["SLURM_STEP_NUM_NODES"])
|
||||||
elif 'SLURM_JOB_NUM_NODES' in os.environ:
|
elif "SLURM_JOB_NUM_NODES" in os.environ:
|
||||||
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
args.nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
||||||
else:
|
else:
|
||||||
raise KeyError('missing node counts in slurm env')
|
raise KeyError("missing node counts in slurm env")
|
||||||
args.gpus_per_node = torch.cuda.device_count()
|
args.gpus_per_node = torch.cuda.device_count()
|
||||||
args.world_size = args.nodes * args.gpus_per_node
|
args.world_size = args.nodes * args.gpus_per_node
|
||||||
|
|
||||||
node = int(os.environ['SLURM_NODEID'])
|
node = int(os.environ["SLURM_NODEID"])
|
||||||
|
|
||||||
if args.gpus_per_node < 1:
|
if args.gpus_per_node < 1:
|
||||||
raise RuntimeError('GPU not found on node {}'.format(node))
|
raise RuntimeError("GPU not found on node {}".format(node))
|
||||||
|
|
||||||
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
|
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
|
||||||
|
|
||||||
|
|
||||||
def gpu_worker(local_rank, node, args):
|
def gpu_worker(local_rank, node, args):
|
||||||
#device = torch.device('cuda', local_rank)
|
# device = torch.device('cuda', local_rank)
|
||||||
#torch.cuda.device(device) # env var recommended over this
|
# torch.cuda.device(device) # env var recommended over this
|
||||||
|
|
||||||
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
|
||||||
device = torch.device('cuda', 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
rank = args.gpus_per_node * node + local_rank
|
rank = args.gpus_per_node * node + local_rank
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
# Note DDP broadcasts initial model states from rank 0
|
# Note DDP broadcasts initial model states from rank 0
|
||||||
torch.manual_seed(args.seed + rank)
|
torch.manual_seed(args.seed + rank)
|
||||||
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic
|
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic
|
||||||
#torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
dist_init(rank, args)
|
dist_init(rank, args)
|
||||||
|
|
||||||
@ -77,9 +77,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
train_sampler = DistFieldSampler(
|
||||||
div_data=args.div_data,
|
train_dataset,
|
||||||
div_shuffle_dist=args.div_shuffle_dist)
|
shuffle=True,
|
||||||
|
div_data=args.div_data,
|
||||||
|
div_shuffle_dist=args.div_shuffle_dist,
|
||||||
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -110,9 +113,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
val_sampler = DistFieldSampler(
|
||||||
div_data=args.div_data,
|
val_dataset,
|
||||||
div_shuffle_dist=args.div_shuffle_dist)
|
shuffle=False,
|
||||||
|
div_data=args.div_data,
|
||||||
|
div_shuffle_dist=args.div_shuffle_dist,
|
||||||
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -127,14 +133,19 @@ def gpu_worker(local_rank, node, args):
|
|||||||
args.out_chan = train_dataset.tgt_chan
|
args.out_chan = train_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
model = model(
|
||||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
args.style_size,
|
||||||
|
sum(args.in_chan),
|
||||||
|
sum(args.out_chan),
|
||||||
|
scale_factor=args.scale_factor,
|
||||||
|
**args.misc_kwargs,
|
||||||
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(
|
||||||
process_group=dist.new_group())
|
model, device_ids=[device], process_group=dist.new_group()
|
||||||
|
)
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, nn, models,
|
criterion = import_attr(args.criterion, nn, models, callback_at=args.callback_at)
|
||||||
callback_at=args.callback_at)
|
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
@ -144,11 +155,13 @@ def gpu_worker(local_rank, node, args):
|
|||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
**args.optimizer_args,
|
**args.optimizer_args,
|
||||||
)
|
)
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **args.scheduler_args)
|
||||||
optimizer, **args.scheduler_args)
|
|
||||||
|
|
||||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
if (
|
||||||
or not args.load_state):
|
args.load_state == ckpt_link
|
||||||
|
and not os.path.isfile(ckpt_link)
|
||||||
|
or not args.load_state
|
||||||
|
):
|
||||||
if args.init_weight_std is not None:
|
if args.init_weight_std is not None:
|
||||||
model.apply(init_weights)
|
model.apply(init_weights)
|
||||||
|
|
||||||
@ -159,23 +172,28 @@ def gpu_worker(local_rank, node, args):
|
|||||||
else:
|
else:
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
|
|
||||||
start_epoch = state['epoch']
|
start_epoch = state["epoch"]
|
||||||
|
|
||||||
load_model_state_dict(model.module, state['model'],
|
load_model_state_dict(
|
||||||
strict=args.load_state_strict)
|
model.module, state["model"], strict=args.load_state_strict
|
||||||
|
)
|
||||||
|
|
||||||
if 'optimizer' in state:
|
if "optimizer" in state:
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
optimizer.load_state_dict(state["optimizer"])
|
||||||
if 'scheduler' in state:
|
if "scheduler" in state:
|
||||||
scheduler.load_state_dict(state['scheduler'])
|
scheduler.load_state_dict(state["scheduler"])
|
||||||
|
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state["rng"].cpu()) # move rng state back
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state["min_loss"]
|
||||||
|
|
||||||
print('state at epoch {} loaded from {}'.format(
|
print(
|
||||||
state['epoch'], args.load_state), flush=True)
|
"state at epoch {} loaded from {}".format(
|
||||||
|
state["epoch"], args.load_state
|
||||||
|
),
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
del state
|
del state
|
||||||
|
|
||||||
@ -189,21 +207,31 @@ def gpu_worker(local_rank, node, args):
|
|||||||
logger = SummaryWriter()
|
logger = SummaryWriter()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print('pytorch {}'.format(torch.__version__))
|
print("pytorch {}".format(torch.__version__))
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader, model, criterion,
|
train_loss = train(
|
||||||
optimizer, scheduler, logger, device, args)
|
epoch,
|
||||||
|
train_loader,
|
||||||
|
model,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
logger,
|
||||||
|
device,
|
||||||
|
args,
|
||||||
|
)
|
||||||
epoch_loss = train_loss
|
epoch_loss = train_loss
|
||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_loss = validate(epoch, val_loader, model, criterion,
|
val_loss = validate(
|
||||||
logger, device, args)
|
epoch, val_loader, model, criterion, logger, device, args
|
||||||
#epoch_loss = val_loss
|
)
|
||||||
|
# epoch_loss = val_loss
|
||||||
|
|
||||||
if args.reduce_lr_on_plateau:
|
if args.reduce_lr_on_plateau:
|
||||||
scheduler.step(epoch_loss[2])
|
scheduler.step(epoch_loss[2])
|
||||||
@ -215,27 +243,26 @@ def gpu_worker(local_rank, node, args):
|
|||||||
min_loss = epoch_loss[2]
|
min_loss = epoch_loss[2]
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
"epoch": epoch + 1,
|
||||||
'model': model.module.state_dict(),
|
"model": model.module.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
'scheduler': scheduler.state_dict(),
|
"scheduler": scheduler.state_dict(),
|
||||||
'rng': torch.get_rng_state(),
|
"rng": torch.get_rng_state(),
|
||||||
'min_loss': min_loss,
|
"min_loss": min_loss,
|
||||||
}
|
}
|
||||||
|
|
||||||
state_file = 'state_{}.pt'.format(epoch + 1)
|
state_file = "state_{}.pt".format(epoch + 1)
|
||||||
torch.save(state, state_file)
|
torch.save(state, state_file)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
tmp_link = '{}.pt'.format(time.time())
|
tmp_link = "{}.pt".format(time.time())
|
||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, criterion,
|
def train(epoch, loader, model, criterion, optimizer, scheduler, logger, device, args):
|
||||||
optimizer, scheduler, logger, device, args):
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
@ -246,7 +273,7 @@ def train(epoch, loader, model, criterion,
|
|||||||
for i, data in enumerate(loader):
|
for i, data in enumerate(loader):
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
|
|
||||||
style, input, target = data['style'], data['input'], data['target']
|
style, input, target = data["style"], data["input"], data["target"]
|
||||||
|
|
||||||
style = style.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
@ -254,23 +281,21 @@ def train(epoch, loader, model, criterion,
|
|||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
if batch <= 5 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
print('##### batch :', batch)
|
print("##### batch :", batch)
|
||||||
print('style shape :', style.shape)
|
print("style shape :", style.shape)
|
||||||
print('input shape :', input.shape)
|
print("input shape :", input.shape)
|
||||||
print('output shape :', output.shape)
|
print("output shape :", output.shape)
|
||||||
print('target shape :', target.shape)
|
print("target shape :", target.shape)
|
||||||
|
|
||||||
if (hasattr(model.module, 'scale_factor')
|
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
|
||||||
and model.module.scale_factor != 1):
|
|
||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
if batch <= 5 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
print('narrowed shape :', output.shape)
|
print("narrowed shape :", output.shape)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.detach()
|
epoch_loss[0] += loss.detach()
|
||||||
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -280,43 +305,45 @@ def train(epoch, loader, model, criterion,
|
|||||||
dist.all_reduce(loss)
|
dist.all_reduce(loss)
|
||||||
loss /= world_size
|
loss /= world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/batch/train', loss.item(),
|
logger.add_scalar("loss/batch/train", loss.item(), global_step=batch)
|
||||||
global_step=batch)
|
|
||||||
|
|
||||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
logger.add_scalar("grad/first", grads[0], global_step=batch)
|
||||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
logger.add_scalar("grad/last", grads[-1], global_step=batch)
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
logger.add_scalar("loss/epoch/train", epoch_loss[0], global_step=epoch + 1)
|
||||||
global_step=epoch+1)
|
|
||||||
|
|
||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
input[-1],
|
||||||
|
output[-1, skip_chan:],
|
||||||
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
title=["in", "out", "tgt", "out - tgt"],
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
logger.add_figure("fig/train", fig, global_step=epoch + 1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
fig = plt_power(
|
fig = plt_power(
|
||||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
input,
|
||||||
label=['in', 'out', 'tgt'],
|
output[:, skip_chan:],
|
||||||
|
target[:, skip_chan:],
|
||||||
|
label=["in", "out", "tgt"],
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
logger.add_figure("fig/train/power/lag", fig, global_step=epoch + 1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(1.0,
|
# fig = plt_power(1.0,
|
||||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||||
# label=['in', 'out', 'tgt'],
|
# label=['in', 'out', 'tgt'],
|
||||||
# **args.misc_kwargs,
|
# **args.misc_kwargs,
|
||||||
#)
|
# )
|
||||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
# logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
# fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
@ -331,7 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in loader:
|
for data in loader:
|
||||||
style, input, target = data['style'], data['input'], data['target']
|
style, input, target = data["style"], data["input"], data["target"]
|
||||||
|
|
||||||
style = style.to(device, non_blocking=True)
|
style = style.to(device, non_blocking=True)
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
@ -339,8 +366,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
|
|
||||||
if (hasattr(model.module, 'scale_factor')
|
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
|
||||||
and model.module.scale_factor != 1):
|
|
||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
@ -350,40 +376,43 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
logger.add_scalar("loss/epoch/val", epoch_loss[0], global_step=epoch + 1)
|
||||||
global_step=epoch+1)
|
|
||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
|
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
input[-1],
|
||||||
|
output[-1, skip_chan:],
|
||||||
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
title=["in", "out", "tgt", "out - tgt"],
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
logger.add_figure("fig/val", fig, global_step=epoch + 1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
fig = plt_power(
|
fig = plt_power(
|
||||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
input,
|
||||||
label=['in', 'out', 'tgt'],
|
output[:, skip_chan:],
|
||||||
|
target[:, skip_chan:],
|
||||||
|
label=["in", "out", "tgt"],
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
logger.add_figure("fig/val/power/lag", fig, global_step=epoch + 1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(1.0,
|
# fig = plt_power(1.0,
|
||||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||||
# label=['in', 'out', 'tgt'],
|
# label=['in', 'out', 'tgt'],
|
||||||
# **args.misc_kwargs,
|
# **args.misc_kwargs,
|
||||||
#)
|
# )
|
||||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
# logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
# fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def dist_init(rank, args):
|
def dist_init(rank, args):
|
||||||
dist_file = 'dist_addr'
|
dist_file = "dist_addr"
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
addr = socket.gethostname()
|
addr = socket.gethostname()
|
||||||
@ -393,15 +422,15 @@ def dist_init(rank, args):
|
|||||||
s.bind((addr, 0))
|
s.bind((addr, 0))
|
||||||
_, port = s.getsockname()
|
_, port = s.getsockname()
|
||||||
|
|
||||||
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
|
args.dist_addr = "tcp://{}:{}".format(addr, port)
|
||||||
|
|
||||||
with open(dist_file, mode='w') as f:
|
with open(dist_file, mode="w") as f:
|
||||||
f.write(args.dist_addr)
|
f.write(args.dist_addr)
|
||||||
else:
|
else:
|
||||||
while not os.path.exists(dist_file):
|
while not os.path.exists(dist_file):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
with open(dist_file, mode='r') as f:
|
with open(dist_file, mode="r") as f:
|
||||||
args.dist_addr = f.read()
|
args.dist_addr = f.read()
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
@ -417,19 +446,40 @@ def dist_init(rank, args):
|
|||||||
|
|
||||||
|
|
||||||
def init_weights(m, args):
|
def init_weights(m, args):
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
if isinstance(
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
m,
|
||||||
|
(
|
||||||
|
nn.Linear,
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
),
|
||||||
|
):
|
||||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
elif isinstance(
|
||||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
m,
|
||||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
(
|
||||||
|
nn.BatchNorm1d,
|
||||||
|
nn.BatchNorm2d,
|
||||||
|
nn.BatchNorm3d,
|
||||||
|
nn.SyncBatchNorm,
|
||||||
|
nn.LayerNorm,
|
||||||
|
nn.GroupNorm,
|
||||||
|
nn.InstanceNorm1d,
|
||||||
|
nn.InstanceNorm2d,
|
||||||
|
nn.InstanceNorm3d,
|
||||||
|
),
|
||||||
|
):
|
||||||
if m.affine:
|
if m.affine:
|
||||||
# NOTE: dispersion from DCGAN, why?
|
# NOTE: dispersion from DCGAN, why?
|
||||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||||
m.bias.data.fill_(0)
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
|
|
||||||
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
|
def set_requires_grad(module: torch.nn.Module, requires_grad: bool = False):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.requires_grad = requires_grad
|
param.requires_grad = requires_grad
|
||||||
|
|
||||||
@ -444,8 +494,7 @@ def get_grads(model: torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
A list containing the norms of the gradients of the first and the last layer weights.
|
A list containing the norms of the gradients of the first and the last layer weights.
|
||||||
"""
|
"""
|
||||||
grads = list(p.grad for n, p in model.named_parameters()
|
grads = list(p.grad for n, p in model.named_parameters() if ".weight" in n)
|
||||||
if '.weight' in n)
|
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm() for g in grads]
|
grads = [g.detach().norm() for g in grads]
|
||||||
return grads
|
return grads
|
||||||
|
@ -2,11 +2,13 @@ from math import log2, log10, ceil
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
plt.rc('text', usetex=False)
|
|
||||||
|
plt.rc("text", usetex=False)
|
||||||
|
|
||||||
from ..models import lag2eul, power
|
from ..models import lag2eul, power
|
||||||
|
|
||||||
@ -21,7 +23,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
Each field should have a channel dimension followed by spatial dimensions,
|
Each field should have a channel dimension followed by spatial dimensions,
|
||||||
i.e. no batch dimension.
|
i.e. no batch dimension.
|
||||||
"""
|
"""
|
||||||
plt.close('all')
|
plt.close("all")
|
||||||
|
|
||||||
assert all(isinstance(field, torch.Tensor) for field in fields)
|
assert all(isinstance(field, torch.Tensor) for field in fields)
|
||||||
|
|
||||||
@ -38,11 +40,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
im_size = 2
|
im_size = 2
|
||||||
cbar_height = 0.2
|
cbar_height = 0.2
|
||||||
fig, axes = plt.subplots(
|
fig, axes = plt.subplots(
|
||||||
nc + 1, nf,
|
nc + 1,
|
||||||
|
nf,
|
||||||
squeeze=False,
|
squeeze=False,
|
||||||
figsize=(nf * im_size, nc * im_size + cbar_height),
|
figsize=(nf * im_size, nc * im_size + cbar_height),
|
||||||
dpi=100,
|
dpi=100,
|
||||||
gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]}
|
gridspec_kw={"height_ratios": nc * [im_size] + [cbar_height]},
|
||||||
)
|
)
|
||||||
|
|
||||||
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
|
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
|
||||||
@ -51,11 +54,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
|
|
||||||
if cmap_col is None:
|
if cmap_col is None:
|
||||||
if all_non_neg:
|
if all_non_neg:
|
||||||
cmap_col = 'inferno'
|
cmap_col = "inferno"
|
||||||
elif all_non_pos:
|
elif all_non_pos:
|
||||||
cmap_col = 'inferno_r'
|
cmap_col = "inferno_r"
|
||||||
else:
|
else:
|
||||||
cmap_col = 'RdBu_r'
|
cmap_col = "RdBu_r"
|
||||||
|
|
||||||
if norm_col is None:
|
if norm_col is None:
|
||||||
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
||||||
@ -70,9 +73,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
if l1 < 0.1 * l2 or h2 == 0:
|
if l1 < 0.1 * l2 or h2 == 0:
|
||||||
norm_col = Normalize(vmin=-quantize(-l2), vmax=0)
|
norm_col = Normalize(vmin=-quantize(-l2), vmax=0)
|
||||||
else:
|
else:
|
||||||
norm_col = SymLogNorm(linthresh=quantize(-h2),
|
norm_col = SymLogNorm(
|
||||||
vmin=-quantize(-l2),
|
linthresh=quantize(-h2),
|
||||||
vmax=-quantize(-h2))
|
vmin=-quantize(-l2),
|
||||||
|
vmax=-quantize(-h2),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vlim = quantize(max(-l2, h2))
|
vlim = quantize(max(-l2, h2))
|
||||||
if w1 > 0.1 * w2 or l1 * h1 >= 0:
|
if w1 > 0.1 * w2 or l1 * h1 >= 0:
|
||||||
@ -80,8 +85,13 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
else:
|
else:
|
||||||
linthresh = quantize(min(-l1, h1))
|
linthresh = quantize(min(-l1, h1))
|
||||||
linscale = np.log10(vlim / linthresh)
|
linscale = np.log10(vlim / linthresh)
|
||||||
norm_col = SymLogNorm(linthresh=linthresh, linscale=linscale,
|
norm_col = SymLogNorm(
|
||||||
vmin=-vlim, vmax=vlim, base=10)
|
linthresh=linthresh,
|
||||||
|
linscale=linscale,
|
||||||
|
vmin=-vlim,
|
||||||
|
vmax=vlim,
|
||||||
|
base=10,
|
||||||
|
)
|
||||||
|
|
||||||
for c in range(field.shape[0]):
|
for c in range(field.shape[0]):
|
||||||
s = (c,) + tuple(d // 2 for d in field.shape[1:-2])
|
s = (c,) + tuple(d // 2 for d in field.shape[1:-2])
|
||||||
@ -101,7 +111,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
|
|
||||||
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
|
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
|
||||||
|
|
||||||
axes[c, f].set_aspect('equal')
|
axes[c, f].set_aspect("equal")
|
||||||
|
|
||||||
axes[c, f].set_xticks([])
|
axes[c, f].set_xticks([])
|
||||||
axes[c, f].set_yticks([])
|
axes[c, f].set_yticks([])
|
||||||
@ -110,12 +120,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
axes[c, f].set_title(title[f])
|
axes[c, f].set_title(title[f])
|
||||||
|
|
||||||
for c in range(field.shape[0], nc):
|
for c in range(field.shape[0], nc):
|
||||||
axes[c, f].axis('off')
|
axes[c, f].axis("off")
|
||||||
|
|
||||||
fig.colorbar(
|
fig.colorbar(
|
||||||
ScalarMappable(norm=norm_col, cmap=cmap_col),
|
ScalarMappable(norm=norm_col, cmap=cmap_col),
|
||||||
cax=axes[-1, f],
|
cax=axes[-1, f],
|
||||||
orientation='horizontal',
|
orientation="horizontal",
|
||||||
)
|
)
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
@ -133,7 +143,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||||||
|
|
||||||
See `map2map.models.power`.
|
See `map2map.models.power`.
|
||||||
"""
|
"""
|
||||||
plt.close('all')
|
plt.close("all")
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
assert len(label) == len(fields) or len(label) == len(dis)
|
assert len(label) == len(fields) or len(label) == len(dis)
|
||||||
@ -159,8 +169,8 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||||||
axes.loglog(k, P, label=l, alpha=0.7)
|
axes.loglog(k, P, label=l, alpha=0.7)
|
||||||
|
|
||||||
axes.legend()
|
axes.legend()
|
||||||
axes.set_xlabel('unnormalized wavenumber')
|
axes.set_xlabel("unnormalized wavenumber")
|
||||||
axes.set_ylabel('unnormalized power')
|
axes.set_ylabel("unnormalized power")
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||||||
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
|
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
|
||||||
from 'path/to/cb_dir/mod.py'.
|
from 'path/to/cb_dir/mod.py'.
|
||||||
"""
|
"""
|
||||||
if name.count('.') == 0:
|
if name.count(".") == 0:
|
||||||
attr = name
|
attr = name
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
@ -34,23 +34,22 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||||||
|
|
||||||
raise Exception(errors)
|
raise Exception(errors)
|
||||||
else:
|
else:
|
||||||
mod, attr = name.rsplit('.', 1)
|
mod, attr = name.rsplit(".", 1)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for pkg in pkgs:
|
for pkg in pkgs:
|
||||||
try:
|
try:
|
||||||
return getattr(
|
return getattr(importlib.import_module(pkg.__name__ + "." + mod), attr)
|
||||||
importlib.import_module(pkg.__name__ + '.' + mod), attr)
|
|
||||||
except (ModuleNotFoundError, AttributeError) as e:
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
errors.append(e)
|
errors.append(e)
|
||||||
|
|
||||||
if callback_at is None:
|
if callback_at is None:
|
||||||
raise Exception(errors)
|
raise Exception(errors)
|
||||||
|
|
||||||
callback_at = os.path.join(callback_at, mod + '.py')
|
callback_at = os.path.join(callback_at, mod + ".py")
|
||||||
if not os.path.isfile(callback_at):
|
if not os.path.isfile(callback_at):
|
||||||
raise FileNotFoundError('callback file not found')
|
raise FileNotFoundError("callback file not found")
|
||||||
|
|
||||||
if mod in sys.modules:
|
if mod in sys.modules:
|
||||||
return getattr(sys.modules[mod], attr)
|
return getattr(sys.modules[mod], attr)
|
||||||
|
@ -7,9 +7,13 @@ def load_model_state_dict(module, state_dict, strict=True):
|
|||||||
bad_keys = module.load_state_dict(state_dict, strict)
|
bad_keys = module.load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
if len(bad_keys.missing_keys) > 0:
|
if len(bad_keys.missing_keys) > 0:
|
||||||
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
warnings.warn(
|
||||||
pformat(bad_keys.missing_keys)))
|
"Missing keys in state_dict:\n{}".format(pformat(bad_keys.missing_keys))
|
||||||
|
)
|
||||||
if len(bad_keys.unexpected_keys) > 0:
|
if len(bad_keys.unexpected_keys) > 0:
|
||||||
warnings.warn('Unexpected keys in state_dict:\n{}'.format(
|
warnings.warn(
|
||||||
pformat(bad_keys.unexpected_keys)))
|
"Unexpected keys in state_dict:\n{}".format(
|
||||||
|
pformat(bad_keys.unexpected_keys)
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
Loading…
Reference in New Issue
Block a user