chore: Blackify the source code for pretty formatting

This commit is contained in:
Guilhem Lavaux 2024-05-20 17:50:32 +02:00
parent ec9732df21
commit 7361c3eb9d
27 changed files with 908 additions and 568 deletions

View File

@ -7,18 +7,16 @@ from .train import ckpt_link
def get_args(): def get_args():
"""Parse arguments and set runtime defaults. """Parse arguments and set runtime defaults."""
""" parser = argparse.ArgumentParser(description="Transform field(s) to field(s)")
parser = argparse.ArgumentParser(
description='Transform field(s) to field(s)')
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True) subparsers = parser.add_subparsers(title="modes", dest="mode", required=True)
train_parser = subparsers.add_parser( train_parser = subparsers.add_parser(
'train', "train",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
test_parser = subparsers.add_parser( test_parser = subparsers.add_parser(
'test', "test",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
@ -27,163 +25,293 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
if args.mode == 'train': if args.mode == "train":
set_train_args(args) set_train_args(args)
elif args.mode == 'test': elif args.mode == "test":
set_test_args(args) set_test_args(args)
return args return args
def add_common_args(parser): def add_common_args(parser):
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list ' parser.add_argument(
'of input normalization functions') "--in-norms",
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list ' type=str_list,
'of target normalization functions') help="comma-sep. list " "of input normalization functions",
parser.add_argument('--crop', type=int_tuple, )
help='size to crop the input and target data. Default is the ' parser.add_argument(
'field size. Comma-sep. list of 1 or d integers') "--tgt-norms",
parser.add_argument('--crop-start', type=int_tuple, type=str_list,
help='starting point of the first crop. Default is the origin. ' help="comma-sep. list " "of target normalization functions",
'Comma-sep. list of 1 or d integers') )
parser.add_argument('--crop-stop', type=int_tuple, parser.add_argument(
help='stopping point of the last crop. Default is the opposite ' "--crop",
'corner to the origin. Comma-sep. list of 1 or d integers') type=int_tuple,
parser.add_argument('--crop-step', type=int_tuple, help="size to crop the input and target data. Default is the "
help='spacing between crops. Default is the crop size. ' "field size. Comma-sep. list of 1 or d integers",
'Comma-sep. list of 1 or d integers') )
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple, parser.add_argument(
help='size to pad the input data beyond the crop size, assuming ' "--crop-start",
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 ' type=int_tuple,
'integers, to pad equally along all axes, symmetrically on each, ' help="starting point of the first crop. Default is the origin. "
'or by the specified size on every boundary, respectively') "Comma-sep. list of 1 or d integers",
parser.add_argument('--tgt-pad', default=0, type=int_tuple, )
help='size to pad the target data beyond the crop size, assuming ' parser.add_argument(
'periodic boundary condition, useful for super-resolution. ' "--crop-stop",
'Comma-sep. list with the same format as --in-pad') type=int_tuple,
parser.add_argument('--scale-factor', default=1, type=int, help="stopping point of the last crop. Default is the opposite "
help='upsampling factor for super-resolution, in which case ' "corner to the origin. Comma-sep. list of 1 or d integers",
'crop and pad are sizes of the input resolution') )
parser.add_argument(
"--crop-step",
type=int_tuple,
help="spacing between crops. Default is the crop size. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--in-pad",
"--pad",
default=0,
type=int_tuple,
help="size to pad the input data beyond the crop size, assuming "
"periodic boundary condition. Comma-sep. list of 1, d, or dx2 "
"integers, to pad equally along all axes, symmetrically on each, "
"or by the specified size on every boundary, respectively",
)
parser.add_argument(
"--tgt-pad",
default=0,
type=int_tuple,
help="size to pad the target data beyond the crop size, assuming "
"periodic boundary condition, useful for super-resolution. "
"Comma-sep. list with the same format as --in-pad",
)
parser.add_argument(
"--scale-factor",
default=1,
type=int,
help="upsampling factor for super-resolution, in which case "
"crop and pad are sizes of the input resolution",
)
parser.add_argument('--model', type=str, required=True, parser.add_argument("--model", type=str, required=True, help="(generator) model")
help='(generator) model') parser.add_argument(
parser.add_argument('--criterion', default='MSELoss', type=str, "--criterion", default="MSELoss", type=str, help="loss function"
help='loss function') )
parser.add_argument('--load-state', default=ckpt_link, type=str, parser.add_argument(
help='path to load the states of model, optimizer, rng, etc. ' "--load-state",
'Default is the checkpoint. ' default=ckpt_link,
'Start from scratch in case of empty string or missing checkpoint') type=str,
parser.add_argument('--load-state-non-strict', action='store_false', help="path to load the states of model, optimizer, rng, etc. "
help='allow incompatible keys when loading model states', "Default is the checkpoint. "
dest='load_state_strict') "Start from scratch in case of empty string or missing checkpoint",
)
parser.add_argument(
"--load-state-non-strict",
action="store_false",
help="allow incompatible keys when loading model states",
dest="load_state_strict",
)
# somehow I named it "batches" instead of batch_size at first # somehow I named it "batches" instead of batch_size at first
# "batches" is kept for now for backward compatibility # "batches" is kept for now for backward compatibility
parser.add_argument('--batch-size', '--batches', type=int, required=True, parser.add_argument(
help='mini-batch size, per GPU in training or in total in testing') "--batch-size",
parser.add_argument('--loader-workers', default=8, type=int, "--batches",
help='number of subprocesses per data loader. ' type=int,
'0 to disable multiprocessing') required=True,
help="mini-batch size, per GPU in training or in total in testing",
)
parser.add_argument(
"--loader-workers",
default=8,
type=int,
help="number of subprocesses per data loader. " "0 to disable multiprocessing",
)
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s), parser.add_argument(
help='directory of custorm code defining callbacks for models, ' "--callback-at",
'norms, criteria, and optimizers. Disabled if not set. ' type=lambda s: os.path.abspath(s),
'This is appended to the default locations, ' help="directory of custorm code defining callbacks for models, "
'thus has the lowest priority') "norms, criteria, and optimizers. Disabled if not set. "
parser.add_argument('--misc-kwargs', default='{}', type=json.loads, "This is appended to the default locations, "
help='miscellaneous keyword arguments for custom models and ' "thus has the lowest priority",
'norms. Be careful with name collisions') )
parser.add_argument(
"--misc-kwargs",
default="{}",
type=json.loads,
help="miscellaneous keyword arguments for custom models and "
"norms. Be careful with name collisions",
)
def add_train_args(parser): def add_train_args(parser):
add_common_args(parser) add_common_args(parser)
parser.add_argument('--train-style-pattern', type=str, required=True, parser.add_argument(
help='glob pattern for training data styles') "--train-style-pattern",
parser.add_argument('--train-in-patterns', type=str_list, required=True, type=str,
help='comma-sep. list of glob patterns for training input data') required=True,
parser.add_argument('--train-tgt-patterns', type=str_list, required=True, help="glob pattern for training data styles",
help='comma-sep. list of glob patterns for training target data') )
parser.add_argument('--val-style-pattern', type=str, parser.add_argument(
help='glob pattern for validation data styles') "--train-in-patterns",
parser.add_argument('--val-in-patterns', type=str_list, type=str_list,
help='comma-sep. list of glob patterns for validation input data') required=True,
parser.add_argument('--val-tgt-patterns', type=str_list, help="comma-sep. list of glob patterns for training input data",
help='comma-sep. list of glob patterns for validation target data') )
parser.add_argument('--augment', action='store_true', parser.add_argument(
help='enable data augmentation of axis flipping and permutation') "--train-tgt-patterns",
parser.add_argument('--aug-shift', type=int_tuple, type=str_list,
help='data augmentation by shifting cropping by [0, aug_shift) pixels, ' required=True,
'useful for models that treat neighboring pixels differently, ' help="comma-sep. list of glob patterns for training target data",
'e.g. with strided convolutions. ' )
'Comma-sep. list of 1 or d integers') parser.add_argument(
parser.add_argument('--aug-add', type=float, "--val-style-pattern", type=str, help="glob pattern for validation data styles"
help='additive data augmentation, (normal) std, ' )
'same factor for all fields') parser.add_argument(
parser.add_argument('--aug-mul', type=float, "--val-in-patterns",
help='multiplicative data augmentation, (log-normal) std, ' type=str_list,
'same factor for all fields') help="comma-sep. list of glob patterns for validation input data",
)
parser.add_argument(
"--val-tgt-patterns",
type=str_list,
help="comma-sep. list of glob patterns for validation target data",
)
parser.add_argument(
"--augment",
action="store_true",
help="enable data augmentation of axis flipping and permutation",
)
parser.add_argument(
"--aug-shift",
type=int_tuple,
help="data augmentation by shifting cropping by [0, aug_shift) pixels, "
"useful for models that treat neighboring pixels differently, "
"e.g. with strided convolutions. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--aug-add",
type=float,
help="additive data augmentation, (normal) std, " "same factor for all fields",
)
parser.add_argument(
"--aug-mul",
type=float,
help="multiplicative data augmentation, (log-normal) std, "
"same factor for all fields",
)
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument(
help='optimization algorithm') "--optimizer", default="Adam", type=str, help="optimization algorithm"
parser.add_argument('--lr', type=float, required=True, )
help='initial learning rate') parser.add_argument("--lr", type=float, required=True, help="initial learning rate")
parser.add_argument('--optimizer-args', default='{}', type=json.loads, parser.add_argument(
help='optimizer arguments in addition to the learning rate, ' "--optimizer-args",
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'') default="{}",
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--scheduler-args', default='{"verbose": true}',
type=json.loads, type=json.loads,
help='arguments for the ReduceLROnPlateau scheduler') help="optimizer arguments in addition to the learning rate, "
parser.add_argument('--init-weight-std', type=float, "e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'",
help='weight initialization std') )
parser.add_argument('--epochs', default=128, type=int, parser.add_argument(
help='total number of epochs to run') "--reduce-lr-on-plateau",
parser.add_argument('--seed', default=42, type=int, action="store_true",
help='seed for initializing training') help="Enable ReduceLROnPlateau learning rate scheduler",
)
parser.add_argument(
"--scheduler-args",
default='{"verbose": true}',
type=json.loads,
help="arguments for the ReduceLROnPlateau scheduler",
)
parser.add_argument(
"--init-weight-std", type=float, help="weight initialization std"
)
parser.add_argument(
"--epochs", default=128, type=int, help="total number of epochs to run"
)
parser.add_argument(
"--seed", default=42, type=int, help="seed for initializing training"
)
parser.add_argument('--div-data', action='store_true', parser.add_argument(
help='enable data division among GPUs for better page caching. ' "--div-data",
'Data division is shuffled every epoch. ' action="store_true",
'Only relevant if there are multiple crops in each field') help="enable data division among GPUs for better page caching. "
parser.add_argument('--div-shuffle-dist', default=1, type=float, "Data division is shuffled every epoch. "
help='distance to further shuffle cropped samples relative to ' "Only relevant if there are multiple crops in each field",
'their fields, to be used with --div-data. ' )
'Only relevant if there are multiple crops in each file. ' parser.add_argument(
'The order of each sample is randomly displaced by this value. ' "--div-shuffle-dist",
'Setting it to 0 turn off this randomization, and setting it to N ' default=1,
'limits the shuffling within a distance of N files. ' type=float,
'Change this to balance cache locality and stochasticity') help="distance to further shuffle cropped samples relative to "
parser.add_argument('--dist-backend', default='nccl', type=str, "their fields, to be used with --div-data. "
choices=['gloo', 'nccl'], help='distributed backend') "Only relevant if there are multiple crops in each file. "
parser.add_argument('--log-interval', default=100, type=int, "The order of each sample is randomly displaced by this value. "
help='interval (batches) between logging training loss') "Setting it to 0 turn off this randomization, and setting it to N "
parser.add_argument('--detect-anomaly', action='store_true', "limits the shuffling within a distance of N files. "
help='enable anomaly detection for the autograd engine') "Change this to balance cache locality and stochasticity",
)
parser.add_argument(
"--dist-backend",
default="nccl",
type=str,
choices=["gloo", "nccl"],
help="distributed backend",
)
parser.add_argument(
"--log-interval",
default=100,
type=int,
help="interval (batches) between logging training loss",
)
parser.add_argument(
"--detect-anomaly",
action="store_true",
help="enable anomaly detection for the autograd engine",
)
def add_test_args(parser): def add_test_args(parser):
add_common_args(parser) add_common_args(parser)
parser.add_argument('--test-style-pattern', type=str, required=True, parser.add_argument(
help='glob pattern for test data styles') "--test-style-pattern",
parser.add_argument('--test-in-patterns', type=str_list, required=True, type=str,
help='comma-sep. list of glob patterns for test input data') required=True,
parser.add_argument('--test-tgt-patterns', type=str_list, required=True, help="glob pattern for test data styles",
help='comma-sep. list of glob patterns for test target data') )
parser.add_argument(
"--test-in-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test input data",
)
parser.add_argument(
"--test-tgt-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test target data",
)
parser.add_argument('--num-threads', type=int, parser.add_argument(
help='number of CPU threads when cuda is unavailable. ' "--num-threads",
'Default is the number of CPUs on the node by slurm') type=int,
help="number of CPU threads when cuda is unavailable. "
"Default is the number of CPUs on the node by slurm",
)
def str_list(s): def str_list(s):
return s.split(',') return s.split(",")
def int_tuple(s): def int_tuple(s):
t = s.split(',') t = s.split(",")
t = tuple(int(i) for i in t) t = tuple(int(i) for i in t)
if len(t) == 1: if len(t) == 1:
return t[0] return t[0]
@ -198,8 +326,7 @@ def set_common_args(args):
def set_train_args(args): def set_train_args(args):
set_common_args(args) set_common_args(args)
args.val = args.val_in_patterns is not None and \ args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None
args.val_tgt_patterns is not None
def set_test_args(args): def set_test_args(args):

