chore: Blackify the source code for pretty formatting
This commit is contained in:
parent
ec9732df21
commit
7361c3eb9d
27 changed files with 908 additions and 568 deletions
397
map2map/args.py
397
map2map/args.py
|
@ -7,18 +7,16 @@ from .train import ckpt_link
|
|||
|
||||
|
||||
def get_args():
|
||||
"""Parse arguments and set runtime defaults.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Transform field(s) to field(s)')
|
||||
"""Parse arguments and set runtime defaults."""
|
||||
parser = argparse.ArgumentParser(description="Transform field(s) to field(s)")
|
||||
|
||||
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
||||
subparsers = parser.add_subparsers(title="modes", dest="mode", required=True)
|
||||
train_parser = subparsers.add_parser(
|
||||
'train',
|
||||
"train",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
test_parser = subparsers.add_parser(
|
||||
'test',
|
||||
"test",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
|
@ -27,163 +25,293 @@ def get_args():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == 'train':
|
||||
if args.mode == "train":
|
||||
set_train_args(args)
|
||||
elif args.mode == 'test':
|
||||
elif args.mode == "test":
|
||||
set_test_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def add_common_args(parser):
|
||||
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
|
||||
'of input normalization functions')
|
||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||
'of target normalization functions')
|
||||
parser.add_argument('--crop', type=int_tuple,
|
||||
help='size to crop the input and target data. Default is the '
|
||||
'field size. Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-start', type=int_tuple,
|
||||
help='starting point of the first crop. Default is the origin. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-stop', type=int_tuple,
|
||||
help='stopping point of the last crop. Default is the opposite '
|
||||
'corner to the origin. Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-step', type=int_tuple,
|
||||
help='spacing between crops. Default is the crop size. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
|
||||
'integers, to pad equally along all axes, symmetrically on each, '
|
||||
'or by the specified size on every boundary, respectively')
|
||||
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
|
||||
help='size to pad the target data beyond the crop size, assuming '
|
||||
'periodic boundary condition, useful for super-resolution. '
|
||||
'Comma-sep. list with the same format as --in-pad')
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='upsampling factor for super-resolution, in which case '
|
||||
'crop and pad are sizes of the input resolution')
|
||||
parser.add_argument(
|
||||
"--in-norms",
|
||||
type=str_list,
|
||||
help="comma-sep. list " "of input normalization functions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-norms",
|
||||
type=str_list,
|
||||
help="comma-sep. list " "of target normalization functions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop",
|
||||
type=int_tuple,
|
||||
help="size to crop the input and target data. Default is the "
|
||||
"field size. Comma-sep. list of 1 or d integers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-start",
|
||||
type=int_tuple,
|
||||
help="starting point of the first crop. Default is the origin. "
|
||||
"Comma-sep. list of 1 or d integers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-stop",
|
||||
type=int_tuple,
|
||||
help="stopping point of the last crop. Default is the opposite "
|
||||
"corner to the origin. Comma-sep. list of 1 or d integers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-step",
|
||||
type=int_tuple,
|
||||
help="spacing between crops. Default is the crop size. "
|
||||
"Comma-sep. list of 1 or d integers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in-pad",
|
||||
"--pad",
|
||||
default=0,
|
||||
type=int_tuple,
|
||||
help="size to pad the input data beyond the crop size, assuming "
|
||||
"periodic boundary condition. Comma-sep. list of 1, d, or dx2 "
|
||||
"integers, to pad equally along all axes, symmetrically on each, "
|
||||
"or by the specified size on every boundary, respectively",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-pad",
|
||||
default=0,
|
||||
type=int_tuple,
|
||||
help="size to pad the target data beyond the crop size, assuming "
|
||||
"periodic boundary condition, useful for super-resolution. "
|
||||
"Comma-sep. list with the same format as --in-pad",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-factor",
|
||||
default=1,
|
||||
type=int,
|
||||
help="upsampling factor for super-resolution, in which case "
|
||||
"crop and pad are sizes of the input resolution",
|
||||
)
|
||||
|
||||
parser.add_argument('--model', type=str, required=True,
|
||||
help='(generator) model')
|
||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||
help='loss function')
|
||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||
help='path to load the states of model, optimizer, rng, etc. '
|
||||
'Default is the checkpoint. '
|
||||
'Start from scratch in case of empty string or missing checkpoint')
|
||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||
help='allow incompatible keys when loading model states',
|
||||
dest='load_state_strict')
|
||||
parser.add_argument("--model", type=str, required=True, help="(generator) model")
|
||||
parser.add_argument(
|
||||
"--criterion", default="MSELoss", type=str, help="loss function"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-state",
|
||||
default=ckpt_link,
|
||||
type=str,
|
||||
help="path to load the states of model, optimizer, rng, etc. "
|
||||
"Default is the checkpoint. "
|
||||
"Start from scratch in case of empty string or missing checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-state-non-strict",
|
||||
action="store_false",
|
||||
help="allow incompatible keys when loading model states",
|
||||
dest="load_state_strict",
|
||||
)
|
||||
|
||||
# somehow I named it "batches" instead of batch_size at first
|
||||
# "batches" is kept for now for backward compatibility
|
||||
parser.add_argument('--batch-size', '--batches', type=int, required=True,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', default=8, type=int,
|
||||
help='number of subprocesses per data loader. '
|
||||
'0 to disable multiprocessing')
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
"--batches",
|
||||
type=int,
|
||||
required=True,
|
||||
help="mini-batch size, per GPU in training or in total in testing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loader-workers",
|
||||
default=8,
|
||||
type=int,
|
||||
help="number of subprocesses per data loader. " "0 to disable multiprocessing",
|
||||
)
|
||||
|
||||
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
||||
help='directory of custorm code defining callbacks for models, '
|
||||
'norms, criteria, and optimizers. Disabled if not set. '
|
||||
'This is appended to the default locations, '
|
||||
'thus has the lowest priority')
|
||||
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
|
||||
help='miscellaneous keyword arguments for custom models and '
|
||||
'norms. Be careful with name collisions')
|
||||
parser.add_argument(
|
||||
"--callback-at",
|
||||
type=lambda s: os.path.abspath(s),
|
||||
help="directory of custorm code defining callbacks for models, "
|
||||
"norms, criteria, and optimizers. Disabled if not set. "
|
||||
"This is appended to the default locations, "
|
||||
"thus has the lowest priority",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--misc-kwargs",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="miscellaneous keyword arguments for custom models and "
|
||||
"norms. Be careful with name collisions",
|
||||
)
|
||||
|
||||
|
||||
def add_train_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--train-style-pattern', type=str, required=True,
|
||||
help='glob pattern for training data styles')
|
||||
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training input data')
|
||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training target data')
|
||||
parser.add_argument('--val-style-pattern', type=str,
|
||||
help='glob pattern for validation data styles')
|
||||
parser.add_argument('--val-in-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation input data')
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable data augmentation of axis flipping and permutation')
|
||||
parser.add_argument('--aug-shift', type=int_tuple,
|
||||
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
|
||||
'useful for models that treat neighboring pixels differently, '
|
||||
'e.g. with strided convolutions. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--aug-add', type=float,
|
||||
help='additive data augmentation, (normal) std, '
|
||||
'same factor for all fields')
|
||||
parser.add_argument('--aug-mul', type=float,
|
||||
help='multiplicative data augmentation, (log-normal) std, '
|
||||
'same factor for all fields')
|
||||
parser.add_argument(
|
||||
"--train-style-pattern",
|
||||
type=str,
|
||||
required=True,
|
||||
help="glob pattern for training data styles",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-in-patterns",
|
||||
type=str_list,
|
||||
required=True,
|
||||
help="comma-sep. list of glob patterns for training input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-tgt-patterns",
|
||||
type=str_list,
|
||||
required=True,
|
||||
help="comma-sep. list of glob patterns for training target data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-style-pattern", type=str, help="glob pattern for validation data styles"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-in-patterns",
|
||||
type=str_list,
|
||||
help="comma-sep. list of glob patterns for validation input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-tgt-patterns",
|
||||
type=str_list,
|
||||
help="comma-sep. list of glob patterns for validation target data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment",
|
||||
action="store_true",
|
||||
help="enable data augmentation of axis flipping and permutation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aug-shift",
|
||||
type=int_tuple,
|
||||
help="data augmentation by shifting cropping by [0, aug_shift) pixels, "
|
||||
"useful for models that treat neighboring pixels differently, "
|
||||
"e.g. with strided convolutions. "
|
||||
"Comma-sep. list of 1 or d integers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aug-add",
|
||||
type=float,
|
||||
help="additive data augmentation, (normal) std, " "same factor for all fields",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aug-mul",
|
||||
type=float,
|
||||
help="multiplicative data augmentation, (log-normal) std, "
|
||||
"same factor for all fields",
|
||||
)
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimization algorithm')
|
||||
parser.add_argument('--lr', type=float, required=True,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--optimizer-args', default='{}', type=json.loads,
|
||||
help='optimizer arguments in addition to the learning rate, '
|
||||
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
|
||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||
parser.add_argument('--scheduler-args', default='{"verbose": true}',
|
||||
type=json.loads,
|
||||
help='arguments for the ReduceLROnPlateau scheduler')
|
||||
parser.add_argument('--init-weight-std', type=float,
|
||||
help='weight initialization std')
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--seed', default=42, type=int,
|
||||
help='seed for initializing training')
|
||||
parser.add_argument(
|
||||
"--optimizer", default="Adam", type=str, help="optimization algorithm"
|
||||
)
|
||||
parser.add_argument("--lr", type=float, required=True, help="initial learning rate")
|
||||
parser.add_argument(
|
||||
"--optimizer-args",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="optimizer arguments in addition to the learning rate, "
|
||||
"e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reduce-lr-on-plateau",
|
||||
action="store_true",
|
||||
help="Enable ReduceLROnPlateau learning rate scheduler",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler-args",
|
||||
default='{"verbose": true}',
|
||||
type=json.loads,
|
||||
help="arguments for the ReduceLROnPlateau scheduler",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-weight-std", type=float, help="weight initialization std"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", default=128, type=int, help="total number of epochs to run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", default=42, type=int, help="seed for initializing training"
|
||||
)
|
||||
|
||||
parser.add_argument('--div-data', action='store_true',
|
||||
help='enable data division among GPUs for better page caching. '
|
||||
'Data division is shuffled every epoch. '
|
||||
'Only relevant if there are multiple crops in each field')
|
||||
parser.add_argument('--div-shuffle-dist', default=1, type=float,
|
||||
help='distance to further shuffle cropped samples relative to '
|
||||
'their fields, to be used with --div-data. '
|
||||
'Only relevant if there are multiple crops in each file. '
|
||||
'The order of each sample is randomly displaced by this value. '
|
||||
'Setting it to 0 turn off this randomization, and setting it to N '
|
||||
'limits the shuffling within a distance of N files. '
|
||||
'Change this to balance cache locality and stochasticity')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
choices=['gloo', 'nccl'], help='distributed backend')
|
||||
parser.add_argument('--log-interval', default=100, type=int,
|
||||
help='interval (batches) between logging training loss')
|
||||
parser.add_argument('--detect-anomaly', action='store_true',
|
||||
help='enable anomaly detection for the autograd engine')
|
||||
parser.add_argument(
|
||||
"--div-data",
|
||||
action="store_true",
|
||||
help="enable data division among GPUs for better page caching. "
|
||||
"Data division is shuffled every epoch. "
|
||||
"Only relevant if there are multiple crops in each field",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--div-shuffle-dist",
|
||||
default=1,
|
||||
type=float,
|
||||
help="distance to further shuffle cropped samples relative to "
|
||||
"their fields, to be used with --div-data. "
|
||||
"Only relevant if there are multiple crops in each file. "
|
||||
"The order of each sample is randomly displaced by this value. "
|
||||
"Setting it to 0 turn off this randomization, and setting it to N "
|
||||
"limits the shuffling within a distance of N files. "
|
||||
"Change this to balance cache locality and stochasticity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dist-backend",
|
||||
default="nccl",
|
||||
type=str,
|
||||
choices=["gloo", "nccl"],
|
||||
help="distributed backend",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
default=100,
|
||||
type=int,
|
||||
help="interval (batches) between logging training loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect-anomaly",
|
||||
action="store_true",
|
||||
help="enable anomaly detection for the autograd engine",
|
||||
)
|
||||
|
||||
|
||||
def add_test_args(parser):
|
||||
add_common_args(parser)
|
||||
|
||||
parser.add_argument('--test-style-pattern', type=str, required=True,
|
||||
help='glob pattern for test data styles')
|
||||
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for test input data')
|
||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for test target data')
|
||||
parser.add_argument(
|
||||
"--test-style-pattern",
|
||||
type=str,
|
||||
required=True,
|
||||
help="glob pattern for test data styles",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-in-patterns",
|
||||
type=str_list,
|
||||
required=True,
|
||||
help="comma-sep. list of glob patterns for test input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-tgt-patterns",
|
||||
type=str_list,
|
||||
required=True,
|
||||
help="comma-sep. list of glob patterns for test target data",
|
||||
)
|
||||
|
||||
parser.add_argument('--num-threads', type=int,
|
||||
help='number of CPU threads when cuda is unavailable. '
|
||||
'Default is the number of CPUs on the node by slurm')
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
help="number of CPU threads when cuda is unavailable. "
|
||||
"Default is the number of CPUs on the node by slurm",
|
||||
)
|
||||
|
||||
|
||||
def str_list(s):
|
||||
return s.split(',')
|
||||
return s.split(",")
|
||||
|
||||
|
||||
def int_tuple(s):
|
||||
t = s.split(',')
|
||||
t = s.split(",")
|
||||
t = tuple(int(i) for i in t)
|
||||
if len(t) == 1:
|
||||
return t[0]
|
||||
|
@ -198,8 +326,7 @@ def set_common_args(args):
|
|||
def set_train_args(args):
|
||||
set_common_args(args)
|
||||
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None
|
||||
|
||||
|
||||
def set_test_args(args):
|
||||
|
|
|
@ -43,54 +43,72 @@ class FieldDataset(Dataset):
|
|||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
"""
|
||||
def __init__(self, style_pattern, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
in_pad=0, tgt_pad=0, scale_factor=1,
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
style_pattern,
|
||||
in_patterns,
|
||||
tgt_patterns,
|
||||
in_norms=None,
|
||||
tgt_norms=None,
|
||||
callback_at=None,
|
||||
augment=False,
|
||||
aug_shift=None,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
crop=None,
|
||||
crop_start=None,
|
||||
crop_stop=None,
|
||||
crop_step=None,
|
||||
in_pad=0,
|
||||
tgt_pad=0,
|
||||
scale_factor=1,
|
||||
**kwargs,
|
||||
):
|
||||
self.style_files = sorted(glob(style_pattern))
|
||||
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
self.in_files = list(zip(*in_file_lists))
|
||||
|
||||
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
|
||||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
self.tgt_files = list(zip(*tgt_file_lists))
|
||||
|
||||
# self.style_files = self.style_files[:1]
|
||||
# self.in_files = self.in_files[:1]
|
||||
# self.tgt_files = self.tgt_files[:1]
|
||||
|
||||
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
|
||||
raise ValueError('number of style, input, and target files do not match')
|
||||
raise ValueError("number of style, input, and target files do not match")
|
||||
self.nfile = len(self.in_files)
|
||||
|
||||
if self.nfile == 0:
|
||||
raise FileNotFoundError('file not found for {}'.format(in_patterns))
|
||||
raise FileNotFoundError("file not found for {}".format(in_patterns))
|
||||
|
||||
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
|
||||
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.in_files[0]]
|
||||
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||
for f in self.tgt_files[0]]
|
||||
self.in_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.in_files[0]]
|
||||
self.tgt_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.tgt_files[0]]
|
||||
|
||||
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
|
||||
self.size = np.load(self.in_files[0][0], mmap_mode="r").shape[1:]
|
||||
self.size = np.asarray(self.size)
|
||||
self.ndim = len(self.size)
|
||||
|
||||
if in_norms is not None and len(in_patterns) != len(in_norms):
|
||||
raise ValueError('numbers of input normalization functions and fields do not match')
|
||||
raise ValueError(
|
||||
"numbers of input normalization functions and fields do not match"
|
||||
)
|
||||
self.in_norms = in_norms
|
||||
|
||||
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
|
||||
raise ValueError('numbers of target normalization functions and fields do not match')
|
||||
raise ValueError(
|
||||
"numbers of target normalization functions and fields do not match"
|
||||
)
|
||||
self.tgt_norms = tgt_norms
|
||||
|
||||
self.callback_at = callback_at
|
||||
|
||||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
raise ValueError("cannot augment 1D fields")
|
||||
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
||||
self.aug_add = aug_add
|
||||
self.aug_mul = aug_mul
|
||||
|
@ -116,10 +134,15 @@ class FieldDataset(Dataset):
|
|||
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||
self.crop_step = crop_step
|
||||
|
||||
self.anchors = np.stack(np.mgrid[tuple(
|
||||
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||
for d in range(self.ndim)
|
||||
)], axis=-1).reshape(-1, self.ndim)
|
||||
self.anchors = np.stack(
|
||||
np.mgrid[
|
||||
tuple(
|
||||
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||
for d in range(self.ndim)
|
||||
)
|
||||
],
|
||||
axis=-1,
|
||||
).reshape(-1, self.ndim)
|
||||
self.ncrop = len(self.anchors)
|
||||
|
||||
def format_pad(pad, ndim):
|
||||
|
@ -130,15 +153,16 @@ class FieldDataset(Dataset):
|
|||
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
|
||||
pad = np.array(pad)
|
||||
else:
|
||||
raise ValueError('pad and ndim mismatch')
|
||||
raise ValueError("pad and ndim mismatch")
|
||||
return pad.reshape(ndim, 2)
|
||||
|
||||
self.in_pad = format_pad(in_pad, self.ndim)
|
||||
self.tgt_pad = format_pad(tgt_pad, self.ndim)
|
||||
|
||||
if scale_factor != 1:
|
||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode="r").shape[1:]
|
||||
if any(self.size * scale_factor != tgt_size):
|
||||
raise ValueError('input size x scale factor != target size')
|
||||
raise ValueError("input size x scale factor != target size")
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.nsample = self.nfile * self.ncrop
|
||||
|
@ -148,9 +172,7 @@ class FieldDataset(Dataset):
|
|||
self.assembly_line = {}
|
||||
|
||||
self.commonpath = os.path.commonpath(
|
||||
file
|
||||
for files in self.in_files[:2] + self.tgt_files[:2]
|
||||
for file in files
|
||||
file for files in self.in_files[:2] + self.tgt_files[:2] for file in files
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -180,12 +202,18 @@ class FieldDataset(Dataset):
|
|||
else:
|
||||
argsort_perm_axes = slice(None)
|
||||
|
||||
crop(in_fields, anchor,
|
||||
self.crop[argsort_perm_axes],
|
||||
self.in_pad[argsort_perm_axes])
|
||||
crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop[argsort_perm_axes] * self.scale_factor,
|
||||
self.tgt_pad[argsort_perm_axes])
|
||||
crop(
|
||||
in_fields,
|
||||
anchor,
|
||||
self.crop[argsort_perm_axes],
|
||||
self.in_pad[argsort_perm_axes],
|
||||
)
|
||||
crop(
|
||||
tgt_fields,
|
||||
anchor * self.scale_factor,
|
||||
self.crop[argsort_perm_axes] * self.scale_factor,
|
||||
self.tgt_pad[argsort_perm_axes],
|
||||
)
|
||||
|
||||
style = torch.from_numpy(style).to(torch.float32)
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
|
@ -221,17 +249,19 @@ class FieldDataset(Dataset):
|
|||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
#in_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||
# in_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||
# for file in self.in_files[ifile]]
|
||||
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
|
||||
for file in self.tgt_files[ifile]]
|
||||
tgt_relpath = [
|
||||
os.path.relpath(file, start=self.commonpath)
|
||||
for file in self.tgt_files[ifile]
|
||||
]
|
||||
|
||||
return {
|
||||
'style': style,
|
||||
'input': in_fields,
|
||||
'target': tgt_fields,
|
||||
"style": style,
|
||||
"input": in_fields,
|
||||
"target": tgt_fields,
|
||||
#'input_relpath': in_relpath,
|
||||
'target_relpath': tgt_relpath,
|
||||
"target_relpath": tgt_relpath,
|
||||
}
|
||||
|
||||
def assemble(self, label, chan, patches, paths):
|
||||
|
@ -255,29 +285,29 @@ class FieldDataset(Dataset):
|
|||
if isinstance(patches, torch.Tensor):
|
||||
patches = patches.detach().cpu().numpy()
|
||||
|
||||
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
|
||||
assert patches.ndim == 2 + self.ndim, "ndim mismatch"
|
||||
if any(self.crop_step > patches.shape[2:]):
|
||||
raise RuntimeError('patch too small to tile')
|
||||
raise RuntimeError("patch too small to tile")
|
||||
|
||||
# the batched paths are a list of lists with shape (channel, batch)
|
||||
# since pytorch default_collate batches list of strings transposedly
|
||||
# therefore we transpose below back to (batch, channel)
|
||||
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
|
||||
assert len(paths) == len(chan), 'number of fields mismatch'
|
||||
paths = list(zip(* paths))
|
||||
assert patches.shape[0] == len(paths), 'batch size mismatch'
|
||||
assert patches.shape[1] == sum(chan), "number of channels mismatch"
|
||||
assert len(paths) == len(chan), "number of fields mismatch"
|
||||
paths = list(zip(*paths))
|
||||
assert patches.shape[0] == len(paths), "batch size mismatch"
|
||||
|
||||
patches = list(patches)
|
||||
if label in self.assembly_line:
|
||||
self.assembly_line[label] += patches
|
||||
self.assembly_line[label + 'path'] += paths
|
||||
self.assembly_line[label + "path"] += paths
|
||||
else:
|
||||
self.assembly_line[label] = patches
|
||||
self.assembly_line[label + 'path'] = paths
|
||||
self.assembly_line[label + "path"] = paths
|
||||
|
||||
del patches, paths
|
||||
patches = self.assembly_line[label]
|
||||
paths = self.assembly_line[label + 'path']
|
||||
paths = self.assembly_line[label + "path"]
|
||||
|
||||
# NOTE anchor positioning assumes sufficient target padding and
|
||||
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
|
||||
|
@ -285,31 +315,26 @@ class FieldDataset(Dataset):
|
|||
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
|
||||
|
||||
while len(patches) >= self.ncrop:
|
||||
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
|
||||
patches[0].dtype)
|
||||
fields = np.zeros(patches[0].shape[:1] + tuple(self.size), patches[0].dtype)
|
||||
|
||||
for patch, anchor in zip(patches, anchors):
|
||||
fill(fields, patch, anchor)
|
||||
|
||||
for field, path in zip(
|
||||
np.split(fields, np.cumsum(chan), axis=0),
|
||||
paths[0]):
|
||||
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
|
||||
exist_ok=True)
|
||||
for field, path in zip(np.split(fields, np.cumsum(chan), axis=0), paths[0]):
|
||||
pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = label.join(os.path.splitext(path))
|
||||
np.save(path, field)
|
||||
|
||||
del patches[:self.ncrop], paths[:self.ncrop]
|
||||
del patches[: self.ncrop], paths[: self.ncrop]
|
||||
|
||||
|
||||
def fill(field, patch, anchor):
|
||||
ndim = len(anchor)
|
||||
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch'
|
||||
assert field.ndim == patch.ndim == 1 + ndim, "ndim mismatch"
|
||||
|
||||
ind = [slice(None)]
|
||||
for d, (p, a, s) in enumerate(zip(
|
||||
patch.shape[1:], anchor, field.shape[1:])):
|
||||
for d, (p, a, s) in enumerate(zip(patch.shape[1:], anchor, field.shape[1:])):
|
||||
i = np.arange(a, a + p)
|
||||
i %= s
|
||||
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||
|
@ -320,10 +345,10 @@ def fill(field, patch, anchor):
|
|||
|
||||
|
||||
def crop(fields, anchor, crop, pad):
|
||||
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch'
|
||||
assert all(x.shape == fields[0].shape for x in fields), "shape mismatch"
|
||||
size = fields[0].shape[1:]
|
||||
ndim = len(size)
|
||||
assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch'
|
||||
assert ndim == len(anchor) == len(crop) == len(pad), "ndim mismatch"
|
||||
|
||||
ind = [slice(None)]
|
||||
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||
|
@ -342,7 +367,7 @@ def crop(fields, anchor, crop, pad):
|
|||
|
||||
|
||||
def flip(fields, axes, ndim):
|
||||
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
|
||||
assert ndim > 1, "flipping is ambiguous for 1D scalars/vectors"
|
||||
|
||||
if axes is None:
|
||||
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
||||
|
@ -350,7 +375,7 @@ def flip(fields, axes, ndim):
|
|||
|
||||
for i, x in enumerate(fields):
|
||||
if x.shape[0] == ndim: # flip vector components
|
||||
x[axes] = - x[axes]
|
||||
x[axes] = -x[axes]
|
||||
|
||||
shifted_axes = (1 + axes).tolist()
|
||||
x = torch.flip(x, shifted_axes)
|
||||
|
@ -361,7 +386,7 @@ def flip(fields, axes, ndim):
|
|||
|
||||
|
||||
def perm(fields, axes, ndim):
|
||||
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
||||
assert ndim > 1, "permutation is not necessary for 1D fields"
|
||||
|
||||
if axes is None:
|
||||
axes = torch.randperm(ndim)
|
||||
|
|
|
@ -10,6 +10,7 @@ def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
|||
|
||||
x *= dis_norm
|
||||
|
||||
|
||||
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||
|
||||
|
@ -20,25 +21,28 @@ def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
|||
|
||||
|
||||
def D(z, Om=0.31):
|
||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
||||
"""
|
||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero"""
|
||||
OL = 1 - Om
|
||||
a = 1 / (1+z)
|
||||
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
|
||||
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
|
||||
a = 1 / (1 + z)
|
||||
return (
|
||||
a
|
||||
* hyp2f1(1, 1 / 3, 11 / 6, -OL * a**3 / Om)
|
||||
/ hyp2f1(1, 1 / 3, 11 / 6, -OL / Om)
|
||||
)
|
||||
|
||||
|
||||
def f(z, Om=0.31):
|
||||
"""linear growth rate for flat LambdaCDM
|
||||
"""
|
||||
"""linear growth rate for flat LambdaCDM"""
|
||||
OL = 1 - Om
|
||||
a = 1 / (1+z)
|
||||
a = 1 / (1 + z)
|
||||
aa3 = OL * a**3 / Om
|
||||
return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \
|
||||
/ hyp2f1(1, 1/3, 11/6, -aa3)
|
||||
return 1 - 6 / 11 * aa3 * hyp2f1(2, 4 / 3, 17 / 6, -aa3) / hyp2f1(
|
||||
1, 1 / 3, 11 / 6, -aa3
|
||||
)
|
||||
|
||||
|
||||
def H(z, Om=0.31):
|
||||
"""Hubble in [h km/s/Mpc] for flat LambdaCDM
|
||||
"""
|
||||
"""Hubble in [h km/s/Mpc] for flat LambdaCDM"""
|
||||
OL = 1 - Om
|
||||
a = 1 / (1+z)
|
||||
a = 1 / (1 + z)
|
||||
return 100 * np.sqrt(Om / a**3 + OL)
|
||||
|
|
|
@ -7,18 +7,21 @@ def exp(x, undo=False, **kwargs):
|
|||
else:
|
||||
torch.log(x, out=x)
|
||||
|
||||
|
||||
def log(x, eps=1e-8, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.log(x + eps, out=x)
|
||||
else:
|
||||
torch.exp(x, out=x)
|
||||
|
||||
|
||||
def expm1(x, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.expm1(x, out=x)
|
||||
else:
|
||||
torch.log1p(x, out=x)
|
||||
|
||||
|
||||
def log1p(x, eps=1e-7, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.log1p(x + eps, out=x)
|
||||
|
|
|
@ -22,8 +22,8 @@ class DistFieldSampler(Sampler):
|
|||
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
|
||||
of each epoch during training.
|
||||
"""
|
||||
def __init__(self, dataset, shuffle,
|
||||
div_data=False, div_shuffle_dist=0):
|
||||
|
||||
def __init__(self, dataset, shuffle, div_data=False, div_shuffle_dist=0):
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
|
@ -50,8 +50,10 @@ class DistFieldSampler(Sampler):
|
|||
ind = ind.flatten()
|
||||
|
||||
# displace crops with respect to files
|
||||
dis = torch.rand((self.nfile, self.ncrop),
|
||||
generator=g) * self.div_shuffle_dist
|
||||
dis = (
|
||||
torch.rand((self.nfile, self.ncrop), generator=g)
|
||||
* self.div_shuffle_dist
|
||||
)
|
||||
loc = torch.arange(self.nfile)
|
||||
loc = loc[:, None] + dis
|
||||
loc = loc.flatten() % self.nfile # periodic in files
|
||||
|
|
|
@ -3,6 +3,7 @@ from . import test
|
|||
import click
|
||||
import os
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
|
@ -12,20 +13,24 @@ import importlib.resources
|
|||
import json
|
||||
from functools import partial
|
||||
|
||||
|
||||
def _load_resource_file(resource_path):
|
||||
# Import the package
|
||||
pkg_files = importlib.resources.files() / resource_path
|
||||
with pkg_files.open() as file:
|
||||
return file.read() # Read the file and return its content
|
||||
return file.read() # Read the file and return its content
|
||||
|
||||
|
||||
def _str_list(value):
|
||||
return value.split(',')
|
||||
return value.split(",")
|
||||
|
||||
|
||||
def _int_tuple(value):
|
||||
t = value.split(',')
|
||||
t = value.split(",")
|
||||
t = tuple(int(i) for i in t)
|
||||
return t
|
||||
|
||||
|
||||
class VariadicType(click.ParamType):
|
||||
"""
|
||||
A custom parameter type for Click command-line interface.
|
||||
|
@ -54,14 +59,14 @@ class VariadicType(click.ParamType):
|
|||
if typename in self._mapper:
|
||||
self._type = self._mapper[typename]
|
||||
elif type(typename) == dict:
|
||||
self._type = self._mapper[typename["type"]]
|
||||
self._type = self._mapper[typename["type"]]
|
||||
self.args = typename["opts"]
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {typename}")
|
||||
raise ValueError(f"Unknown type: {typename}")
|
||||
self._typename = typename
|
||||
self.name = self._type["type"]
|
||||
if "func" not in self._type:
|
||||
self._type["func"] = eval(self._type['type'])
|
||||
self._type["func"] = eval(self._type["type"])
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
try:
|
||||
|
@ -69,24 +74,26 @@ class VariadicType(click.ParamType):
|
|||
except Exception as e:
|
||||
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
|
||||
|
||||
|
||||
def _apply_options(options_file, f):
|
||||
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
|
||||
common_args = common_args['arguments']
|
||||
common_args = common_args["arguments"]
|
||||
|
||||
for arg in common_args:
|
||||
argopt = common_args[arg]
|
||||
if 'type' in argopt:
|
||||
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
|
||||
argopt['type'] = click.Choice(argopt['type']['opts'])
|
||||
if "type" in argopt:
|
||||
if type(argopt["type"]) == dict and argopt["type"]["type"] == "choice":
|
||||
argopt["type"] = click.Choice(argopt["type"]["opts"])
|
||||
else:
|
||||
argopt['type'] = VariadicType(argopt['type'])
|
||||
f = click.option(f'--{arg}', **argopt)(f)
|
||||
argopt["type"] = VariadicType(argopt["type"])
|
||||
f = click.option(f"--{arg}", **argopt)(f)
|
||||
else:
|
||||
f = click.option(f'--{arg}', **argopt)(f)
|
||||
f = click.option(f"--{arg}", **argopt)(f)
|
||||
|
||||
return f
|
||||
|
||||
m2m_options=partial(_apply_options,"common_args.yaml")
|
||||
|
||||
m2m_options = partial(_apply_options, "common_args.yaml")
|
||||
|
||||
|
||||
@click.group()
|
||||
|
@ -94,23 +101,26 @@ m2m_options=partial(_apply_options,"common_args.yaml")
|
|||
@click.pass_context
|
||||
def main(ctx, config):
|
||||
if config is not None and os.path.exists(config):
|
||||
with open(config, 'r') as f:
|
||||
with open(config, "r") as f:
|
||||
config = yaml.load(f.read(), Loader=Loader)
|
||||
ctx.default_map = config
|
||||
|
||||
|
||||
# Make a class that provides access to dict with the attribute mechanism
|
||||
class DictProxy:
|
||||
def __init__(self, d):
|
||||
self.__dict__ = d
|
||||
|
||||
|
||||
@main.command()
|
||||
@m2m_options
|
||||
@partial(_apply_options, "train_args.yaml")
|
||||
def train(**kwargs):
|
||||
train.node_worker(DictProxy(kwargs))
|
||||
|
||||
|
||||
@main.command()
|
||||
@m2m_options
|
||||
@partial(_apply_options, "test_args.yaml")
|
||||
def test(**kwargs):
|
||||
test.test(DictProxy(kwargs))
|
||||
test.test(DictProxy(kwargs))
|
||||
|
|
|
@ -5,6 +5,7 @@ def adv_model_wrapper(module):
|
|||
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||
to be concatenated along the batch dimension
|
||||
"""
|
||||
|
||||
class _new_module(module):
|
||||
def forward(self, x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
|
@ -22,6 +23,7 @@ def adv_criterion_wrapper(module):
|
|||
* expand target shape as that of input
|
||||
* return a list of losses, one for each pair of input and target Tensors
|
||||
"""
|
||||
|
||||
class _new_module(module):
|
||||
def forward(self, input, target):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
|
@ -35,8 +37,9 @@ def adv_criterion_wrapper(module):
|
|||
|
||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
loss = [
|
||||
super(new_module, self).forward(i, t) for i, t in zip(input, target)
|
||||
]
|
||||
|
||||
return loss
|
||||
|
||||
|
|
|
@ -15,8 +15,10 @@ class ConvBlock(nn.Module):
|
|||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||
'D': downsampling convolution of kernel size 2 and stride 2
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBA'):
|
||||
|
||||
def __init__(
|
||||
self, in_chan, out_chan=None, mid_chan=None, kernel_size=3, stride=1, seq="CBA"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if out_chan is None:
|
||||
|
@ -31,31 +33,30 @@ class ConvBlock(nn.Module):
|
|||
|
||||
self.norm_chan = in_chan
|
||||
self.idx_conv = 0
|
||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
|
||||
|
||||
layers = [self._get_layer(l) for l in seq]
|
||||
|
||||
self.convs = nn.Sequential(*layers)
|
||||
|
||||
def _get_layer(self, l):
|
||||
if l == 'U':
|
||||
if l == "U":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
|
||||
elif l == 'D':
|
||||
elif l == "D":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
||||
elif l == 'C':
|
||||
elif l == "C":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return nn.Conv3d(in_chan, out_chan, self.kernel_size,
|
||||
stride=self.stride)
|
||||
elif l == 'B':
|
||||
return nn.Conv3d(in_chan, out_chan, self.kernel_size, stride=self.stride)
|
||||
elif l == "B":
|
||||
return nn.BatchNorm3d(self.norm_chan)
|
||||
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||
#return nn.InstanceNorm3d(self.norm_chan)
|
||||
elif l == 'A':
|
||||
# return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||
# return nn.InstanceNorm3d(self.norm_chan)
|
||||
elif l == "A":
|
||||
return nn.LeakyReLU()
|
||||
else:
|
||||
raise ValueError('layer type {} not supported'.format(l))
|
||||
raise ValueError("layer type {} not supported".format(l))
|
||||
|
||||
def _setup_conv(self):
|
||||
self.idx_conv += 1
|
||||
|
@ -88,13 +89,22 @@ class ResBlock(ConvBlock):
|
|||
|
||||
See `ConvBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chan,
|
||||
out_chan=None,
|
||||
mid_chan=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
seq="CBACBA",
|
||||
last_act=None,
|
||||
):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
last_act = seq[-1] == "A"
|
||||
elif last_act and seq[-1] != "A":
|
||||
warnings.warn(
|
||||
'Disabling last_act without trailing activation in seq',
|
||||
"Disabling last_act without trailing activation in seq",
|
||||
RuntimeWarning,
|
||||
)
|
||||
last_act = False
|
||||
|
@ -102,8 +112,14 @@ class ResBlock(ConvBlock):
|
|||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||
super().__init__(
|
||||
in_chan,
|
||||
out_chan=out_chan,
|
||||
mid_chan=mid_chan,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
seq=seq,
|
||||
)
|
||||
|
||||
if last_act:
|
||||
self.act = nn.LeakyReLU()
|
||||
|
@ -115,9 +131,10 @@ class ResBlock(ConvBlock):
|
|||
else:
|
||||
self.skip = nn.Conv3d(in_chan, out_chan, 1)
|
||||
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
if "U" in seq or "D" in seq:
|
||||
raise NotImplementedError(
|
||||
"upsample and downsample layers " "not supported yet"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, eps=0.):
|
||||
def __init__(self, eps=0.0):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
|
@ -10,7 +10,7 @@ class DiceLoss(nn.Module):
|
|||
return dice_loss(input, target, self.eps)
|
||||
|
||||
|
||||
def dice_loss(input, target, eps=0.):
|
||||
def dice_loss(input, target, eps=0.0):
|
||||
input = input.view(-1)
|
||||
target = target.view(-1)
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import torch
|
||||
|
||||
|
||||
class InstanceNoise:
|
||||
"""Instance noise, with a linear decaying schedule
|
||||
"""
|
||||
"""Instance noise, with a linear decaying schedule"""
|
||||
|
||||
def __init__(self, init_std, batches):
|
||||
assert init_std >= 0, 'Noise std cannot be negative'
|
||||
assert init_std >= 0, "Noise std cannot be negative"
|
||||
self.init_std = init_std
|
||||
self._std = init_std
|
||||
self.batches = batches
|
||||
|
|
|
@ -5,17 +5,18 @@ from ..data.norms.cosmology import D
|
|||
|
||||
|
||||
def lag2eul(
|
||||
dis,
|
||||
val=1.0,
|
||||
eul_scale_factor=1,
|
||||
eul_pad=0,
|
||||
rm_dis_mean=True,
|
||||
periodic=False,
|
||||
z=0.0,
|
||||
dis_std=6.0,
|
||||
boxsize=1000.,
|
||||
meshsize=512,
|
||||
**kwargs):
|
||||
dis,
|
||||
val=1.0,
|
||||
eul_scale_factor=1,
|
||||
eul_pad=0,
|
||||
rm_dis_mean=True,
|
||||
periodic=False,
|
||||
z=0.0,
|
||||
dis_std=6.0,
|
||||
boxsize=1000.0,
|
||||
meshsize=512,
|
||||
**kwargs,
|
||||
):
|
||||
"""Transform fields from Lagrangian description to Eulerian description
|
||||
|
||||
Only works for 3d fields, output same mesh size as input.
|
||||
|
@ -48,20 +49,19 @@ def lag2eul(
|
|||
if isinstance(val, (float, torch.Tensor)):
|
||||
val = [val]
|
||||
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
|
||||
raise ValueError('dis-val field mismatch')
|
||||
raise ValueError("dis-val field mismatch")
|
||||
|
||||
if any(d.dim() != 5 for d in dis):
|
||||
raise NotImplementedError('only support 3d fields for now')
|
||||
raise NotImplementedError("only support 3d fields for now")
|
||||
if any(d.shape[1] != 3 for d in dis):
|
||||
raise ValueError('only support 3d displacement fields')
|
||||
raise ValueError("only support 3d displacement fields")
|
||||
|
||||
# common mean displacement of all inputs
|
||||
# if removed, fewer particles go outside of the box
|
||||
# common for all inputs so outputs are comparable in the same coords
|
||||
d_mean = 0
|
||||
if rm_dis_mean:
|
||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
||||
for d in dis) / len(dis)
|
||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) for d in dis) / len(dis)
|
||||
|
||||
out = []
|
||||
if len(dis) == 1 and len(val) != 1:
|
||||
|
@ -85,12 +85,15 @@ def lag2eul(
|
|||
pos = (d - d_mean) * dis_norm
|
||||
del d
|
||||
|
||||
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)[:, None, None]
|
||||
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)[:, None]
|
||||
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)
|
||||
pos[:, 0] += torch.arange(
|
||||
0, DHW[0] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||
)[:, None, None]
|
||||
pos[:, 1] += torch.arange(
|
||||
0, DHW[1] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||
)[:, None]
|
||||
pos[:, 2] += torch.arange(
|
||||
0, DHW[2] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
|
||||
|
||||
|
@ -118,8 +121,7 @@ def lag2eul(
|
|||
if periodic:
|
||||
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
||||
|
||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
||||
) * DHW[2] + tgtpos[n, 2]
|
||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]) * DHW[2] + tgtpos[n, 2]
|
||||
src = v[n]
|
||||
|
||||
if not periodic:
|
||||
|
|
|
@ -3,8 +3,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
def narrow_by(a, c):
|
||||
"""Narrow a by size c symmetrically on all edges.
|
||||
"""
|
||||
"""Narrow a by size c symmetrically on all edges."""
|
||||
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
||||
return a[ind]
|
||||
|
||||
|
|
|
@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.convs = nn.Sequential(
|
||||
ConvBlock(in_chan, 32, seq='CA'),
|
||||
ConvBlock(32, 64, seq='CBA'),
|
||||
ConvBlock(64, 128, seq='CBA'),
|
||||
ConvBlock(128, out_chan, seq='C'),
|
||||
ConvBlock(in_chan, 32, seq="CA"),
|
||||
ConvBlock(32, 64, seq="CBA"),
|
||||
ConvBlock(64, 128, seq="CBA"),
|
||||
ConvBlock(128, out_chan, seq="C"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -19,23 +19,20 @@ class PatchGAN(nn.Module):
|
|||
|
||||
|
||||
class PatchGAN42(nn.Module):
|
||||
"""PatchGAN similar to the one in pix2pix
|
||||
"""
|
||||
"""PatchGAN similar to the one in pix2pix"""
|
||||
|
||||
def __init__(self, in_chan, out_chan=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.Sequential(
|
||||
nn.Conv3d(in_chan, 64, 4, stride=2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
nn.Conv3d(64, 128, 4, stride=2),
|
||||
nn.BatchNorm3d(128),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
nn.Conv3d(128, 256, 4, stride=2),
|
||||
nn.BatchNorm3d(256),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
nn.Conv3d(256, out_chan, 1),
|
||||
)
|
||||
|
||||
|
|
|
@ -26,8 +26,7 @@ def power(x: torch.Tensor):
|
|||
P = P.sum(dim=0)
|
||||
del x
|
||||
|
||||
k = [torch.arange(d, dtype=torch.float32, device=P.device)
|
||||
for d in P.shape]
|
||||
k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape]
|
||||
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
|
||||
k = torch.meshgrid(*k)
|
||||
k = torch.stack(k, dim=0)
|
||||
|
@ -49,9 +48,9 @@ def power(x: torch.Tensor):
|
|||
del kbin
|
||||
|
||||
# drop k=0 mode and cut at kmax (smallest Nyquist)
|
||||
k = k[1:1+kmax]
|
||||
P = P[1:1+kmax]
|
||||
N = N[1:1+kmax]
|
||||
k = k[1 : 1 + kmax]
|
||||
P = P[1 : 1 + kmax]
|
||||
N = N[1 : 1 + kmax]
|
||||
|
||||
k /= N
|
||||
P /= N
|
||||
|
|
|
@ -5,12 +5,11 @@ from .narrow import narrow_by
|
|||
|
||||
|
||||
def resample(x, scale_factor, narrow=True):
|
||||
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
|
||||
ndim = x.dim() - 2
|
||||
mode = modes[ndim]
|
||||
|
||||
x = F.interpolate(x, scale_factor=scale_factor,
|
||||
mode=mode, align_corners=False)
|
||||
x = F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=False)
|
||||
|
||||
if scale_factor > 1 and narrow == True:
|
||||
edges = round(scale_factor) // 2
|
||||
|
@ -25,18 +24,20 @@ class Resampler(nn.Module):
|
|||
|
||||
By default discard the inaccurate edges when upsampling.
|
||||
"""
|
||||
|
||||
def __init__(self, ndim, scale_factor, narrow=True):
|
||||
super().__init__()
|
||||
|
||||
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
|
||||
self.mode = modes[ndim]
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.narrow = narrow
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=self.scale_factor,
|
||||
mode=self.mode, align_corners=False)
|
||||
x = F.interpolate(
|
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False
|
||||
)
|
||||
|
||||
if self.scale_factor > 1 and self.narrow == True:
|
||||
edges = round(self.scale_factor) // 2
|
||||
|
|
|
@ -4,8 +4,18 @@ from torch.nn.utils import spectral_norm, remove_spectral_norm
|
|||
|
||||
def add_spectral_norm(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
if isinstance(
|
||||
child,
|
||||
(
|
||||
nn.Linear,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
),
|
||||
):
|
||||
setattr(module, name, spectral_norm(child))
|
||||
else:
|
||||
add_spectral_norm(child)
|
||||
|
@ -13,8 +23,18 @@ def add_spectral_norm(module):
|
|||
|
||||
def rm_spectral_norm(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
if isinstance(
|
||||
child,
|
||||
(
|
||||
nn.Linear,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
),
|
||||
):
|
||||
setattr(module, name, remove_spectral_norm(child))
|
||||
else:
|
||||
rm_spectral_norm(child)
|
||||
|
|
|
@ -7,9 +7,17 @@ from .resample import Resampler
|
|||
|
||||
|
||||
class G(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_chan,
|
||||
out_chan,
|
||||
scale_factor=16,
|
||||
chan_base=512,
|
||||
chan_min=64,
|
||||
chan_max=512,
|
||||
cat_noise=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
|
@ -30,15 +38,14 @@ class G(nn.Module):
|
|||
|
||||
self.blocks = nn.ModuleList()
|
||||
for b in range(num_blocks):
|
||||
prev_chan, next_chan = chan(b), chan(b+1)
|
||||
self.blocks.append(
|
||||
HBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||
prev_chan, next_chan = chan(b), chan(b + 1)
|
||||
self.blocks.append(HBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||
|
||||
def forward(self, x):
|
||||
y = x # direct upsampling from the input
|
||||
x = self.block0(x)
|
||||
|
||||
#y = None # no direct upsampling from the input
|
||||
# y = None # no direct upsampling from the input
|
||||
for block in self.blocks:
|
||||
x, y = block(x, y)
|
||||
|
||||
|
@ -71,6 +78,7 @@ class HBlock(nn.Module):
|
|||
-----
|
||||
next_size = 2 * prev_size - 6
|
||||
"""
|
||||
|
||||
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
|
||||
super().__init__()
|
||||
|
||||
|
@ -81,7 +89,6 @@ class HBlock(nn.Module):
|
|||
self.upsample,
|
||||
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
|
||||
AddNoise(cat_noise, chan=next_chan),
|
||||
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
|
@ -114,6 +121,7 @@ class AddNoise(nn.Module):
|
|||
The number of channels `chan` should be 1 (StyleGAN2)
|
||||
or that of the input (StyleGAN).
|
||||
"""
|
||||
|
||||
def __init__(self, cat, chan=1):
|
||||
super().__init__()
|
||||
|
||||
|
@ -137,9 +145,16 @@ class AddNoise(nn.Module):
|
|||
|
||||
|
||||
class D(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_chan,
|
||||
out_chan,
|
||||
scale_factor=16,
|
||||
chan_base=512,
|
||||
chan_min=64,
|
||||
chan_max=512,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
|
@ -163,7 +178,7 @@ class D(nn.Module):
|
|||
|
||||
self.blocks = nn.ModuleList()
|
||||
for b in reversed(range(num_blocks)):
|
||||
prev_chan, next_chan = chan(b+1), chan(b)
|
||||
prev_chan, next_chan = chan(b + 1), chan(b)
|
||||
self.blocks.append(ResBlock(prev_chan, next_chan))
|
||||
|
||||
self.block9 = nn.Sequential(
|
||||
|
@ -192,13 +207,13 @@ class ResBlock(nn.Module):
|
|||
-----
|
||||
next_size = (prev_size - 4) // 2
|
||||
"""
|
||||
|
||||
def __init__(self, prev_chan, next_chan):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(prev_chan, prev_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
|
||||
nn.Conv3d(prev_chan, next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ class PixelNorm(nn.Module):
|
|||
|
||||
See ProGAN, StyleGAN.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -23,6 +24,7 @@ class LinearElr(nn.Module):
|
|||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
|
||||
def __init__(self, in_size, out_size, bias=True, act=None):
|
||||
super().__init__()
|
||||
|
||||
|
@ -32,7 +34,7 @@ class LinearElr(nn.Module):
|
|||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_size))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.act = act
|
||||
|
||||
|
@ -52,20 +54,20 @@ class ConvElr3d(nn.Module):
|
|||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan, kernel_size,
|
||||
stride=1, padding=0, bias=True):
|
||||
|
||||
def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
||||
)
|
||||
fan_in = in_chan * kernel_size ** 3
|
||||
fan_in = in_chan * kernel_size**3
|
||||
self.wnorm = 1 / math.sqrt(fan_in)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
@ -87,38 +89,49 @@ class ConvStyled3d(nn.Module):
|
|||
|
||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
|
||||
bias=True, resample=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
style_size,
|
||||
in_chan,
|
||||
out_chan,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
bias=True,
|
||||
resample=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
||||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||
mode='fan_in', nonlinearity='leaky_relu')
|
||||
nn.init.kaiming_uniform_(
|
||||
self.style_weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"
|
||||
)
|
||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||
if resample is None:
|
||||
K3 = (kernel_size,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = stride
|
||||
self.conv = F.conv3d
|
||||
elif resample == 'U':
|
||||
elif resample == "U":
|
||||
K3 = (2,) * 3
|
||||
# NOTE not clear to me why convtranspose have channels swapped
|
||||
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv_transpose3d
|
||||
elif resample == 'D':
|
||||
elif resample == "D":
|
||||
K3 = (2,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
raise ValueError('resample type {} not supported'.format(resample))
|
||||
raise ValueError("resample type {} not supported".format(resample))
|
||||
self.resample = resample
|
||||
|
||||
nn.init.kaiming_uniform_(
|
||||
self.weight, a=math.sqrt(5),
|
||||
mode='fan_in', # effectively 'fan_out' for 'D'
|
||||
nonlinearity='leaky_relu',
|
||||
self.weight,
|
||||
a=math.sqrt(5),
|
||||
mode="fan_in", # effectively 'fan_out' for 'D'
|
||||
nonlinearity="leaky_relu",
|
||||
)
|
||||
|
||||
if bias:
|
||||
|
@ -127,12 +140,12 @@ class ConvStyled3d(nn.Module):
|
|||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x, s, eps=1e-8):
|
||||
N, Cin, *DHWin = x.shape
|
||||
C0, C1, *K3 = self.weight.shape
|
||||
if self.resample == 'U':
|
||||
if self.resample == "U":
|
||||
Cin, Cout = C0, C1
|
||||
else:
|
||||
Cout, Cin = C0, C1
|
||||
|
@ -140,14 +153,14 @@ class ConvStyled3d(nn.Module):
|
|||
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
||||
|
||||
# modulation
|
||||
if self.resample == 'U':
|
||||
if self.resample == "U":
|
||||
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
||||
else:
|
||||
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
||||
w = self.weight * s
|
||||
|
||||
# demodulation
|
||||
if self.resample == 'U':
|
||||
if self.resample == "U":
|
||||
fan_in_dim = (1, 3, 4, 5)
|
||||
else:
|
||||
fan_in_dim = (2, 3, 4, 5)
|
||||
|
@ -161,18 +174,22 @@ class ConvStyled3d(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
class BatchNormStyled3d(nn.BatchNorm3d) :
|
||||
""" Trivially does standard batch normalization, but accepts second argument
|
||||
|
||||
class BatchNormStyled3d(nn.BatchNorm3d):
|
||||
"""Trivially does standard batch normalization, but accepts second argument
|
||||
|
||||
for style array that is not used
|
||||
"""
|
||||
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class LeakyReLUStyled(nn.LeakyReLU):
|
||||
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||
"""Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||
|
||||
for sytle array that is not used
|
||||
"""
|
||||
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
||||
|
|
|
@ -16,8 +16,17 @@ class ConvStyledBlock(nn.Module):
|
|||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||
'D': downsampling convolution of kernel size 2 and stride 2
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBA'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
style_size,
|
||||
in_chan,
|
||||
out_chan=None,
|
||||
mid_chan=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
seq="CBA",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if out_chan is None:
|
||||
|
@ -33,31 +42,34 @@ class ConvStyledBlock(nn.Module):
|
|||
|
||||
self.norm_chan = in_chan
|
||||
self.idx_conv = 0
|
||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
|
||||
|
||||
layers = [self._get_layer(l) for l in seq]
|
||||
|
||||
self.convs = nn.ModuleList(layers)
|
||||
|
||||
def _get_layer(self, l):
|
||||
if l == 'U':
|
||||
if l == "U":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'U')
|
||||
elif l == 'D':
|
||||
return ConvStyled3d(
|
||||
self.style_size, in_chan, out_chan, 2, stride=2, resample="U"
|
||||
)
|
||||
elif l == "D":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'D')
|
||||
elif l == 'C':
|
||||
return ConvStyled3d(
|
||||
self.style_size, in_chan, out_chan, 2, stride=2, resample="D"
|
||||
)
|
||||
elif l == "C":
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
||||
stride=self.stride)
|
||||
elif l == 'B':
|
||||
return ConvStyled3d(
|
||||
self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride
|
||||
)
|
||||
elif l == "B":
|
||||
return BatchNormStyled3d(self.norm_chan)
|
||||
elif l == 'A':
|
||||
elif l == "A":
|
||||
return LeakyReLUStyled()
|
||||
else:
|
||||
raise ValueError('layer type {} not supported'.format(l))
|
||||
raise ValueError("layer type {} not supported".format(l))
|
||||
|
||||
def _setup_conv(self):
|
||||
self.idx_conv += 1
|
||||
|
@ -92,13 +104,23 @@ class ResStyledBlock(ConvStyledBlock):
|
|||
|
||||
See `ConvStyledBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
style_size,
|
||||
in_chan,
|
||||
out_chan=None,
|
||||
mid_chan=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
seq="CBACBA",
|
||||
last_act=None,
|
||||
):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
last_act = seq[-1] == "A"
|
||||
elif last_act and seq[-1] != "A":
|
||||
warnings.warn(
|
||||
'Disabling last_act without trailing activation in seq',
|
||||
"Disabling last_act without trailing activation in seq",
|
||||
RuntimeWarning,
|
||||
)
|
||||
last_act = False
|
||||
|
@ -106,8 +128,15 @@ class ResStyledBlock(ConvStyledBlock):
|
|||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||
super().__init__(
|
||||
style_size,
|
||||
in_chan,
|
||||
out_chan=out_chan,
|
||||
mid_chan=mid_chan,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
seq=seq,
|
||||
)
|
||||
|
||||
if last_act:
|
||||
self.act = LeakyReLUStyled()
|
||||
|
@ -119,9 +148,10 @@ class ResStyledBlock(ConvStyledBlock):
|
|||
else:
|
||||
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
if "U" in seq or "D" in seq:
|
||||
raise NotImplementedError(
|
||||
"upsample and downsample layers " "not supported yet"
|
||||
)
|
||||
|
||||
def forward(self, x, s):
|
||||
y = x
|
||||
|
|
|
@ -15,17 +15,17 @@ class StyledVNet(nn.Module):
|
|||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq="CACBA")
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq="DBA")
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq="DBA")
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
|
||||
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq="UBA")
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq="CBACBA")
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq="UBA")
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq="CAC")
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
|
|
|
@ -19,17 +19,17 @@ class UNet(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA')
|
||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l1 = ConvBlock(64, seq='CBACBA')
|
||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l0 = ConvBlock(in_chan, 64, seq="CACBA")
|
||||
self.down_l0 = ConvBlock(64, seq="DBA")
|
||||
self.conv_l1 = ConvBlock(64, seq="CBACBA")
|
||||
self.down_l1 = ConvBlock(64, seq="DBA")
|
||||
|
||||
self.conv_c = ConvBlock(64, seq='CBACBA')
|
||||
self.conv_c = ConvBlock(64, seq="CBACBA")
|
||||
|
||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r1 = ConvBlock(128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
|
||||
self.up_r1 = ConvBlock(64, seq="UBA")
|
||||
self.conv_r1 = ConvBlock(128, 64, seq="CBACBA")
|
||||
self.up_r0 = ConvBlock(64, seq="UBA")
|
||||
self.conv_r0 = ConvBlock(128, out_chan, seq="CAC")
|
||||
|
||||
self.bypass = in_chan == out_chan
|
||||
|
||||
|
|
|
@ -23,17 +23,17 @@ class VNet(nn.Module):
|
|||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||
self.conv_l0 = ResBlock(in_chan, 64, seq="CACBA")
|
||||
self.down_l0 = ConvBlock(64, seq="DBA")
|
||||
self.conv_l1 = ResBlock(64, 64, seq="CBACBA")
|
||||
self.down_l1 = ConvBlock(64, seq="DBA")
|
||||
|
||||
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
||||
self.conv_c = ResBlock(64, 64, seq="CBACBA")
|
||||
|
||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||
self.up_r1 = ConvBlock(64, seq="UBA")
|
||||
self.conv_r1 = ResBlock(128, 64, seq="CBACBA")
|
||||
self.up_r0 = ConvBlock(64, seq="UBA")
|
||||
self.conv_r0 = ResBlock(128, out_chan, seq="CAC")
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
|
|
|
@ -16,21 +16,21 @@ from .utils import import_attr, load_model_state_dict
|
|||
def test(args):
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.device_count() > 1:
|
||||
warnings.warn('Not parallelized but given more than 1 GPUs')
|
||||
warnings.warn("Not parallelized but given more than 1 GPUs")
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
device = torch.device('cuda', 0)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else: # CPU multithreading
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
|
||||
if args.num_threads is None:
|
||||
args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE'])
|
||||
args.num_threads = int(os.environ["SLURM_CPUS_ON_NODE"])
|
||||
|
||||
torch.set_num_threads(args.num_threads)
|
||||
|
||||
print('pytorch {}'.format(torch.__version__))
|
||||
print("pytorch {}".format(torch.__version__))
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
@ -67,26 +67,33 @@ def test(args):
|
|||
out_chan = test_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(style_size, sum(in_chan), sum(out_chan),
|
||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||
model = model(
|
||||
style_size,
|
||||
sum(in_chan),
|
||||
sum(out_chan),
|
||||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
criterion = import_attr(args.criterion, torch.nn, models,
|
||||
callback_at=args.callback_at)
|
||||
criterion = import_attr(
|
||||
args.criterion, torch.nn, models, callback_at=args.callback_at
|
||||
)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
||||
print('model state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state))
|
||||
load_model_state_dict(model, state["model"], strict=args.load_state_strict)
|
||||
print(
|
||||
"model state at epoch {} loaded from {}".format(state["epoch"], args.load_state)
|
||||
)
|
||||
del state
|
||||
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, data in enumerate(test_loader):
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
style, input, target = data["style"], data["input"], data["target"]
|
||||
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
|
@ -94,21 +101,21 @@ def test(args):
|
|||
|
||||
output = model(input, style)
|
||||
if i < 5:
|
||||
print('##### sample :', i)
|
||||
print('style shape :', style.shape)
|
||||
print('input shape :', input.shape)
|
||||
print('output shape :', output.shape)
|
||||
print('target shape :', target.shape)
|
||||
print("##### sample :", i)
|
||||
print("style shape :", style.shape)
|
||||
print("input shape :", input.shape)
|
||||
print("output shape :", output.shape)
|
||||
print("target shape :", target.shape)
|
||||
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
if i < 5:
|
||||
print('narrowed shape :', output.shape, flush=True)
|
||||
print("narrowed shape :", output.shape, flush=True)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
print('sample {} loss: {}'.format(i, loss.item()))
|
||||
print("sample {} loss: {}".format(i, loss.item()))
|
||||
|
||||
#if args.in_norms is not None:
|
||||
# if args.in_norms is not None:
|
||||
# start = 0
|
||||
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
|
@ -119,12 +126,11 @@ def test(args):
|
|||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
# norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
start = stop
|
||||
|
||||
#test_dataset.assemble('_in', in_chan, input,
|
||||
# test_dataset.assemble('_in', in_chan, input,
|
||||
# data['input_relpath'])
|
||||
test_dataset.assemble('_out', out_chan, output,
|
||||
data['target_relpath'])
|
||||
#test_dataset.assemble('_tgt', out_chan, target,
|
||||
test_dataset.assemble("_out", out_chan, output, data["target_relpath"])
|
||||
# test_dataset.assemble('_tgt', out_chan, target,
|
||||
# data['target_relpath'])
|
||||
|
|
269
map2map/train.py
269
map2map/train.py
|
@ -18,34 +18,34 @@ from .models import narrow_cast, resample
|
|||
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
||||
|
||||
|
||||
ckpt_link = 'checkpoint.pt'
|
||||
ckpt_link = "checkpoint.pt"
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||
elif 'SLURM_JOB_NUM_NODES' in os.environ:
|
||||
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
||||
if "SLURM_STEP_NUM_NODES" in os.environ:
|
||||
args.nodes = int(os.environ["SLURM_STEP_NUM_NODES"])
|
||||
elif "SLURM_JOB_NUM_NODES" in os.environ:
|
||||
args.nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
||||
else:
|
||||
raise KeyError('missing node counts in slurm env')
|
||||
raise KeyError("missing node counts in slurm env")
|
||||
args.gpus_per_node = torch.cuda.device_count()
|
||||
args.world_size = args.nodes * args.gpus_per_node
|
||||
|
||||
node = int(os.environ['SLURM_NODEID'])
|
||||
node = int(os.environ["SLURM_NODEID"])
|
||||
|
||||
if args.gpus_per_node < 1:
|
||||
raise RuntimeError('GPU not found on node {}'.format(node))
|
||||
raise RuntimeError("GPU not found on node {}".format(node))
|
||||
|
||||
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
|
||||
|
||||
|
||||
def gpu_worker(local_rank, node, args):
|
||||
#device = torch.device('cuda', local_rank)
|
||||
#torch.cuda.device(device) # env var recommended over this
|
||||
# device = torch.device('cuda', local_rank)
|
||||
# torch.cuda.device(device) # env var recommended over this
|
||||
|
||||
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
|
||||
device = torch.device('cuda', 0)
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
rank = args.gpus_per_node * node + local_rank
|
||||
|
||||
|
@ -53,7 +53,7 @@ def gpu_worker(local_rank, node, args):
|
|||
# Note DDP broadcasts initial model states from rank 0
|
||||
torch.manual_seed(args.seed + rank)
|
||||
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic
|
||||
#torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
||||
dist_init(rank, args)
|
||||
|
||||
|
@ -77,9 +77,12 @@ def gpu_worker(local_rank, node, args):
|
|||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
train_sampler = DistFieldSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
|
@ -110,9 +113,12 @@ def gpu_worker(local_rank, node, args):
|
|||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist)
|
||||
val_sampler = DistFieldSampler(
|
||||
val_dataset,
|
||||
shuffle=False,
|
||||
div_data=args.div_data,
|
||||
div_shuffle_dist=args.div_shuffle_dist,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
|
@ -127,14 +133,19 @@ def gpu_worker(local_rank, node, args):
|
|||
args.out_chan = train_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||
model = model(
|
||||
args.style_size,
|
||||
sum(args.in_chan),
|
||||
sum(args.out_chan),
|
||||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[device], process_group=dist.new_group()
|
||||
)
|
||||
|
||||
criterion = import_attr(args.criterion, nn, models,
|
||||
callback_at=args.callback_at)
|
||||
criterion = import_attr(args.criterion, nn, models, callback_at=args.callback_at)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
|
||||
|
@ -144,11 +155,13 @@ def gpu_worker(local_rank, node, args):
|
|||
lr=args.lr,
|
||||
**args.optimizer_args,
|
||||
)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, **args.scheduler_args)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **args.scheduler_args)
|
||||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
if (
|
||||
args.load_state == ckpt_link
|
||||
and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state
|
||||
):
|
||||
if args.init_weight_std is not None:
|
||||
model.apply(init_weights)
|
||||
|
||||
|
@ -159,23 +172,28 @@ def gpu_worker(local_rank, node, args):
|
|||
else:
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
|
||||
start_epoch = state['epoch']
|
||||
start_epoch = state["epoch"]
|
||||
|
||||
load_model_state_dict(model.module, state['model'],
|
||||
strict=args.load_state_strict)
|
||||
load_model_state_dict(
|
||||
model.module, state["model"], strict=args.load_state_strict
|
||||
)
|
||||
|
||||
if 'optimizer' in state:
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
if 'scheduler' in state:
|
||||
scheduler.load_state_dict(state['scheduler'])
|
||||
if "optimizer" in state:
|
||||
optimizer.load_state_dict(state["optimizer"])
|
||||
if "scheduler" in state:
|
||||
scheduler.load_state_dict(state["scheduler"])
|
||||
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
torch.set_rng_state(state["rng"].cpu()) # move rng state back
|
||||
|
||||
if rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
min_loss = state["min_loss"]
|
||||
|
||||
print('state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state), flush=True)
|
||||
print(
|
||||
"state at epoch {} loaded from {}".format(
|
||||
state["epoch"], args.load_state
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
del state
|
||||
|
||||
|
@ -189,21 +207,31 @@ def gpu_worker(local_rank, node, args):
|
|||
logger = SummaryWriter()
|
||||
|
||||
if rank == 0:
|
||||
print('pytorch {}'.format(torch.__version__))
|
||||
print("pytorch {}".format(torch.__version__))
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_loss = train(epoch, train_loader, model, criterion,
|
||||
optimizer, scheduler, logger, device, args)
|
||||
train_loss = train(
|
||||
epoch,
|
||||
train_loader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
logger,
|
||||
device,
|
||||
args,
|
||||
)
|
||||
epoch_loss = train_loss
|
||||
|
||||
if args.val:
|
||||
val_loss = validate(epoch, val_loader, model, criterion,
|
||||
logger, device, args)
|
||||
#epoch_loss = val_loss
|
||||
val_loss = validate(
|
||||
epoch, val_loader, model, criterion, logger, device, args
|
||||
)
|
||||
# epoch_loss = val_loss
|
||||
|
||||
if args.reduce_lr_on_plateau:
|
||||
scheduler.step(epoch_loss[2])
|
||||
|
@ -215,27 +243,26 @@ def gpu_worker(local_rank, node, args):
|
|||
min_loss = epoch_loss[2]
|
||||
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
"epoch": epoch + 1,
|
||||
"model": model.module.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"rng": torch.get_rng_state(),
|
||||
"min_loss": min_loss,
|
||||
}
|
||||
|
||||
state_file = 'state_{}.pt'.format(epoch + 1)
|
||||
state_file = "state_{}.pt".format(epoch + 1)
|
||||
torch.save(state, state_file)
|
||||
del state
|
||||
|
||||
tmp_link = '{}.pt'.format(time.time())
|
||||
tmp_link = "{}.pt".format(time.time())
|
||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||
os.rename(tmp_link, ckpt_link)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def train(epoch, loader, model, criterion,
|
||||
optimizer, scheduler, logger, device, args):
|
||||
def train(epoch, loader, model, criterion, optimizer, scheduler, logger, device, args):
|
||||
model.train()
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
@ -246,7 +273,7 @@ def train(epoch, loader, model, criterion,
|
|||
for i, data in enumerate(loader):
|
||||
batch = epoch * len(loader) + i + 1
|
||||
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
style, input, target = data["style"], data["input"], data["target"]
|
||||
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
|
@ -254,23 +281,21 @@ def train(epoch, loader, model, criterion,
|
|||
|
||||
output = model(input, style)
|
||||
if batch <= 5 and rank == 0:
|
||||
print('##### batch :', batch)
|
||||
print('style shape :', style.shape)
|
||||
print('input shape :', input.shape)
|
||||
print('output shape :', output.shape)
|
||||
print('target shape :', target.shape)
|
||||
print("##### batch :", batch)
|
||||
print("style shape :", style.shape)
|
||||
print("input shape :", input.shape)
|
||||
print("output shape :", output.shape)
|
||||
print("target shape :", target.shape)
|
||||
|
||||
if (hasattr(model.module, 'scale_factor')
|
||||
and model.module.scale_factor != 1):
|
||||
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
|
||||
input = resample(input, model.module.scale_factor, narrow=False)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
if batch <= 5 and rank == 0:
|
||||
print('narrowed shape :', output.shape)
|
||||
print("narrowed shape :", output.shape)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.detach()
|
||||
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -280,43 +305,45 @@ def train(epoch, loader, model, criterion,
|
|||
dist.all_reduce(loss)
|
||||
loss /= world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/batch/train', loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar("loss/batch/train", loss.item(), global_step=batch)
|
||||
|
||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||
logger.add_scalar("grad/first", grads[0], global_step=batch)
|
||||
logger.add_scalar("grad/last", grads[-1], global_step=batch)
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar("loss/epoch/train", epoch_loss[0], global_step=epoch + 1)
|
||||
|
||||
skip_chan = 0
|
||||
fig = plt_slices(
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
title=["in", "out", "tgt", "out - tgt"],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||
logger.add_figure("fig/train", fig, global_step=epoch + 1)
|
||||
fig.clf()
|
||||
|
||||
fig = plt_power(
|
||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||
label=['in', 'out', 'tgt'],
|
||||
input,
|
||||
output[:, skip_chan:],
|
||||
target[:, skip_chan:],
|
||||
label=["in", "out", "tgt"],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||
logger.add_figure("fig/train/power/lag", fig, global_step=epoch + 1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(1.0,
|
||||
# fig = plt_power(1.0,
|
||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
# )
|
||||
# logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||
# fig.clf()
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
@ -331,7 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
|
||||
with torch.no_grad():
|
||||
for data in loader:
|
||||
style, input, target = data['style'], data['input'], data['target']
|
||||
style, input, target = data["style"], data["input"], data["target"]
|
||||
|
||||
style = style.to(device, non_blocking=True)
|
||||
input = input.to(device, non_blocking=True)
|
||||
|
@ -339,8 +366,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
|
||||
output = model(input, style)
|
||||
|
||||
if (hasattr(model.module, 'scale_factor')
|
||||
and model.module.scale_factor != 1):
|
||||
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
|
||||
input = resample(input, model.module.scale_factor, narrow=False)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
|
@ -350,40 +376,43 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar("loss/epoch/val", epoch_loss[0], global_step=epoch + 1)
|
||||
skip_chan = 0
|
||||
|
||||
fig = plt_slices(
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
title=["in", "out", "tgt", "out - tgt"],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||
logger.add_figure("fig/val", fig, global_step=epoch + 1)
|
||||
fig.clf()
|
||||
|
||||
fig = plt_power(
|
||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||
label=['in', 'out', 'tgt'],
|
||||
input,
|
||||
output[:, skip_chan:],
|
||||
target[:, skip_chan:],
|
||||
label=["in", "out", "tgt"],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||
logger.add_figure("fig/val/power/lag", fig, global_step=epoch + 1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(1.0,
|
||||
# fig = plt_power(1.0,
|
||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
# )
|
||||
# logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||
# fig.clf()
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def dist_init(rank, args):
|
||||
dist_file = 'dist_addr'
|
||||
dist_file = "dist_addr"
|
||||
|
||||
if rank == 0:
|
||||
addr = socket.gethostname()
|
||||
|
@ -393,15 +422,15 @@ def dist_init(rank, args):
|
|||
s.bind((addr, 0))
|
||||
_, port = s.getsockname()
|
||||
|
||||
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
|
||||
args.dist_addr = "tcp://{}:{}".format(addr, port)
|
||||
|
||||
with open(dist_file, mode='w') as f:
|
||||
with open(dist_file, mode="w") as f:
|
||||
f.write(args.dist_addr)
|
||||
else:
|
||||
while not os.path.exists(dist_file):
|
||||
time.sleep(1)
|
||||
|
||||
with open(dist_file, mode='r') as f:
|
||||
with open(dist_file, mode="r") as f:
|
||||
args.dist_addr = f.read()
|
||||
|
||||
dist.init_process_group(
|
||||
|
@ -417,19 +446,40 @@ def dist_init(rank, args):
|
|||
|
||||
|
||||
def init_weights(m, args):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
if isinstance(
|
||||
m,
|
||||
(
|
||||
nn.Linear,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
),
|
||||
):
|
||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
elif isinstance(
|
||||
m,
|
||||
(
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm,
|
||||
nn.LayerNorm,
|
||||
nn.GroupNorm,
|
||||
nn.InstanceNorm1d,
|
||||
nn.InstanceNorm2d,
|
||||
nn.InstanceNorm3d,
|
||||
),
|
||||
):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
|
||||
def set_requires_grad(module: torch.nn.Module, requires_grad: bool = False):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
|
@ -444,8 +494,7 @@ def get_grads(model: torch.nn.Module):
|
|||
Returns:
|
||||
A list containing the norms of the gradients of the first and the last layer weights.
|
||||
"""
|
||||
grads = list(p.grad for n, p in model.named_parameters()
|
||||
if '.weight' in n)
|
||||
grads = list(p.grad for n, p in model.named_parameters() if ".weight" in n)
|
||||
grads = [grads[0], grads[-1]]
|
||||
grads = [g.detach().norm() for g in grads]
|
||||
return grads
|
||||
|
|
|
@ -2,11 +2,13 @@ from math import log2, log10, ceil
|
|||
import torch
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
||||
from matplotlib.cm import ScalarMappable
|
||||
plt.rc('text', usetex=False)
|
||||
|
||||
plt.rc("text", usetex=False)
|
||||
|
||||
from ..models import lag2eul, power
|
||||
|
||||
|
@ -21,7 +23,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
Each field should have a channel dimension followed by spatial dimensions,
|
||||
i.e. no batch dimension.
|
||||
"""
|
||||
plt.close('all')
|
||||
plt.close("all")
|
||||
|
||||
assert all(isinstance(field, torch.Tensor) for field in fields)
|
||||
|
||||
|
@ -38,11 +40,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
im_size = 2
|
||||
cbar_height = 0.2
|
||||
fig, axes = plt.subplots(
|
||||
nc + 1, nf,
|
||||
nc + 1,
|
||||
nf,
|
||||
squeeze=False,
|
||||
figsize=(nf * im_size, nc * im_size + cbar_height),
|
||||
dpi=100,
|
||||
gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]}
|
||||
gridspec_kw={"height_ratios": nc * [im_size] + [cbar_height]},
|
||||
)
|
||||
|
||||
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
|
||||
|
@ -51,11 +54,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
|
||||
if cmap_col is None:
|
||||
if all_non_neg:
|
||||
cmap_col = 'inferno'
|
||||
cmap_col = "inferno"
|
||||
elif all_non_pos:
|
||||
cmap_col = 'inferno_r'
|
||||
cmap_col = "inferno_r"
|
||||
else:
|
||||
cmap_col = 'RdBu_r'
|
||||
cmap_col = "RdBu_r"
|
||||
|
||||
if norm_col is None:
|
||||
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
||||
|
@ -70,9 +73,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
if l1 < 0.1 * l2 or h2 == 0:
|
||||
norm_col = Normalize(vmin=-quantize(-l2), vmax=0)
|
||||
else:
|
||||
norm_col = SymLogNorm(linthresh=quantize(-h2),
|
||||
vmin=-quantize(-l2),
|
||||
vmax=-quantize(-h2))
|
||||
norm_col = SymLogNorm(
|
||||
linthresh=quantize(-h2),
|
||||
vmin=-quantize(-l2),
|
||||
vmax=-quantize(-h2),
|
||||
)
|
||||
else:
|
||||
vlim = quantize(max(-l2, h2))
|
||||
if w1 > 0.1 * w2 or l1 * h1 >= 0:
|
||||
|
@ -80,8 +85,13 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
else:
|
||||
linthresh = quantize(min(-l1, h1))
|
||||
linscale = np.log10(vlim / linthresh)
|
||||
norm_col = SymLogNorm(linthresh=linthresh, linscale=linscale,
|
||||
vmin=-vlim, vmax=vlim, base=10)
|
||||
norm_col = SymLogNorm(
|
||||
linthresh=linthresh,
|
||||
linscale=linscale,
|
||||
vmin=-vlim,
|
||||
vmax=vlim,
|
||||
base=10,
|
||||
)
|
||||
|
||||
for c in range(field.shape[0]):
|
||||
s = (c,) + tuple(d // 2 for d in field.shape[1:-2])
|
||||
|
@ -101,7 +111,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
|
||||
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
|
||||
|
||||
axes[c, f].set_aspect('equal')
|
||||
axes[c, f].set_aspect("equal")
|
||||
|
||||
axes[c, f].set_xticks([])
|
||||
axes[c, f].set_yticks([])
|
||||
|
@ -110,12 +120,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||
axes[c, f].set_title(title[f])
|
||||
|
||||
for c in range(field.shape[0], nc):
|
||||
axes[c, f].axis('off')
|
||||
axes[c, f].axis("off")
|
||||
|
||||
fig.colorbar(
|
||||
ScalarMappable(norm=norm_col, cmap=cmap_col),
|
||||
cax=axes[-1, f],
|
||||
orientation='horizontal',
|
||||
orientation="horizontal",
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
|
@ -133,7 +143,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||
|
||||
See `map2map.models.power`.
|
||||
"""
|
||||
plt.close('all')
|
||||
plt.close("all")
|
||||
|
||||
if label is not None:
|
||||
assert len(label) == len(fields) or len(label) == len(dis)
|
||||
|
@ -159,8 +169,8 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
|
|||
axes.loglog(k, P, label=l, alpha=0.7)
|
||||
|
||||
axes.legend()
|
||||
axes.set_xlabel('unnormalized wavenumber')
|
||||
axes.set_ylabel('unnormalized power')
|
||||
axes.set_xlabel("unnormalized wavenumber")
|
||||
axes.set_ylabel("unnormalized power")
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
|
||||
from 'path/to/cb_dir/mod.py'.
|
||||
"""
|
||||
if name.count('.') == 0:
|
||||
if name.count(".") == 0:
|
||||
attr = name
|
||||
|
||||
errors = []
|
||||
|
@ -34,23 +34,22 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||
|
||||
raise Exception(errors)
|
||||
else:
|
||||
mod, attr = name.rsplit('.', 1)
|
||||
mod, attr = name.rsplit(".", 1)
|
||||
|
||||
errors = []
|
||||
|
||||
for pkg in pkgs:
|
||||
try:
|
||||
return getattr(
|
||||
importlib.import_module(pkg.__name__ + '.' + mod), attr)
|
||||
return getattr(importlib.import_module(pkg.__name__ + "." + mod), attr)
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
errors.append(e)
|
||||
|
||||
if callback_at is None:
|
||||
raise Exception(errors)
|
||||
|
||||
callback_at = os.path.join(callback_at, mod + '.py')
|
||||
callback_at = os.path.join(callback_at, mod + ".py")
|
||||
if not os.path.isfile(callback_at):
|
||||
raise FileNotFoundError('callback file not found')
|
||||
raise FileNotFoundError("callback file not found")
|
||||
|
||||
if mod in sys.modules:
|
||||
return getattr(sys.modules[mod], attr)
|
||||
|
|
|
@ -7,9 +7,13 @@ def load_model_state_dict(module, state_dict, strict=True):
|
|||
bad_keys = module.load_state_dict(state_dict, strict)
|
||||
|
||||
if len(bad_keys.missing_keys) > 0:
|
||||
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
||||
pformat(bad_keys.missing_keys)))
|
||||
warnings.warn(
|
||||
"Missing keys in state_dict:\n{}".format(pformat(bad_keys.missing_keys))
|
||||
)
|
||||
if len(bad_keys.unexpected_keys) > 0:
|
||||
warnings.warn('Unexpected keys in state_dict:\n{}'.format(
|
||||
pformat(bad_keys.unexpected_keys)))
|
||||
warnings.warn(
|
||||
"Unexpected keys in state_dict:\n{}".format(
|
||||
pformat(bad_keys.unexpected_keys)
|
||||
)
|
||||
)
|
||||
sys.stderr.flush()
|
||||
|
|
Loading…
Reference in a new issue