View File

@ -43,54 +43,72 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution. the input resolution.
""" """
def __init__(self, style_pattern, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None, def __init__(
augment=False, aug_shift=None, aug_add=None, aug_mul=None, self,
crop=None, crop_start=None, crop_stop=None, crop_step=None, style_pattern,
in_pad=0, tgt_pad=0, scale_factor=1, in_patterns,
**kwargs): tgt_patterns,
in_norms=None,
tgt_norms=None,
callback_at=None,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=None,
crop_start=None,
crop_stop=None,
crop_step=None,
in_pad=0,
tgt_pad=0,
scale_factor=1,
**kwargs,
):
self.style_files = sorted(glob(style_pattern)) self.style_files = sorted(glob(style_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(*in_file_lists))
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists)) self.tgt_files = list(zip(*tgt_file_lists))
# self.style_files = self.style_files[:1] # self.style_files = self.style_files[:1]
# self.in_files = self.in_files[:1] # self.in_files = self.in_files[:1]
# self.tgt_files = self.tgt_files[:1] # self.tgt_files = self.tgt_files[:1]
if len(self.style_files) != len(self.in_files) != len(self.tgt_files): if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of style, input, and target files do not match') raise ValueError("number of style, input, and target files do not match")
self.nfile = len(self.in_files) self.nfile = len(self.in_files)
if self.nfile == 0: if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns)) raise FileNotFoundError("file not found for {}".format(in_patterns))
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0] self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
self.in_chan = [np.load(f, mmap_mode='r').shape[0] self.in_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.in_files[0]]
for f in self.in_files[0]] self.tgt_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.tgt_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.tgt_files[0]]
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:] self.size = np.load(self.in_files[0][0], mmap_mode="r").shape[1:]
self.size = np.asarray(self.size) self.size = np.asarray(self.size)
self.ndim = len(self.size) self.ndim = len(self.size)
if in_norms is not None and len(in_patterns) != len(in_norms): if in_norms is not None and len(in_patterns) != len(in_norms):
raise ValueError('numbers of input normalization functions and fields do not match') raise ValueError(
"numbers of input normalization functions and fields do not match"
)
self.in_norms = in_norms self.in_norms = in_norms
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms): if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
raise ValueError('numbers of target normalization functions and fields do not match') raise ValueError(
"numbers of target normalization functions and fields do not match"
)
self.tgt_norms = tgt_norms self.tgt_norms = tgt_norms
self.callback_at = callback_at self.callback_at = callback_at
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment: if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields') raise ValueError("cannot augment 1D fields")
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,)) self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add self.aug_add = aug_add
self.aug_mul = aug_mul self.aug_mul = aug_mul
@ -116,10 +134,15 @@ class FieldDataset(Dataset):
crop_step = np.broadcast_to(crop_step, (self.ndim,)) crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.crop_step = crop_step self.crop_step = crop_step
self.anchors = np.stack(np.mgrid[tuple( self.anchors = np.stack(
np.mgrid[
tuple(
slice(crop_start[d], crop_stop[d], crop_step[d]) slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim) for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim) )
],
axis=-1,
).reshape(-1, self.ndim)
self.ncrop = len(self.anchors) self.ncrop = len(self.anchors)
def format_pad(pad, ndim): def format_pad(pad, ndim):
@ -130,15 +153,16 @@ class FieldDataset(Dataset):
elif isinstance(pad, tuple) and len(pad) == ndim * 2: elif isinstance(pad, tuple) and len(pad) == ndim * 2:
pad = np.array(pad) pad = np.array(pad)
else: else:
raise ValueError('pad and ndim mismatch') raise ValueError("pad and ndim mismatch")
return pad.reshape(ndim, 2) return pad.reshape(ndim, 2)
self.in_pad = format_pad(in_pad, self.ndim) self.in_pad = format_pad(in_pad, self.ndim)
self.tgt_pad = format_pad(tgt_pad, self.ndim) self.tgt_pad = format_pad(tgt_pad, self.ndim)
if scale_factor != 1: if scale_factor != 1:
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:] tgt_size = np.load(self.tgt_files[0][0], mmap_mode="r").shape[1:]
if any(self.size * scale_factor != tgt_size): if any(self.size * scale_factor != tgt_size):
raise ValueError('input size x scale factor != target size') raise ValueError("input size x scale factor != target size")
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.nsample = self.nfile * self.ncrop self.nsample = self.nfile * self.ncrop
@ -148,9 +172,7 @@ class FieldDataset(Dataset):
self.assembly_line = {} self.assembly_line = {}
self.commonpath = os.path.commonpath( self.commonpath = os.path.commonpath(
file file for files in self.in_files[:2] + self.tgt_files[:2] for file in files
for files in self.in_files[:2] + self.tgt_files[:2]
for file in files
) )
def __len__(self): def __len__(self):
@ -180,12 +202,18 @@ class FieldDataset(Dataset):
else: else:
argsort_perm_axes = slice(None) argsort_perm_axes = slice(None)
crop(in_fields, anchor, crop(
in_fields,
anchor,
self.crop[argsort_perm_axes], self.crop[argsort_perm_axes],
self.in_pad[argsort_perm_axes]) self.in_pad[argsort_perm_axes],
crop(tgt_fields, anchor * self.scale_factor, )
crop(
tgt_fields,
anchor * self.scale_factor,
self.crop[argsort_perm_axes] * self.scale_factor, self.crop[argsort_perm_axes] * self.scale_factor,
self.tgt_pad[argsort_perm_axes]) self.tgt_pad[argsort_perm_axes],
)
style = torch.from_numpy(style).to(torch.float32) style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -221,17 +249,19 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)
#in_relpath = [os.path.relpath(file, start=self.commonpath) # in_relpath = [os.path.relpath(file, start=self.commonpath)
# for file in self.in_files[ifile]] # for file in self.in_files[ifile]]
tgt_relpath = [os.path.relpath(file, start=self.commonpath) tgt_relpath = [
for file in self.tgt_files[ifile]] os.path.relpath(file, start=self.commonpath)
for file in self.tgt_files[ifile]
]
return { return {
'style': style, "style": style,
'input': in_fields, "input": in_fields,
'target': tgt_fields, "target": tgt_fields,
#'input_relpath': in_relpath, #'input_relpath': in_relpath,
'target_relpath': tgt_relpath, "target_relpath": tgt_relpath,
} }
def assemble(self, label, chan, patches, paths): def assemble(self, label, chan, patches, paths):
@ -255,29 +285,29 @@ class FieldDataset(Dataset):
if isinstance(patches, torch.Tensor): if isinstance(patches, torch.Tensor):
patches = patches.detach().cpu().numpy() patches = patches.detach().cpu().numpy()
assert patches.ndim == 2 + self.ndim, 'ndim mismatch' assert patches.ndim == 2 + self.ndim, "ndim mismatch"
if any(self.crop_step > patches.shape[2:]): if any(self.crop_step > patches.shape[2:]):
raise RuntimeError('patch too small to tile') raise RuntimeError("patch too small to tile")
# the batched paths are a list of lists with shape (channel, batch) # the batched paths are a list of lists with shape (channel, batch)
# since pytorch default_collate batches list of strings transposedly # since pytorch default_collate batches list of strings transposedly
# therefore we transpose below back to (batch, channel) # therefore we transpose below back to (batch, channel)
assert patches.shape[1] == sum(chan), 'number of channels mismatch' assert patches.shape[1] == sum(chan), "number of channels mismatch"
assert len(paths) == len(chan), 'number of fields mismatch' assert len(paths) == len(chan), "number of fields mismatch"
paths = list(zip(* paths)) paths = list(zip(*paths))
assert patches.shape[0] == len(paths), 'batch size mismatch' assert patches.shape[0] == len(paths), "batch size mismatch"
patches = list(patches) patches = list(patches)
if label in self.assembly_line: if label in self.assembly_line:
self.assembly_line[label] += patches self.assembly_line[label] += patches
self.assembly_line[label + 'path'] += paths self.assembly_line[label + "path"] += paths
else: else:
self.assembly_line[label] = patches self.assembly_line[label] = patches
self.assembly_line[label + 'path'] = paths self.assembly_line[label + "path"] = paths
del patches, paths del patches, paths
patches = self.assembly_line[label] patches = self.assembly_line[label]
paths = self.assembly_line[label + 'path'] paths = self.assembly_line[label + "path"]
# NOTE anchor positioning assumes sufficient target padding and # NOTE anchor positioning assumes sufficient target padding and
# symmetric narrowing (more on the right if odd) see `models/narrow.py` # symmetric narrowing (more on the right if odd) see `models/narrow.py`
@ -285,31 +315,26 @@ class FieldDataset(Dataset):
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2 anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
while len(patches) >= self.ncrop: while len(patches) >= self.ncrop:
fields = np.zeros(patches[0].shape[:1] + tuple(self.size), fields = np.zeros(patches[0].shape[:1] + tuple(self.size), patches[0].dtype)
patches[0].dtype)
for patch, anchor in zip(patches, anchors): for patch, anchor in zip(patches, anchors):
fill(fields, patch, anchor) fill(fields, patch, anchor)
for field, path in zip( for field, path in zip(np.split(fields, np.cumsum(chan), axis=0), paths[0]):
np.split(fields, np.cumsum(chan), axis=0), pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
paths[0]):
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
exist_ok=True)
path = label.join(os.path.splitext(path)) path = label.join(os.path.splitext(path))
np.save(path, field) np.save(path, field)
del patches[:self.ncrop], paths[:self.ncrop] del patches[: self.ncrop], paths[: self.ncrop]
def fill(field, patch, anchor): def fill(field, patch, anchor):
ndim = len(anchor) ndim = len(anchor)
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch' assert field.ndim == patch.ndim == 1 + ndim, "ndim mismatch"
ind = [slice(None)] ind = [slice(None)]
for d, (p, a, s) in enumerate(zip( for d, (p, a, s) in enumerate(zip(patch.shape[1:], anchor, field.shape[1:])):
patch.shape[1:], anchor, field.shape[1:])):
i = np.arange(a, a + p) i = np.arange(a, a + p)
i %= s i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1)) i = i.reshape((-1,) + (1,) * (ndim - d - 1))
@ -320,10 +345,10 @@ def fill(field, patch, anchor):
def crop(fields, anchor, crop, pad): def crop(fields, anchor, crop, pad):
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch' assert all(x.shape == fields[0].shape for x in fields), "shape mismatch"
size = fields[0].shape[1:] size = fields[0].shape[1:]
ndim = len(size) ndim = len(size)
assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch' assert ndim == len(anchor) == len(crop) == len(pad), "ndim mismatch"
ind = [slice(None)] ind = [slice(None)]
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
@ -342,7 +367,7 @@ def crop(fields, anchor, crop, pad):
def flip(fields, axes, ndim): def flip(fields, axes, ndim):
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors' assert ndim > 1, "flipping is ambiguous for 1D scalars/vectors"
if axes is None: if axes is None:
axes = torch.randint(2, (ndim,), dtype=torch.bool) axes = torch.randint(2, (ndim,), dtype=torch.bool)
@ -350,7 +375,7 @@ def flip(fields, axes, ndim):
for i, x in enumerate(fields): for i, x in enumerate(fields):
if x.shape[0] == ndim: # flip vector components if x.shape[0] == ndim: # flip vector components
x[axes] = - x[axes] x[axes] = -x[axes]
shifted_axes = (1 + axes).tolist() shifted_axes = (1 + axes).tolist()
x = torch.flip(x, shifted_axes) x = torch.flip(x, shifted_axes)
@ -361,7 +386,7 @@ def flip(fields, axes, ndim):
def perm(fields, axes, ndim): def perm(fields, axes, ndim):
assert ndim > 1, 'permutation is not necessary for 1D fields' assert ndim > 1, "permutation is not necessary for 1D fields"
if axes is None: if axes is None:
axes = torch.randperm(ndim) axes = torch.randperm(ndim)

View File

@ -10,6 +10,7 @@ def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
x *= dis_norm x *= dis_norm
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs): def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s] vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
@ -20,25 +21,28 @@ def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
def D(z, Om=0.31): def D(z, Om=0.31):
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero"""
"""
OL = 1 - Om OL = 1 - Om
a = 1 / (1+z) a = 1 / (1 + z)
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \ return (
/ hyp2f1(1, 1/3, 11/6, - OL / Om) a
* hyp2f1(1, 1 / 3, 11 / 6, -OL * a**3 / Om)
/ hyp2f1(1, 1 / 3, 11 / 6, -OL / Om)
)
def f(z, Om=0.31): def f(z, Om=0.31):
"""linear growth rate for flat LambdaCDM """linear growth rate for flat LambdaCDM"""
"""
OL = 1 - Om OL = 1 - Om
a = 1 / (1+z) a = 1 / (1 + z)
aa3 = OL * a**3 / Om aa3 = OL * a**3 / Om
return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \ return 1 - 6 / 11 * aa3 * hyp2f1(2, 4 / 3, 17 / 6, -aa3) / hyp2f1(
/ hyp2f1(1, 1/3, 11/6, -aa3) 1, 1 / 3, 11 / 6, -aa3
)
def H(z, Om=0.31): def H(z, Om=0.31):
"""Hubble in [h km/s/Mpc] for flat LambdaCDM """Hubble in [h km/s/Mpc] for flat LambdaCDM"""
"""
OL = 1 - Om OL = 1 - Om
a = 1 / (1+z) a = 1 / (1 + z)
return 100 * np.sqrt(Om / a**3 + OL) return 100 * np.sqrt(Om / a**3 + OL)

View File

@ -7,18 +7,21 @@ def exp(x, undo=False, **kwargs):
else: else:
torch.log(x, out=x) torch.log(x, out=x)
def log(x, eps=1e-8, undo=False, **kwargs): def log(x, eps=1e-8, undo=False, **kwargs):
if not undo: if not undo:
torch.log(x + eps, out=x) torch.log(x + eps, out=x)
else: else:
torch.exp(x, out=x) torch.exp(x, out=x)
def expm1(x, undo=False, **kwargs): def expm1(x, undo=False, **kwargs):
if not undo: if not undo:
torch.expm1(x, out=x) torch.expm1(x, out=x)
else: else:
torch.log1p(x, out=x) torch.log1p(x, out=x)
def log1p(x, eps=1e-7, undo=False, **kwargs): def log1p(x, eps=1e-7, undo=False, **kwargs):
if not undo: if not undo:
torch.log1p(x + eps, out=x) torch.log1p(x + eps, out=x)

View File

@ -22,8 +22,8 @@ class DistFieldSampler(Sampler):
Like `DistributedSampler`, `set_epoch()` should be called at the beginning Like `DistributedSampler`, `set_epoch()` should be called at the beginning
of each epoch during training. of each epoch during training.
""" """
def __init__(self, dataset, shuffle,
div_data=False, div_shuffle_dist=0): def __init__(self, dataset, shuffle, div_data=False, div_shuffle_dist=0):
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
@ -50,8 +50,10 @@ class DistFieldSampler(Sampler):
ind = ind.flatten() ind = ind.flatten()
# displace crops with respect to files # displace crops with respect to files
dis = torch.rand((self.nfile, self.ncrop), dis = (
generator=g) * self.div_shuffle_dist torch.rand((self.nfile, self.ncrop), generator=g)
* self.div_shuffle_dist
)
loc = torch.arange(self.nfile) loc = torch.arange(self.nfile)
loc = loc[:, None] + dis loc = loc[:, None] + dis
loc = loc.flatten() % self.nfile # periodic in files loc = loc.flatten() % self.nfile # periodic in files

View File

@ -3,6 +3,7 @@ from . import test
import click import click
import os import os
import yaml import yaml
try: try:
from yaml import CLoader as Loader from yaml import CLoader as Loader
except ImportError: except ImportError:
@ -12,20 +13,24 @@ import importlib.resources
import json import json
from functools import partial from functools import partial
def _load_resource_file(resource_path): def _load_resource_file(resource_path):
# Import the package # Import the package
pkg_files = importlib.resources.files() / resource_path pkg_files = importlib.resources.files() / resource_path
with pkg_files.open() as file: with pkg_files.open() as file:
return file.read() # Read the file and return its content return file.read() # Read the file and return its content
def _str_list(value): def _str_list(value):
return value.split(',') return value.split(",")
def _int_tuple(value): def _int_tuple(value):
t = value.split(',') t = value.split(",")
t = tuple(int(i) for i in t) t = tuple(int(i) for i in t)
return t return t
class VariadicType(click.ParamType): class VariadicType(click.ParamType):
""" """
A custom parameter type for Click command-line interface. A custom parameter type for Click command-line interface.
@ -61,7 +66,7 @@ class VariadicType(click.ParamType):
self._typename = typename self._typename = typename
self.name = self._type["type"] self.name = self._type["type"]
if "func" not in self._type: if "func" not in self._type:
self._type["func"] = eval(self._type['type']) self._type["func"] = eval(self._type["type"])
def convert(self, value, param, ctx): def convert(self, value, param, ctx):
try: try:
@ -69,24 +74,26 @@ class VariadicType(click.ParamType):
except Exception as e: except Exception as e:
self.fail(f"Could not parse {self._typename}: {e}", param, ctx) self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
def _apply_options(options_file, f): def _apply_options(options_file, f):
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader) common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
common_args = common_args['arguments'] common_args = common_args["arguments"]
for arg in common_args: for arg in common_args:
argopt = common_args[arg] argopt = common_args[arg]
if 'type' in argopt: if "type" in argopt:
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice': if type(argopt["type"]) == dict and argopt["type"]["type"] == "choice":
argopt['type'] = click.Choice(argopt['type']['opts']) argopt["type"] = click.Choice(argopt["type"]["opts"])
else: else:
argopt['type'] = VariadicType(argopt['type']) argopt["type"] = VariadicType(argopt["type"])
f = click.option(f'--{arg}', **argopt)(f) f = click.option(f"--{arg}", **argopt)(f)
else: else:
f = click.option(f'--{arg}', **argopt)(f) f = click.option(f"--{arg}", **argopt)(f)
return f return f
m2m_options=partial(_apply_options,"common_args.yaml")
m2m_options = partial(_apply_options, "common_args.yaml")
@click.group() @click.group()
@ -94,21 +101,24 @@ m2m_options=partial(_apply_options,"common_args.yaml")
@click.pass_context @click.pass_context
def main(ctx, config): def main(ctx, config):
if config is not None and os.path.exists(config): if config is not None and os.path.exists(config):
with open(config, 'r') as f: with open(config, "r") as f:
config = yaml.load(f.read(), Loader=Loader) config = yaml.load(f.read(), Loader=Loader)
ctx.default_map = config ctx.default_map = config
# Make a class that provides access to dict with the attribute mechanism # Make a class that provides access to dict with the attribute mechanism
class DictProxy: class DictProxy:
def __init__(self, d): def __init__(self, d):
self.__dict__ = d self.__dict__ = d
@main.command() @main.command()
@m2m_options @m2m_options
@partial(_apply_options, "train_args.yaml") @partial(_apply_options, "train_args.yaml")
def train(**kwargs): def train(**kwargs):
train.node_worker(DictProxy(kwargs)) train.node_worker(DictProxy(kwargs))
@main.command() @main.command()
@m2m_options @m2m_options
@partial(_apply_options, "test_args.yaml") @partial(_apply_options, "test_args.yaml")

View File

@ -5,6 +5,7 @@ def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input, """Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension to be concatenated along the batch dimension
""" """
class _new_module(module): class _new_module(module):
def forward(self, x): def forward(self, x):
if not isinstance(x, torch.Tensor): if not isinstance(x, torch.Tensor):
@ -22,6 +23,7 @@ def adv_criterion_wrapper(module):
* expand target shape as that of input * expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors * return a list of losses, one for each pair of input and target Tensors
""" """
class _new_module(module): class _new_module(module):
def forward(self, input, target): def forward(self, input, target):
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
@ -35,8 +37,9 @@ def adv_criterion_wrapper(module):
target = [t.expand_as(i) for i, t in zip(input, target)] target = [t.expand_as(i) for i, t in zip(input, target)]
loss = [super(new_module, self).forward(i, t) loss = [
for i, t in zip(input, target)] super(new_module, self).forward(i, t) for i, t in zip(input, target)
]
return loss return loss

View File

@ -15,8 +15,10 @@ class ConvBlock(nn.Module):
'U': upsampling transposed convolution of kernel size 2 and stride 2 'U': upsampling transposed convolution of kernel size 2 and stride 2
'D': downsampling convolution of kernel size 2 and stride 2 'D': downsampling convolution of kernel size 2 and stride 2
""" """
def __init__(self, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBA'): def __init__(
self, in_chan, out_chan=None, mid_chan=None, kernel_size=3, stride=1, seq="CBA"
):
super().__init__() super().__init__()
if out_chan is None: if out_chan is None:
@ -31,31 +33,30 @@ class ConvBlock(nn.Module):
self.norm_chan = in_chan self.norm_chan = in_chan
self.idx_conv = 0 self.idx_conv = 0
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
layers = [self._get_layer(l) for l in seq] layers = [self._get_layer(l) for l in seq]
self.convs = nn.Sequential(*layers) self.convs = nn.Sequential(*layers)
def _get_layer(self, l): def _get_layer(self, l):
if l == 'U': if l == "U":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2) return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
elif l == 'D': elif l == "D":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return nn.Conv3d(in_chan, out_chan, 2, stride=2) return nn.Conv3d(in_chan, out_chan, 2, stride=2)
elif l == 'C': elif l == "C":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return nn.Conv3d(in_chan, out_chan, self.kernel_size, return nn.Conv3d(in_chan, out_chan, self.kernel_size, stride=self.stride)
stride=self.stride) elif l == "B":
elif l == 'B':
return nn.BatchNorm3d(self.norm_chan) return nn.BatchNorm3d(self.norm_chan)
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) # return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
#return nn.InstanceNorm3d(self.norm_chan) # return nn.InstanceNorm3d(self.norm_chan)
elif l == 'A': elif l == "A":
return nn.LeakyReLU() return nn.LeakyReLU()
else: else:
raise ValueError('layer type {} not supported'.format(l)) raise ValueError("layer type {} not supported".format(l))
def _setup_conv(self): def _setup_conv(self):
self.idx_conv += 1 self.idx_conv += 1
@ -88,13 +89,22 @@ class ResBlock(ConvBlock):
See `ConvBlock` for `seq` types. See `ConvBlock` for `seq` types.
""" """
def __init__(self, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBACBA', last_act=None): def __init__(
self,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBACBA",
last_act=None,
):
if last_act is None: if last_act is None:
last_act = seq[-1] == 'A' last_act = seq[-1] == "A"
elif last_act and seq[-1] != 'A': elif last_act and seq[-1] != "A":
warnings.warn( warnings.warn(
'Disabling last_act without trailing activation in seq', "Disabling last_act without trailing activation in seq",
RuntimeWarning, RuntimeWarning,
) )
last_act = False last_act = False
@ -102,8 +112,14 @@ class ResBlock(ConvBlock):
if last_act: if last_act:
seq = seq[:-1] seq = seq[:-1]
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, super().__init__(
kernel_size=kernel_size, stride=stride, seq=seq) in_chan,
out_chan=out_chan,
mid_chan=mid_chan,
kernel_size=kernel_size,
stride=stride,
seq=seq,
)
if last_act: if last_act:
self.act = nn.LeakyReLU() self.act = nn.LeakyReLU()
@ -115,9 +131,10 @@ class ResBlock(ConvBlock):
else: else:
self.skip = nn.Conv3d(in_chan, out_chan, 1) self.skip = nn.Conv3d(in_chan, out_chan, 1)
if 'U' in seq or 'D' in seq: if "U" in seq or "D" in seq:
raise NotImplementedError('upsample and downsample layers ' raise NotImplementedError(
'not supported yet') "upsample and downsample layers " "not supported yet"
)
def forward(self, x): def forward(self, x):
y = x y = x

View File

@ -2,7 +2,7 @@ import torch.nn as nn
class DiceLoss(nn.Module): class DiceLoss(nn.Module):
def __init__(self, eps=0.): def __init__(self, eps=0.0):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
@ -10,7 +10,7 @@ class DiceLoss(nn.Module):
return dice_loss(input, target, self.eps) return dice_loss(input, target, self.eps)
def dice_loss(input, target, eps=0.): def dice_loss(input, target, eps=0.0):
input = input.view(-1) input = input.view(-1)
target = target.view(-1) target = target.view(-1)

View File

@ -1,10 +1,11 @@
import torch import torch
class InstanceNoise: class InstanceNoise:
"""Instance noise, with a linear decaying schedule """Instance noise, with a linear decaying schedule"""
"""
def __init__(self, init_std, batches): def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative' assert init_std >= 0, "Noise std cannot be negative"
self.init_std = init_std self.init_std = init_std
self._std = init_std self._std = init_std
self.batches = batches self.batches = batches

View File

@ -13,9 +13,10 @@ def lag2eul(
periodic=False, periodic=False,
z=0.0, z=0.0,
dis_std=6.0, dis_std=6.0,
boxsize=1000., boxsize=1000.0,
meshsize=512, meshsize=512,
**kwargs): **kwargs,
):
"""Transform fields from Lagrangian description to Eulerian description """Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input. Only works for 3d fields, output same mesh size as input.
@ -48,20 +49,19 @@ def lag2eul(
if isinstance(val, (float, torch.Tensor)): if isinstance(val, (float, torch.Tensor)):
val = [val] val = [val]
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1: if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
raise ValueError('dis-val field mismatch') raise ValueError("dis-val field mismatch")
if any(d.dim() != 5 for d in dis): if any(d.dim() != 5 for d in dis):
raise NotImplementedError('only support 3d fields for now') raise NotImplementedError("only support 3d fields for now")
if any(d.shape[1] != 3 for d in dis): if any(d.shape[1] != 3 for d in dis):
raise ValueError('only support 3d displacement fields') raise ValueError("only support 3d displacement fields")
# common mean displacement of all inputs # common mean displacement of all inputs
# if removed, fewer particles go outside of the box # if removed, fewer particles go outside of the box
# common for all inputs so outputs are comparable in the same coords # common for all inputs so outputs are comparable in the same coords
d_mean = 0 d_mean = 0
if rm_dis_mean: if rm_dis_mean:
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) for d in dis) / len(dis)
for d in dis) / len(dis)
out = [] out = []
if len(dis) == 1 and len(val) != 1: if len(dis) == 1 and len(val) != 1:
@ -85,12 +85,15 @@ def lag2eul(
pos = (d - d_mean) * dis_norm pos = (d - d_mean) * dis_norm
del d del d
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor, pos[:, 0] += torch.arange(
dtype=dtype, device=device)[:, None, None] 0, DHW[0] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor, )[:, None, None]
dtype=dtype, device=device)[:, None] pos[:, 1] += torch.arange(
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor, 0, DHW[1] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
dtype=dtype, device=device) )[:, None]
pos[:, 2] += torch.arange(
0, DHW[2] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
)
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
@ -118,8 +121,7 @@ def lag2eul(
if periodic: if periodic:
torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]) * DHW[2] + tgtpos[n, 2]
) * DHW[2] + tgtpos[n, 2]
src = v[n] src = v[n]
if not periodic: if not periodic:

View File

@ -3,8 +3,7 @@ import torch.nn as nn
def narrow_by(a, c): def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges. """Narrow a by size c symmetrically on all edges."""
"""
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2) ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
return a[ind] return a[ind]

View File

@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
super().__init__() super().__init__()
self.convs = nn.Sequential( self.convs = nn.Sequential(
ConvBlock(in_chan, 32, seq='CA'), ConvBlock(in_chan, 32, seq="CA"),
ConvBlock(32, 64, seq='CBA'), ConvBlock(32, 64, seq="CBA"),
ConvBlock(64, 128, seq='CBA'), ConvBlock(64, 128, seq="CBA"),
ConvBlock(128, out_chan, seq='C'), ConvBlock(128, out_chan, seq="C"),
) )
def forward(self, x): def forward(self, x):
@ -19,23 +19,20 @@ class PatchGAN(nn.Module):
class PatchGAN42(nn.Module): class PatchGAN42(nn.Module):
"""PatchGAN similar to the one in pix2pix """PatchGAN similar to the one in pix2pix"""
"""
def __init__(self, in_chan, out_chan=1, **kwargs): def __init__(self, in_chan, out_chan=1, **kwargs):
super().__init__() super().__init__()
self.convs = nn.Sequential( self.convs = nn.Sequential(
nn.Conv3d(in_chan, 64, 4, stride=2), nn.Conv3d(in_chan, 64, 4, stride=2),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(64, 128, 4, stride=2), nn.Conv3d(64, 128, 4, stride=2),
nn.BatchNorm3d(128), nn.BatchNorm3d(128),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(128, 256, 4, stride=2), nn.Conv3d(128, 256, 4, stride=2),
nn.BatchNorm3d(256), nn.BatchNorm3d(256),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(256, out_chan, 1), nn.Conv3d(256, out_chan, 1),
) )

View File

@ -26,8 +26,7 @@ def power(x: torch.Tensor):
P = P.sum(dim=0) P = P.sum(dim=0)
del x del x
k = [torch.arange(d, dtype=torch.float32, device=P.device) k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape]
for d in P.shape]
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]] k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
k = torch.meshgrid(*k) k = torch.meshgrid(*k)
k = torch.stack(k, dim=0) k = torch.stack(k, dim=0)
@ -49,9 +48,9 @@ def power(x: torch.Tensor):
del kbin del kbin
# drop k=0 mode and cut at kmax (smallest Nyquist) # drop k=0 mode and cut at kmax (smallest Nyquist)
k = k[1:1+kmax] k = k[1 : 1 + kmax]
P = P[1:1+kmax] P = P[1 : 1 + kmax]
N = N[1:1+kmax] N = N[1 : 1 + kmax]
k /= N k /= N
P /= N P /= N

View File

@ -5,12 +5,11 @@ from .narrow import narrow_by
def resample(x, scale_factor, narrow=True): def resample(x, scale_factor, narrow=True):
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
ndim = x.dim() - 2 ndim = x.dim() - 2
mode = modes[ndim] mode = modes[ndim]
x = F.interpolate(x, scale_factor=scale_factor, x = F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=False)
mode=mode, align_corners=False)
if scale_factor > 1 and narrow == True: if scale_factor > 1 and narrow == True:
edges = round(scale_factor) // 2 edges = round(scale_factor) // 2
@ -25,18 +24,20 @@ class Resampler(nn.Module):
By default discard the inaccurate edges when upsampling. By default discard the inaccurate edges when upsampling.
""" """
def __init__(self, ndim, scale_factor, narrow=True): def __init__(self, ndim, scale_factor, narrow=True):
super().__init__() super().__init__()
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
self.mode = modes[ndim] self.mode = modes[ndim]
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.narrow = narrow self.narrow = narrow
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, x = F.interpolate(
mode=self.mode, align_corners=False) x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False
)
if self.scale_factor > 1 and self.narrow == True: if self.scale_factor > 1 and self.narrow == True:
edges = round(self.scale_factor) // 2 edges = round(self.scale_factor) // 2

View File

@ -4,8 +4,18 @@ from torch.nn.utils import spectral_norm, remove_spectral_norm
def add_spectral_norm(module): def add_spectral_norm(module):
for name, child in module.named_children(): for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, if isinstance(
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): child,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
setattr(module, name, spectral_norm(child)) setattr(module, name, spectral_norm(child))
else: else:
add_spectral_norm(child) add_spectral_norm(child)
@ -13,8 +23,18 @@ def add_spectral_norm(module):
def rm_spectral_norm(module): def rm_spectral_norm(module):
for name, child in module.named_children(): for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, if isinstance(
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): child,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
setattr(module, name, remove_spectral_norm(child)) setattr(module, name, remove_spectral_norm(child))
else: else:
rm_spectral_norm(child) rm_spectral_norm(child)

View File

@ -7,9 +7,17 @@ from .resample import Resampler
class G(nn.Module): class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16, def __init__(
chan_base=512, chan_min=64, chan_max=512, cat_noise=False, self,
**kwargs): in_chan,
out_chan,
scale_factor=16,
chan_base=512,
chan_min=64,
chan_max=512,
cat_noise=False,
**kwargs,
):
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -30,15 +38,14 @@ class G(nn.Module):
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for b in range(num_blocks): for b in range(num_blocks):
prev_chan, next_chan = chan(b), chan(b+1) prev_chan, next_chan = chan(b), chan(b + 1)
self.blocks.append( self.blocks.append(HBlock(prev_chan, next_chan, out_chan, cat_noise))
HBlock(prev_chan, next_chan, out_chan, cat_noise))
def forward(self, x): def forward(self, x):
y = x # direct upsampling from the input y = x # direct upsampling from the input
x = self.block0(x) x = self.block0(x)
#y = None # no direct upsampling from the input # y = None # no direct upsampling from the input
for block in self.blocks: for block in self.blocks:
x, y = block(x, y) x, y = block(x, y)
@ -71,6 +78,7 @@ class HBlock(nn.Module):
----- -----
next_size = 2 * prev_size - 6 next_size = 2 * prev_size - 6
""" """
def __init__(self, prev_chan, next_chan, out_chan, cat_noise): def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
super().__init__() super().__init__()
@ -81,7 +89,6 @@ class HBlock(nn.Module):
self.upsample, self.upsample,
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3), nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
AddNoise(cat_noise, chan=next_chan), AddNoise(cat_noise, chan=next_chan),
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3), nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
@ -114,6 +121,7 @@ class AddNoise(nn.Module):
The number of channels `chan` should be 1 (StyleGAN2) The number of channels `chan` should be 1 (StyleGAN2)
or that of the input (StyleGAN). or that of the input (StyleGAN).
""" """
def __init__(self, cat, chan=1): def __init__(self, cat, chan=1):
super().__init__() super().__init__()
@ -137,9 +145,16 @@ class AddNoise(nn.Module):
class D(nn.Module): class D(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16, def __init__(
chan_base=512, chan_min=64, chan_max=512, self,
**kwargs): in_chan,
out_chan,
scale_factor=16,
chan_base=512,
chan_min=64,
chan_max=512,
**kwargs,
):
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -163,7 +178,7 @@ class D(nn.Module):
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for b in reversed(range(num_blocks)): for b in reversed(range(num_blocks)):
prev_chan, next_chan = chan(b+1), chan(b) prev_chan, next_chan = chan(b + 1), chan(b)
self.blocks.append(ResBlock(prev_chan, next_chan)) self.blocks.append(ResBlock(prev_chan, next_chan))
self.block9 = nn.Sequential( self.block9 = nn.Sequential(
@ -192,13 +207,13 @@ class ResBlock(nn.Module):
----- -----
next_size = (prev_size - 4) // 2 next_size = (prev_size - 4) // 2
""" """
def __init__(self, prev_chan, next_chan): def __init__(self, prev_chan, next_chan):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv3d(prev_chan, prev_chan, 3), nn.Conv3d(prev_chan, prev_chan, 3),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
nn.Conv3d(prev_chan, next_chan, 3), nn.Conv3d(prev_chan, next_chan, 3),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
) )

View File

@ -9,6 +9,7 @@ class PixelNorm(nn.Module):
See ProGAN, StyleGAN. See ProGAN, StyleGAN.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -23,6 +24,7 @@ class LinearElr(nn.Module):
Useful at all if not for regularization(1706.05350)? Useful at all if not for regularization(1706.05350)?
""" """
def __init__(self, in_size, out_size, bias=True, act=None): def __init__(self, in_size, out_size, bias=True, act=None):
super().__init__() super().__init__()
@ -32,7 +34,7 @@ class LinearElr(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_size)) self.bias = nn.Parameter(torch.zeros(out_size))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.act = act self.act = act
@ -52,20 +54,20 @@ class ConvElr3d(nn.Module):
Useful at all if not for regularization(1706.05350)? Useful at all if not for regularization(1706.05350)?
""" """
def __init__(self, in_chan, out_chan, kernel_size,
stride=1, padding=0, bias=True): def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, bias=True):
super().__init__() super().__init__()
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.randn(out_chan, in_chan, *(kernel_size,) * 3), torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
) )
fan_in = in_chan * kernel_size ** 3 fan_in = in_chan * kernel_size**3
self.wnorm = 1 / math.sqrt(fan_in) self.wnorm = 1 / math.sqrt(fan_in)
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_chan)) self.bias = nn.Parameter(torch.zeros(out_chan))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
@ -87,38 +89,49 @@ class ConvStyled3d(nn.Module):
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`. Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
""" """
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
bias=True, resample=None): def __init__(
self,
style_size,
in_chan,
out_chan,
kernel_size=3,
stride=1,
bias=True,
resample=None,
):
super().__init__() super().__init__()
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size)) self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5), nn.init.kaiming_uniform_(
mode='fan_in', nonlinearity='leaky_relu') self.style_weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"
)
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1 self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
if resample is None: if resample is None:
K3 = (kernel_size,) * 3 K3 = (kernel_size,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = stride self.stride = stride
self.conv = F.conv3d self.conv = F.conv3d
elif resample == 'U': elif resample == "U":
K3 = (2,) * 3 K3 = (2,) * 3
# NOTE not clear to me why convtranspose have channels swapped # NOTE not clear to me why convtranspose have channels swapped
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3)) self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
self.stride = 2 self.stride = 2
self.conv = F.conv_transpose3d self.conv = F.conv_transpose3d
elif resample == 'D': elif resample == "D":
K3 = (2,) * 3 K3 = (2,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = 2 self.stride = 2
self.conv = F.conv3d self.conv = F.conv3d
else: else:
raise ValueError('resample type {} not supported'.format(resample)) raise ValueError("resample type {} not supported".format(resample))
self.resample = resample self.resample = resample
nn.init.kaiming_uniform_( nn.init.kaiming_uniform_(
self.weight, a=math.sqrt(5), self.weight,
mode='fan_in', # effectively 'fan_out' for 'D' a=math.sqrt(5),
nonlinearity='leaky_relu', mode="fan_in", # effectively 'fan_out' for 'D'
nonlinearity="leaky_relu",
) )
if bias: if bias:
@ -127,12 +140,12 @@ class ConvStyled3d(nn.Module):
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound) nn.init.uniform_(self.bias, -bound, bound)
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
def forward(self, x, s, eps=1e-8): def forward(self, x, s, eps=1e-8):
N, Cin, *DHWin = x.shape N, Cin, *DHWin = x.shape
C0, C1, *K3 = self.weight.shape C0, C1, *K3 = self.weight.shape
if self.resample == 'U': if self.resample == "U":
Cin, Cout = C0, C1 Cin, Cout = C0, C1
else: else:
Cout, Cin = C0, C1 Cout, Cin = C0, C1
@ -140,14 +153,14 @@ class ConvStyled3d(nn.Module):
s = F.linear(s, self.style_weight, bias=self.style_bias) s = F.linear(s, self.style_weight, bias=self.style_bias)
# modulation # modulation
if self.resample == 'U': if self.resample == "U":
s = s.reshape(N, Cin, 1, 1, 1, 1) s = s.reshape(N, Cin, 1, 1, 1, 1)
else: else:
s = s.reshape(N, 1, Cin, 1, 1, 1) s = s.reshape(N, 1, Cin, 1, 1, 1)
w = self.weight * s w = self.weight * s
# demodulation # demodulation
if self.resample == 'U': if self.resample == "U":
fan_in_dim = (1, 3, 4, 5) fan_in_dim = (1, 3, 4, 5)
else: else:
fan_in_dim = (2, 3, 4, 5) fan_in_dim = (2, 3, 4, 5)
@ -161,18 +174,22 @@ class ConvStyled3d(nn.Module):
return x return x
class BatchNormStyled3d(nn.BatchNorm3d) :
""" Trivially does standard batch normalization, but accepts second argument class BatchNormStyled3d(nn.BatchNorm3d):
"""Trivially does standard batch normalization, but accepts second argument
for style array that is not used for style array that is not used
""" """
def forward(self, x, s): def forward(self, x, s):
return super().forward(x) return super().forward(x)
class LeakyReLUStyled(nn.LeakyReLU): class LeakyReLUStyled(nn.LeakyReLU):
""" Trivially evaluates standard leaky ReLU, but accepts second argument """Trivially evaluates standard leaky ReLU, but accepts second argument
for sytle array that is not used for sytle array that is not used
""" """
def forward(self, x, s): def forward(self, x, s):
return super().forward(x) return super().forward(x)

View File

@ -16,8 +16,17 @@ class ConvStyledBlock(nn.Module):
'U': upsampling transposed convolution of kernel size 2 and stride 2 'U': upsampling transposed convolution of kernel size 2 and stride 2
'D': downsampling convolution of kernel size 2 and stride 2 'D': downsampling convolution of kernel size 2 and stride 2
""" """
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBA'): def __init__(
self,
style_size,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBA",
):
super().__init__() super().__init__()
if out_chan is None: if out_chan is None:
@ -33,31 +42,34 @@ class ConvStyledBlock(nn.Module):
self.norm_chan = in_chan self.norm_chan = in_chan
self.idx_conv = 0 self.idx_conv = 0
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
layers = [self._get_layer(l) for l in seq] layers = [self._get_layer(l) for l in seq]
self.convs = nn.ModuleList(layers) self.convs = nn.ModuleList(layers)
def _get_layer(self, l): def _get_layer(self, l):
if l == 'U': if l == "U":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, return ConvStyled3d(
resample = 'U') self.style_size, in_chan, out_chan, 2, stride=2, resample="U"
elif l == 'D': )
elif l == "D":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, return ConvStyled3d(
resample = 'D') self.style_size, in_chan, out_chan, 2, stride=2, resample="D"
elif l == 'C': )
elif l == "C":
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size, return ConvStyled3d(
stride=self.stride) self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride
elif l == 'B': )
elif l == "B":
return BatchNormStyled3d(self.norm_chan) return BatchNormStyled3d(self.norm_chan)
elif l == 'A': elif l == "A":
return LeakyReLUStyled() return LeakyReLUStyled()
else: else:
raise ValueError('layer type {} not supported'.format(l)) raise ValueError("layer type {} not supported".format(l))
def _setup_conv(self): def _setup_conv(self):
self.idx_conv += 1 self.idx_conv += 1
@ -92,13 +104,23 @@ class ResStyledBlock(ConvStyledBlock):
See `ConvStyledBlock` for `seq` types. See `ConvStyledBlock` for `seq` types.
""" """
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBACBA', last_act=None): def __init__(
self,
style_size,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBACBA",
last_act=None,
):
if last_act is None: if last_act is None:
last_act = seq[-1] == 'A' last_act = seq[-1] == "A"
elif last_act and seq[-1] != 'A': elif last_act and seq[-1] != "A":
warnings.warn( warnings.warn(
'Disabling last_act without trailing activation in seq', "Disabling last_act without trailing activation in seq",
RuntimeWarning, RuntimeWarning,
) )
last_act = False last_act = False
@ -106,8 +128,15 @@ class ResStyledBlock(ConvStyledBlock):
if last_act: if last_act:
seq = seq[:-1] seq = seq[:-1]
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, super().__init__(
kernel_size=kernel_size, stride=stride, seq=seq) style_size,
in_chan,
out_chan=out_chan,
mid_chan=mid_chan,
kernel_size=kernel_size,
stride=stride,
seq=seq,
)
if last_act: if last_act:
self.act = LeakyReLUStyled() self.act = LeakyReLUStyled()
@ -119,9 +148,10 @@ class ResStyledBlock(ConvStyledBlock):
else: else:
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1) self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
if 'U' in seq or 'D' in seq: if "U" in seq or "D" in seq:
raise NotImplementedError('upsample and downsample layers ' raise NotImplementedError(
'not supported yet') "upsample and downsample layers " "not supported yet"
)
def forward(self, x, s): def forward(self, x, s):
y = x y = x

View File

@ -15,17 +15,17 @@ class StyledVNet(nn.Module):
# activate non-identity skip connection in residual block # activate non-identity skip connection in residual block
# by explicitly setting out_chan # by explicitly setting out_chan
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA') self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq="CACBA")
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA') self.down_l0 = ConvStyledBlock(style_size, 64, seq="DBA")
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA') self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA') self.down_l1 = ConvStyledBlock(style_size, 64, seq="DBA")
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') self.conv_c = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA') self.up_r1 = ConvStyledBlock(style_size, 64, seq="UBA")
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA') self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq="CBACBA")
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') self.up_r0 = ConvStyledBlock(style_size, 64, seq="UBA")
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq="CAC")
if bypass is None: if bypass is None:
self.bypass = in_chan == out_chan self.bypass = in_chan == out_chan

View File

@ -19,17 +19,17 @@ class UNet(nn.Module):
""" """
super().__init__() super().__init__()
self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA') self.conv_l0 = ConvBlock(in_chan, 64, seq="CACBA")
self.down_l0 = ConvBlock(64, seq='DBA') self.down_l0 = ConvBlock(64, seq="DBA")
self.conv_l1 = ConvBlock(64, seq='CBACBA') self.conv_l1 = ConvBlock(64, seq="CBACBA")
self.down_l1 = ConvBlock(64, seq='DBA') self.down_l1 = ConvBlock(64, seq="DBA")
self.conv_c = ConvBlock(64, seq='CBACBA') self.conv_c = ConvBlock(64, seq="CBACBA")
self.up_r1 = ConvBlock(64, seq='UBA') self.up_r1 = ConvBlock(64, seq="UBA")
self.conv_r1 = ConvBlock(128, 64, seq='CBACBA') self.conv_r1 = ConvBlock(128, 64, seq="CBACBA")
self.up_r0 = ConvBlock(64, seq='UBA') self.up_r0 = ConvBlock(64, seq="UBA")
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC') self.conv_r0 = ConvBlock(128, out_chan, seq="CAC")
self.bypass = in_chan == out_chan self.bypass = in_chan == out_chan

View File

@ -23,17 +23,17 @@ class VNet(nn.Module):
# activate non-identity skip connection in residual block # activate non-identity skip connection in residual block
# by explicitly setting out_chan # by explicitly setting out_chan
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA') self.conv_l0 = ResBlock(in_chan, 64, seq="CACBA")
self.down_l0 = ConvBlock(64, seq='DBA') self.down_l0 = ConvBlock(64, seq="DBA")
self.conv_l1 = ResBlock(64, 64, seq='CBACBA') self.conv_l1 = ResBlock(64, 64, seq="CBACBA")
self.down_l1 = ConvBlock(64, seq='DBA') self.down_l1 = ConvBlock(64, seq="DBA")
self.conv_c = ResBlock(64, 64, seq='CBACBA') self.conv_c = ResBlock(64, 64, seq="CBACBA")
self.up_r1 = ConvBlock(64, seq='UBA') self.up_r1 = ConvBlock(64, seq="UBA")
self.conv_r1 = ResBlock(128, 64, seq='CBACBA') self.conv_r1 = ResBlock(128, 64, seq="CBACBA")
self.up_r0 = ConvBlock(64, seq='UBA') self.up_r0 = ConvBlock(64, seq="UBA")
self.conv_r0 = ResBlock(128, out_chan, seq='CAC') self.conv_r0 = ResBlock(128, out_chan, seq="CAC")
if bypass is None: if bypass is None:
self.bypass = in_chan == out_chan self.bypass = in_chan == out_chan

View File

@ -16,21 +16,21 @@ from .utils import import_attr, load_model_state_dict
def test(args): def test(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
warnings.warn('Not parallelized but given more than 1 GPUs') warnings.warn("Not parallelized but given more than 1 GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda', 0) device = torch.device("cuda", 0)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
else: # CPU multithreading else: # CPU multithreading
device = torch.device('cpu') device = torch.device("cpu")
if args.num_threads is None: if args.num_threads is None:
args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE']) args.num_threads = int(os.environ["SLURM_CPUS_ON_NODE"])
torch.set_num_threads(args.num_threads) torch.set_num_threads(args.num_threads)
print('pytorch {}'.format(torch.__version__)) print("pytorch {}".format(torch.__version__))
pprint(vars(args)) pprint(vars(args))
sys.stdout.flush() sys.stdout.flush()
@ -67,26 +67,33 @@ def test(args):
out_chan = test_dataset.tgt_chan out_chan = test_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(style_size, sum(in_chan), sum(out_chan), model = model(
scale_factor=args.scale_factor, **args.misc_kwargs) style_size,
sum(in_chan),
sum(out_chan),
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
model.to(device) model.to(device)
criterion = import_attr(args.criterion, torch.nn, models, criterion = import_attr(
callback_at=args.callback_at) args.criterion, torch.nn, models, callback_at=args.callback_at
)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
state = torch.load(args.load_state, map_location=device) state = torch.load(args.load_state, map_location=device)
load_model_state_dict(model, state['model'], strict=args.load_state_strict) load_model_state_dict(model, state["model"], strict=args.load_state_strict)
print('model state at epoch {} loaded from {}'.format( print(
state['epoch'], args.load_state)) "model state at epoch {} loaded from {}".format(state["epoch"], args.load_state)
)
del state del state
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for i, data in enumerate(test_loader): for i, data in enumerate(test_loader):
style, input, target = data['style'], data['input'], data['target'] style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True) style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
@ -94,21 +101,21 @@ def test(args):
output = model(input, style) output = model(input, style)
if i < 5: if i < 5:
print('##### sample :', i) print("##### sample :", i)
print('style shape :', style.shape) print("style shape :", style.shape)
print('input shape :', input.shape) print("input shape :", input.shape)
print('output shape :', output.shape) print("output shape :", output.shape)
print('target shape :', target.shape) print("target shape :", target.shape)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
if i < 5: if i < 5:
print('narrowed shape :', output.shape, flush=True) print("narrowed shape :", output.shape, flush=True)
loss = criterion(output, target) loss = criterion(output, target)
print('sample {} loss: {}'.format(i, loss.item())) print("sample {} loss: {}".format(i, loss.item()))
#if args.in_norms is not None: # if args.in_norms is not None:
# start = 0 # start = 0
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
# norm = import_attr(norm, norms, callback_at=args.callback_at) # norm = import_attr(norm, norms, callback_at=args.callback_at)
@ -119,12 +126,11 @@ def test(args):
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at) norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True, **args.misc_kwargs) norm(output[:, start:stop], undo=True, **args.misc_kwargs)
#norm(target[:, start:stop], undo=True, **args.misc_kwargs) # norm(target[:, start:stop], undo=True, **args.misc_kwargs)
start = stop start = stop
#test_dataset.assemble('_in', in_chan, input, # test_dataset.assemble('_in', in_chan, input,
# data['input_relpath']) # data['input_relpath'])
test_dataset.assemble('_out', out_chan, output, test_dataset.assemble("_out", out_chan, output, data["target_relpath"])
data['target_relpath']) # test_dataset.assemble('_tgt', out_chan, target,
#test_dataset.assemble('_tgt', out_chan, target,
# data['target_relpath']) # data['target_relpath'])

View File

@ -18,34 +18,34 @@ from .models import narrow_cast, resample
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
ckpt_link = 'checkpoint.pt' ckpt_link = "checkpoint.pt"
def node_worker(args): def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ: if "SLURM_STEP_NUM_NODES" in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES']) args.nodes = int(os.environ["SLURM_STEP_NUM_NODES"])
elif 'SLURM_JOB_NUM_NODES' in os.environ: elif "SLURM_JOB_NUM_NODES" in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES']) args.nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
else: else:
raise KeyError('missing node counts in slurm env') raise KeyError("missing node counts in slurm env")
args.gpus_per_node = torch.cuda.device_count() args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID']) node = int(os.environ["SLURM_NODEID"])
if args.gpus_per_node < 1: if args.gpus_per_node < 1:
raise RuntimeError('GPU not found on node {}'.format(node)) raise RuntimeError("GPU not found on node {}".format(node))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node) spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args): def gpu_worker(local_rank, node, args):
#device = torch.device('cuda', local_rank) # device = torch.device('cuda', local_rank)
#torch.cuda.device(device) # env var recommended over this # torch.cuda.device(device) # env var recommended over this
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
device = torch.device('cuda', 0) device = torch.device("cuda", 0)
rank = args.gpus_per_node * node + local_rank rank = args.gpus_per_node * node + local_rank
@ -53,7 +53,7 @@ def gpu_worker(local_rank, node, args):
# Note DDP broadcasts initial model states from rank 0 # Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank) torch.manual_seed(args.seed + rank)
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic # good practice to disable cudnn.benchmark if enabling cudnn.deterministic
#torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
dist_init(rank, args) dist_init(rank, args)
@ -77,9 +77,12 @@ def gpu_worker(local_rank, node, args):
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
**args.misc_kwargs, **args.misc_kwargs,
) )
train_sampler = DistFieldSampler(train_dataset, shuffle=True, train_sampler = DistFieldSampler(
train_dataset,
shuffle=True,
div_data=args.div_data, div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist) div_shuffle_dist=args.div_shuffle_dist,
)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -110,9 +113,12 @@ def gpu_worker(local_rank, node, args):
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
**args.misc_kwargs, **args.misc_kwargs,
) )
val_sampler = DistFieldSampler(val_dataset, shuffle=False, val_sampler = DistFieldSampler(
val_dataset,
shuffle=False,
div_data=args.div_data, div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist) div_shuffle_dist=args.div_shuffle_dist,
)
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -127,14 +133,19 @@ def gpu_worker(local_rank, node, args):
args.out_chan = train_dataset.tgt_chan args.out_chan = train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan), model = model(
scale_factor=args.scale_factor, **args.misc_kwargs) args.style_size,
sum(args.in_chan),
sum(args.out_chan),
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
model.to(device) model.to(device)
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(
process_group=dist.new_group()) model, device_ids=[device], process_group=dist.new_group()
)
criterion = import_attr(args.criterion, nn, models, criterion = import_attr(args.criterion, nn, models, callback_at=args.callback_at)
callback_at=args.callback_at)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
@ -144,11 +155,13 @@ def gpu_worker(local_rank, node, args):
lr=args.lr, lr=args.lr,
**args.optimizer_args, **args.optimizer_args,
) )
scheduler = optim.lr_scheduler.ReduceLROnPlateau( scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **args.scheduler_args)
optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) if (
or not args.load_state): args.load_state == ckpt_link
and not os.path.isfile(ckpt_link)
or not args.load_state
):
if args.init_weight_std is not None: if args.init_weight_std is not None:
model.apply(init_weights) model.apply(init_weights)
@ -159,23 +172,28 @@ def gpu_worker(local_rank, node, args):
else: else:
state = torch.load(args.load_state, map_location=device) state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch'] start_epoch = state["epoch"]
load_model_state_dict(model.module, state['model'], load_model_state_dict(
strict=args.load_state_strict) model.module, state["model"], strict=args.load_state_strict
)
if 'optimizer' in state: if "optimizer" in state:
optimizer.load_state_dict(state['optimizer']) optimizer.load_state_dict(state["optimizer"])
if 'scheduler' in state: if "scheduler" in state:
scheduler.load_state_dict(state['scheduler']) scheduler.load_state_dict(state["scheduler"])
torch.set_rng_state(state['rng'].cpu()) # move rng state back torch.set_rng_state(state["rng"].cpu()) # move rng state back
if rank == 0: if rank == 0:
min_loss = state['min_loss'] min_loss = state["min_loss"]
print('state at epoch {} loaded from {}'.format( print(
state['epoch'], args.load_state), flush=True) "state at epoch {} loaded from {}".format(
state["epoch"], args.load_state
),
flush=True,
)
del state del state
@ -189,21 +207,31 @@ def gpu_worker(local_rank, node, args):
logger = SummaryWriter() logger = SummaryWriter()
if rank == 0: if rank == 0:
print('pytorch {}'.format(torch.__version__)) print("pytorch {}".format(torch.__version__))
pprint(vars(args)) pprint(vars(args))
sys.stdout.flush() sys.stdout.flush()
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, criterion, train_loss = train(
optimizer, scheduler, logger, device, args) epoch,
train_loader,
model,
criterion,
optimizer,
scheduler,
logger,
device,
args,
)
epoch_loss = train_loss epoch_loss = train_loss
if args.val: if args.val:
val_loss = validate(epoch, val_loader, model, criterion, val_loss = validate(
logger, device, args) epoch, val_loader, model, criterion, logger, device, args
#epoch_loss = val_loss )
# epoch_loss = val_loss
if args.reduce_lr_on_plateau: if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[2]) scheduler.step(epoch_loss[2])
@ -215,27 +243,26 @@ def gpu_worker(local_rank, node, args):
min_loss = epoch_loss[2] min_loss = epoch_loss[2]
state = { state = {
'epoch': epoch + 1, "epoch": epoch + 1,
'model': model.module.state_dict(), "model": model.module.state_dict(),
'optimizer': optimizer.state_dict(), "optimizer": optimizer.state_dict(),
'scheduler': scheduler.state_dict(), "scheduler": scheduler.state_dict(),
'rng': torch.get_rng_state(), "rng": torch.get_rng_state(),
'min_loss': min_loss, "min_loss": min_loss,
} }
state_file = 'state_{}.pt'.format(epoch + 1) state_file = "state_{}.pt".format(epoch + 1)
torch.save(state, state_file) torch.save(state, state_file)
del state del state
tmp_link = '{}.pt'.format(time.time()) tmp_link = "{}.pt".format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link) os.rename(tmp_link, ckpt_link)
dist.destroy_process_group() dist.destroy_process_group()
def train(epoch, loader, model, criterion, def train(epoch, loader, model, criterion, optimizer, scheduler, logger, device, args):
optimizer, scheduler, logger, device, args):
model.train() model.train()
rank = dist.get_rank() rank = dist.get_rank()
@ -246,7 +273,7 @@ def train(epoch, loader, model, criterion,
for i, data in enumerate(loader): for i, data in enumerate(loader):
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
style, input, target = data['style'], data['input'], data['target'] style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True) style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
@ -254,23 +281,21 @@ def train(epoch, loader, model, criterion,
output = model(input, style) output = model(input, style)
if batch <= 5 and rank == 0: if batch <= 5 and rank == 0:
print('##### batch :', batch) print("##### batch :", batch)
print('style shape :', style.shape) print("style shape :", style.shape)
print('input shape :', input.shape) print("input shape :", input.shape)
print('output shape :', output.shape) print("output shape :", output.shape)
print('target shape :', target.shape) print("target shape :", target.shape)
if (hasattr(model.module, 'scale_factor') if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False) input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
if batch <= 5 and rank == 0: if batch <= 5 and rank == 0:
print('narrowed shape :', output.shape) print("narrowed shape :", output.shape)
loss = criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.detach() epoch_loss[0] += loss.detach()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -280,43 +305,45 @@ def train(epoch, loader, model, criterion,
dist.all_reduce(loss) dist.all_reduce(loss)
loss /= world_size loss /= world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/batch/train', loss.item(), logger.add_scalar("loss/batch/train", loss.item(), global_step=batch)
global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch) logger.add_scalar("grad/first", grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch) logger.add_scalar("grad/last", grads[-1], global_step=batch)
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/epoch/train', epoch_loss[0], logger.add_scalar("loss/epoch/train", epoch_loss[0], global_step=epoch + 1)
global_step=epoch+1)
skip_chan = 0 skip_chan = 0
fig = plt_slices( fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:], input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'out', 'tgt', 'out - tgt'], title=["in", "out", "tgt", "out - tgt"],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/train', fig, global_step=epoch+1) logger.add_figure("fig/train", fig, global_step=epoch + 1)
fig.clf() fig.clf()
fig = plt_power( fig = plt_power(
input, output[:, skip_chan:], target[:, skip_chan:], input,
label=['in', 'out', 'tgt'], output[:, skip_chan:],
target[:, skip_chan:],
label=["in", "out", "tgt"],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) logger.add_figure("fig/train/power/lag", fig, global_step=epoch + 1)
fig.clf() fig.clf()
#fig = plt_power(1.0, # fig = plt_power(1.0,
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'], # label=['in', 'out', 'tgt'],
# **args.misc_kwargs, # **args.misc_kwargs,
#) # )
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) # logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf() # fig.clf()
return epoch_loss return epoch_loss
@ -331,7 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
with torch.no_grad(): with torch.no_grad():
for data in loader: for data in loader:
style, input, target = data['style'], data['input'], data['target'] style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True) style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
@ -339,8 +366,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
output = model(input, style) output = model(input, style)
if (hasattr(model.module, 'scale_factor') if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False) input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
@ -350,40 +376,43 @@ def validate(epoch, loader, model, criterion, logger, device, args):
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/epoch/val', epoch_loss[0], logger.add_scalar("loss/epoch/val", epoch_loss[0], global_step=epoch + 1)
global_step=epoch+1)
skip_chan = 0 skip_chan = 0
fig = plt_slices( fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:], input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'out', 'tgt', 'out - tgt'], title=["in", "out", "tgt", "out - tgt"],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/val', fig, global_step=epoch+1) logger.add_figure("fig/val", fig, global_step=epoch + 1)
fig.clf() fig.clf()
fig = plt_power( fig = plt_power(
input, output[:, skip_chan:], target[:, skip_chan:], input,
label=['in', 'out', 'tgt'], output[:, skip_chan:],
target[:, skip_chan:],
label=["in", "out", "tgt"],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) logger.add_figure("fig/val/power/lag", fig, global_step=epoch + 1)
fig.clf() fig.clf()
#fig = plt_power(1.0, # fig = plt_power(1.0,
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'], # label=['in', 'out', 'tgt'],
# **args.misc_kwargs, # **args.misc_kwargs,
#) # )
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) # logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf() # fig.clf()
return epoch_loss return epoch_loss
def dist_init(rank, args): def dist_init(rank, args):
dist_file = 'dist_addr' dist_file = "dist_addr"
if rank == 0: if rank == 0:
addr = socket.gethostname() addr = socket.gethostname()
@ -393,15 +422,15 @@ def dist_init(rank, args):
s.bind((addr, 0)) s.bind((addr, 0))
_, port = s.getsockname() _, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port) args.dist_addr = "tcp://{}:{}".format(addr, port)
with open(dist_file, mode='w') as f: with open(dist_file, mode="w") as f:
f.write(args.dist_addr) f.write(args.dist_addr)
else: else:
while not os.path.exists(dist_file): while not os.path.exists(dist_file):
time.sleep(1) time.sleep(1)
with open(dist_file, mode='r') as f: with open(dist_file, mode="r") as f:
args.dist_addr = f.read() args.dist_addr = f.read()
dist.init_process_group( dist.init_process_group(
@ -417,19 +446,40 @@ def dist_init(rank, args):
def init_weights(m, args): def init_weights(m, args):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, if isinstance(
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): m,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
m.weight.data.normal_(0.0, args.init_weight_std) m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, elif isinstance(
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm, m,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
),
):
if m.affine: if m.affine:
# NOTE: dispersion from DCGAN, why? # NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std) m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0) m.bias.data.fill_(0)
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False): def set_requires_grad(module: torch.nn.Module, requires_grad: bool = False):
for param in module.parameters(): for param in module.parameters():
param.requires_grad = requires_grad param.requires_grad = requires_grad
@ -444,8 +494,7 @@ def get_grads(model: torch.nn.Module):
Returns: Returns:
A list containing the norms of the gradients of the first and the last layer weights. A list containing the norms of the gradients of the first and the last layer weights.
""" """
grads = list(p.grad for n, p in model.named_parameters() grads = list(p.grad for n, p in model.named_parameters() if ".weight" in n)
if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm() for g in grads] grads = [g.detach().norm() for g in grads]
return grads return grads

View File

@ -2,11 +2,13 @@ from math import log2, log10, ceil
import torch import torch
import numpy as np import numpy as np
import matplotlib import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, SymLogNorm from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable from matplotlib.cm import ScalarMappable
plt.rc('text', usetex=False)
plt.rc("text", usetex=False)
from ..models import lag2eul, power from ..models import lag2eul, power
@ -21,7 +23,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
Each field should have a channel dimension followed by spatial dimensions, Each field should have a channel dimension followed by spatial dimensions,
i.e. no batch dimension. i.e. no batch dimension.
""" """
plt.close('all') plt.close("all")
assert all(isinstance(field, torch.Tensor) for field in fields) assert all(isinstance(field, torch.Tensor) for field in fields)
@ -38,11 +40,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
im_size = 2 im_size = 2
cbar_height = 0.2 cbar_height = 0.2
fig, axes = plt.subplots( fig, axes = plt.subplots(
nc + 1, nf, nc + 1,
nf,
squeeze=False, squeeze=False,
figsize=(nf * im_size, nc * im_size + cbar_height), figsize=(nf * im_size, nc * im_size + cbar_height),
dpi=100, dpi=100,
gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]} gridspec_kw={"height_ratios": nc * [im_size] + [cbar_height]},
) )
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)): for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
@ -51,11 +54,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
if cmap_col is None: if cmap_col is None:
if all_non_neg: if all_non_neg:
cmap_col = 'inferno' cmap_col = "inferno"
elif all_non_pos: elif all_non_pos:
cmap_col = 'inferno_r' cmap_col = "inferno_r"
else: else:
cmap_col = 'RdBu_r' cmap_col = "RdBu_r"
if norm_col is None: if norm_col is None:
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5]) l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
@ -70,9 +73,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
if l1 < 0.1 * l2 or h2 == 0: if l1 < 0.1 * l2 or h2 == 0:
norm_col = Normalize(vmin=-quantize(-l2), vmax=0) norm_col = Normalize(vmin=-quantize(-l2), vmax=0)
else: else:
norm_col = SymLogNorm(linthresh=quantize(-h2), norm_col = SymLogNorm(
linthresh=quantize(-h2),
vmin=-quantize(-l2), vmin=-quantize(-l2),
vmax=-quantize(-h2)) vmax=-quantize(-h2),
)
else: else:
vlim = quantize(max(-l2, h2)) vlim = quantize(max(-l2, h2))
if w1 > 0.1 * w2 or l1 * h1 >= 0: if w1 > 0.1 * w2 or l1 * h1 >= 0:
@ -80,8 +85,13 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
else: else:
linthresh = quantize(min(-l1, h1)) linthresh = quantize(min(-l1, h1))
linscale = np.log10(vlim / linthresh) linscale = np.log10(vlim / linthresh)
norm_col = SymLogNorm(linthresh=linthresh, linscale=linscale, norm_col = SymLogNorm(
vmin=-vlim, vmax=vlim, base=10) linthresh=linthresh,
linscale=linscale,
vmin=-vlim,
vmax=vlim,
base=10,
)
for c in range(field.shape[0]): for c in range(field.shape[0]):
s = (c,) + tuple(d // 2 for d in field.shape[1:-2]) s = (c,) + tuple(d // 2 for d in field.shape[1:-2])
@ -101,7 +111,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col) axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
axes[c, f].set_aspect('equal') axes[c, f].set_aspect("equal")
axes[c, f].set_xticks([]) axes[c, f].set_xticks([])
axes[c, f].set_yticks([]) axes[c, f].set_yticks([])
@ -110,12 +120,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
axes[c, f].set_title(title[f]) axes[c, f].set_title(title[f])
for c in range(field.shape[0], nc): for c in range(field.shape[0], nc):
axes[c, f].axis('off') axes[c, f].axis("off")
fig.colorbar( fig.colorbar(
ScalarMappable(norm=norm_col, cmap=cmap_col), ScalarMappable(norm=norm_col, cmap=cmap_col),
cax=axes[-1, f], cax=axes[-1, f],
orientation='horizontal', orientation="horizontal",
) )
fig.tight_layout() fig.tight_layout()
@ -133,7 +143,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
See `map2map.models.power`. See `map2map.models.power`.
""" """
plt.close('all') plt.close("all")
if label is not None: if label is not None:
assert len(label) == len(fields) or len(label) == len(dis) assert len(label) == len(fields) or len(label) == len(dis)
@ -159,8 +169,8 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
axes.loglog(k, P, label=l, alpha=0.7) axes.loglog(k, P, label=l, alpha=0.7)
axes.legend() axes.legend()
axes.set_xlabel('unnormalized wavenumber') axes.set_xlabel("unnormalized wavenumber")
axes.set_ylabel('unnormalized power') axes.set_ylabel("unnormalized power")
fig.tight_layout() fig.tight_layout()

View File

@ -21,7 +21,7 @@ def import_attr(name, *pkgs, callback_at=None):
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
from 'path/to/cb_dir/mod.py'. from 'path/to/cb_dir/mod.py'.
""" """
if name.count('.') == 0: if name.count(".") == 0:
attr = name attr = name
errors = [] errors = []
@ -34,23 +34,22 @@ def import_attr(name, *pkgs, callback_at=None):
raise Exception(errors) raise Exception(errors)
else: else:
mod, attr = name.rsplit('.', 1) mod, attr = name.rsplit(".", 1)
errors = [] errors = []
for pkg in pkgs: for pkg in pkgs:
try: try:
return getattr( return getattr(importlib.import_module(pkg.__name__ + "." + mod), attr)
importlib.import_module(pkg.__name__ + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError) as e: except (ModuleNotFoundError, AttributeError) as e:
errors.append(e) errors.append(e)
if callback_at is None: if callback_at is None:
raise Exception(errors) raise Exception(errors)
callback_at = os.path.join(callback_at, mod + '.py') callback_at = os.path.join(callback_at, mod + ".py")
if not os.path.isfile(callback_at): if not os.path.isfile(callback_at):
raise FileNotFoundError('callback file not found') raise FileNotFoundError("callback file not found")
if mod in sys.modules: if mod in sys.modules:
return getattr(sys.modules[mod], attr) return getattr(sys.modules[mod], attr)

View File

@ -7,9 +7,13 @@ def load_model_state_dict(module, state_dict, strict=True):
bad_keys = module.load_state_dict(state_dict, strict) bad_keys = module.load_state_dict(state_dict, strict)
if len(bad_keys.missing_keys) > 0: if len(bad_keys.missing_keys) > 0:
warnings.warn('Missing keys in state_dict:\n{}'.format( warnings.warn(
pformat(bad_keys.missing_keys))) "Missing keys in state_dict:\n{}".format(pformat(bad_keys.missing_keys))
)
if len(bad_keys.unexpected_keys) > 0: if len(bad_keys.unexpected_keys) > 0:
warnings.warn('Unexpected keys in state_dict:\n{}'.format( warnings.warn(
pformat(bad_keys.unexpected_keys))) "Unexpected keys in state_dict:\n{}".format(
pformat(bad_keys.unexpected_keys)
)
)
sys.stderr.flush() sys.stderr.flush()