chore: Blackify the source code for pretty formatting

This commit is contained in:
Guilhem Lavaux 2024-05-20 17:50:32 +02:00
parent ec9732df21
commit 7361c3eb9d
27 changed files with 908 additions and 568 deletions

View file

@ -7,18 +7,16 @@ from .train import ckpt_link
def get_args():
"""Parse arguments and set runtime defaults.
"""
parser = argparse.ArgumentParser(
description='Transform field(s) to field(s)')
"""Parse arguments and set runtime defaults."""
parser = argparse.ArgumentParser(description="Transform field(s) to field(s)")
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
subparsers = parser.add_subparsers(title="modes", dest="mode", required=True)
train_parser = subparsers.add_parser(
'train',
"train",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
test_parser = subparsers.add_parser(
'test',
"test",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
@ -27,163 +25,293 @@ def get_args():
args = parser.parse_args()
if args.mode == 'train':
if args.mode == "train":
set_train_args(args)
elif args.mode == 'test':
elif args.mode == "test":
set_test_args(args)
return args
def add_common_args(parser):
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
'of input normalization functions')
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions')
parser.add_argument('--crop', type=int_tuple,
help='size to crop the input and target data. Default is the '
'field size. Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-start', type=int_tuple,
help='starting point of the first crop. Default is the origin. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-stop', type=int_tuple,
help='stopping point of the last crop. Default is the opposite '
'corner to the origin. Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-step', type=int_tuple,
help='spacing between crops. Default is the crop size. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
'integers, to pad equally along all axes, symmetrically on each, '
'or by the specified size on every boundary, respectively')
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
help='size to pad the target data beyond the crop size, assuming '
'periodic boundary condition, useful for super-resolution. '
'Comma-sep. list with the same format as --in-pad')
parser.add_argument('--scale-factor', default=1, type=int,
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')
parser.add_argument(
"--in-norms",
type=str_list,
help="comma-sep. list " "of input normalization functions",
)
parser.add_argument(
"--tgt-norms",
type=str_list,
help="comma-sep. list " "of target normalization functions",
)
parser.add_argument(
"--crop",
type=int_tuple,
help="size to crop the input and target data. Default is the "
"field size. Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-start",
type=int_tuple,
help="starting point of the first crop. Default is the origin. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-stop",
type=int_tuple,
help="stopping point of the last crop. Default is the opposite "
"corner to the origin. Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-step",
type=int_tuple,
help="spacing between crops. Default is the crop size. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--in-pad",
"--pad",
default=0,
type=int_tuple,
help="size to pad the input data beyond the crop size, assuming "
"periodic boundary condition. Comma-sep. list of 1, d, or dx2 "
"integers, to pad equally along all axes, symmetrically on each, "
"or by the specified size on every boundary, respectively",
)
parser.add_argument(
"--tgt-pad",
default=0,
type=int_tuple,
help="size to pad the target data beyond the crop size, assuming "
"periodic boundary condition, useful for super-resolution. "
"Comma-sep. list with the same format as --in-pad",
)
parser.add_argument(
"--scale-factor",
default=1,
type=int,
help="upsampling factor for super-resolution, in which case "
"crop and pad are sizes of the input resolution",
)
parser.add_argument('--model', type=str, required=True,
help='(generator) model')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='loss function')
parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. '
'Start from scratch in case of empty string or missing checkpoint')
parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states',
dest='load_state_strict')
parser.add_argument("--model", type=str, required=True, help="(generator) model")
parser.add_argument(
"--criterion", default="MSELoss", type=str, help="loss function"
)
parser.add_argument(
"--load-state",
default=ckpt_link,
type=str,
help="path to load the states of model, optimizer, rng, etc. "
"Default is the checkpoint. "
"Start from scratch in case of empty string or missing checkpoint",
)
parser.add_argument(
"--load-state-non-strict",
action="store_false",
help="allow incompatible keys when loading model states",
dest="load_state_strict",
)
# somehow I named it "batches" instead of batch_size at first
# "batches" is kept for now for backward compatibility
parser.add_argument('--batch-size', '--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=8, type=int,
help='number of subprocesses per data loader. '
'0 to disable multiprocessing')
parser.add_argument(
"--batch-size",
"--batches",
type=int,
required=True,
help="mini-batch size, per GPU in training or in total in testing",
)
parser.add_argument(
"--loader-workers",
default=8,
type=int,
help="number of subprocesses per data loader. " "0 to disable multiprocessing",
)
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority')
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
help='miscellaneous keyword arguments for custom models and '
'norms. Be careful with name collisions')
parser.add_argument(
"--callback-at",
type=lambda s: os.path.abspath(s),
help="directory of custorm code defining callbacks for models, "
"norms, criteria, and optimizers. Disabled if not set. "
"This is appended to the default locations, "
"thus has the lowest priority",
)
parser.add_argument(
"--misc-kwargs",
default="{}",
type=json.loads,
help="miscellaneous keyword arguments for custom models and "
"norms. Be careful with name collisions",
)
def add_train_args(parser):
add_common_args(parser)
parser.add_argument('--train-style-pattern', type=str, required=True,
help='glob pattern for training data styles')
parser.add_argument('--train-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training input data')
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training target data')
parser.add_argument('--val-style-pattern', type=str,
help='glob pattern for validation data styles')
parser.add_argument('--val-in-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int_tuple,
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')
parser.add_argument('--aug-mul', type=float,
help='multiplicative data augmentation, (log-normal) std, '
'same factor for all fields')
parser.add_argument(
"--train-style-pattern",
type=str,
required=True,
help="glob pattern for training data styles",
)
parser.add_argument(
"--train-in-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for training input data",
)
parser.add_argument(
"--train-tgt-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for training target data",
)
parser.add_argument(
"--val-style-pattern", type=str, help="glob pattern for validation data styles"
)
parser.add_argument(
"--val-in-patterns",
type=str_list,
help="comma-sep. list of glob patterns for validation input data",
)
parser.add_argument(
"--val-tgt-patterns",
type=str_list,
help="comma-sep. list of glob patterns for validation target data",
)
parser.add_argument(
"--augment",
action="store_true",
help="enable data augmentation of axis flipping and permutation",
)
parser.add_argument(
"--aug-shift",
type=int_tuple,
help="data augmentation by shifting cropping by [0, aug_shift) pixels, "
"useful for models that treat neighboring pixels differently, "
"e.g. with strided convolutions. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--aug-add",
type=float,
help="additive data augmentation, (normal) std, " "same factor for all fields",
)
parser.add_argument(
"--aug-mul",
type=float,
help="multiplicative data augmentation, (log-normal) std, "
"same factor for all fields",
)
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimization algorithm')
parser.add_argument('--lr', type=float, required=True,
help='initial learning rate')
parser.add_argument('--optimizer-args', default='{}', type=json.loads,
help='optimizer arguments in addition to the learning rate, '
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--scheduler-args', default='{"verbose": true}',
type=json.loads,
help='arguments for the ReduceLROnPlateau scheduler')
parser.add_argument('--init-weight-std', type=float,
help='weight initialization std')
parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run')
parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training')
parser.add_argument(
"--optimizer", default="Adam", type=str, help="optimization algorithm"
)
parser.add_argument("--lr", type=float, required=True, help="initial learning rate")
parser.add_argument(
"--optimizer-args",
default="{}",
type=json.loads,
help="optimizer arguments in addition to the learning rate, "
"e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'",
)
parser.add_argument(
"--reduce-lr-on-plateau",
action="store_true",
help="Enable ReduceLROnPlateau learning rate scheduler",
)
parser.add_argument(
"--scheduler-args",
default='{"verbose": true}',
type=json.loads,
help="arguments for the ReduceLROnPlateau scheduler",
)
parser.add_argument(
"--init-weight-std", type=float, help="weight initialization std"
)
parser.add_argument(
"--epochs", default=128, type=int, help="total number of epochs to run"
)
parser.add_argument(
"--seed", default=42, type=int, help="seed for initializing training"
)
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs for better page caching. '
'Data division is shuffled every epoch. '
'Only relevant if there are multiple crops in each field')
parser.add_argument('--div-shuffle-dist', default=1, type=float,
help='distance to further shuffle cropped samples relative to '
'their fields, to be used with --div-data. '
'Only relevant if there are multiple crops in each file. '
'The order of each sample is randomly displaced by this value. '
'Setting it to 0 turn off this randomization, and setting it to N '
'limits the shuffling within a distance of N files. '
'Change this to balance cache locality and stochasticity')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int,
help='interval (batches) between logging training loss')
parser.add_argument('--detect-anomaly', action='store_true',
help='enable anomaly detection for the autograd engine')
parser.add_argument(
"--div-data",
action="store_true",
help="enable data division among GPUs for better page caching. "
"Data division is shuffled every epoch. "
"Only relevant if there are multiple crops in each field",
)
parser.add_argument(
"--div-shuffle-dist",
default=1,
type=float,
help="distance to further shuffle cropped samples relative to "
"their fields, to be used with --div-data. "
"Only relevant if there are multiple crops in each file. "
"The order of each sample is randomly displaced by this value. "
"Setting it to 0 turn off this randomization, and setting it to N "
"limits the shuffling within a distance of N files. "
"Change this to balance cache locality and stochasticity",
)
parser.add_argument(
"--dist-backend",
default="nccl",
type=str,
choices=["gloo", "nccl"],
help="distributed backend",
)
parser.add_argument(
"--log-interval",
default=100,
type=int,
help="interval (batches) between logging training loss",
)
parser.add_argument(
"--detect-anomaly",
action="store_true",
help="enable anomaly detection for the autograd engine",
)
def add_test_args(parser):
add_common_args(parser)
parser.add_argument('--test-style-pattern', type=str, required=True,
help='glob pattern for test data styles')
parser.add_argument('--test-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test input data')
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test target data')
parser.add_argument(
"--test-style-pattern",
type=str,
required=True,
help="glob pattern for test data styles",
)
parser.add_argument(
"--test-in-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test input data",
)
parser.add_argument(
"--test-tgt-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test target data",
)
parser.add_argument('--num-threads', type=int,
help='number of CPU threads when cuda is unavailable. '
'Default is the number of CPUs on the node by slurm')
parser.add_argument(
"--num-threads",
type=int,
help="number of CPU threads when cuda is unavailable. "
"Default is the number of CPUs on the node by slurm",
)
def str_list(s):
return s.split(',')
return s.split(",")
def int_tuple(s):
t = s.split(',')
t = s.split(",")
t = tuple(int(i) for i in t)
if len(t) == 1:
return t[0]
@ -198,8 +326,7 @@ def set_common_args(args):
def set_train_args(args):
set_common_args(args)
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None
def set_test_args(args):

View file

@ -43,54 +43,72 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
"""
def __init__(self, style_pattern, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1,
**kwargs):
def __init__(
self,
style_pattern,
in_patterns,
tgt_patterns,
in_norms=None,
tgt_norms=None,
callback_at=None,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=None,
crop_start=None,
crop_stop=None,
crop_step=None,
in_pad=0,
tgt_pad=0,
scale_factor=1,
**kwargs,
):
self.style_files = sorted(glob(style_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
self.in_files = list(zip(*in_file_lists))
tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns]
self.tgt_files = list(zip(* tgt_file_lists))
self.tgt_files = list(zip(*tgt_file_lists))
# self.style_files = self.style_files[:1]
# self.in_files = self.in_files[:1]
# self.tgt_files = self.tgt_files[:1]
if len(self.style_files) != len(self.in_files) != len(self.tgt_files):
raise ValueError('number of style, input, and target files do not match')
raise ValueError("number of style, input, and target files do not match")
self.nfile = len(self.in_files)
if self.nfile == 0:
raise FileNotFoundError('file not found for {}'.format(in_patterns))
raise FileNotFoundError("file not found for {}".format(in_patterns))
self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0]
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.tgt_files[0]]
self.in_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.tgt_files[0]]
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
self.size = np.load(self.in_files[0][0], mmap_mode="r").shape[1:]
self.size = np.asarray(self.size)
self.ndim = len(self.size)
if in_norms is not None and len(in_patterns) != len(in_norms):
raise ValueError('numbers of input normalization functions and fields do not match')
raise ValueError(
"numbers of input normalization functions and fields do not match"
)
self.in_norms = in_norms
if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms):
raise ValueError('numbers of target normalization functions and fields do not match')
raise ValueError(
"numbers of target normalization functions and fields do not match"
)
self.tgt_norms = tgt_norms
self.callback_at = callback_at
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
raise ValueError("cannot augment 1D fields")
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add
self.aug_mul = aug_mul
@ -116,10 +134,15 @@ class FieldDataset(Dataset):
crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.crop_step = crop_step
self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim)
self.anchors = np.stack(
np.mgrid[
tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)
],
axis=-1,
).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
def format_pad(pad, ndim):
@ -130,15 +153,16 @@ class FieldDataset(Dataset):
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
pad = np.array(pad)
else:
raise ValueError('pad and ndim mismatch')
raise ValueError("pad and ndim mismatch")
return pad.reshape(ndim, 2)
self.in_pad = format_pad(in_pad, self.ndim)
self.tgt_pad = format_pad(tgt_pad, self.ndim)
if scale_factor != 1:
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
tgt_size = np.load(self.tgt_files[0][0], mmap_mode="r").shape[1:]
if any(self.size * scale_factor != tgt_size):
raise ValueError('input size x scale factor != target size')
raise ValueError("input size x scale factor != target size")
self.scale_factor = scale_factor
self.nsample = self.nfile * self.ncrop
@ -148,9 +172,7 @@ class FieldDataset(Dataset):
self.assembly_line = {}
self.commonpath = os.path.commonpath(
file
for files in self.in_files[:2] + self.tgt_files[:2]
for file in files
file for files in self.in_files[:2] + self.tgt_files[:2] for file in files
)
def __len__(self):
@ -180,12 +202,18 @@ class FieldDataset(Dataset):
else:
argsort_perm_axes = slice(None)
crop(in_fields, anchor,
self.crop[argsort_perm_axes],
self.in_pad[argsort_perm_axes])
crop(tgt_fields, anchor * self.scale_factor,
self.crop[argsort_perm_axes] * self.scale_factor,
self.tgt_pad[argsort_perm_axes])
crop(
in_fields,
anchor,
self.crop[argsort_perm_axes],
self.in_pad[argsort_perm_axes],
)
crop(
tgt_fields,
anchor * self.scale_factor,
self.crop[argsort_perm_axes] * self.scale_factor,
self.tgt_pad[argsort_perm_axes],
)
style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -221,17 +249,19 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
#in_relpath = [os.path.relpath(file, start=self.commonpath)
# in_relpath = [os.path.relpath(file, start=self.commonpath)
# for file in self.in_files[ifile]]
tgt_relpath = [os.path.relpath(file, start=self.commonpath)
for file in self.tgt_files[ifile]]
tgt_relpath = [
os.path.relpath(file, start=self.commonpath)
for file in self.tgt_files[ifile]
]
return {
'style': style,
'input': in_fields,
'target': tgt_fields,
"style": style,
"input": in_fields,
"target": tgt_fields,
#'input_relpath': in_relpath,
'target_relpath': tgt_relpath,
"target_relpath": tgt_relpath,
}
def assemble(self, label, chan, patches, paths):
@ -255,29 +285,29 @@ class FieldDataset(Dataset):
if isinstance(patches, torch.Tensor):
patches = patches.detach().cpu().numpy()
assert patches.ndim == 2 + self.ndim, 'ndim mismatch'
assert patches.ndim == 2 + self.ndim, "ndim mismatch"
if any(self.crop_step > patches.shape[2:]):
raise RuntimeError('patch too small to tile')
raise RuntimeError("patch too small to tile")
# the batched paths are a list of lists with shape (channel, batch)
# since pytorch default_collate batches list of strings transposedly
# therefore we transpose below back to (batch, channel)
assert patches.shape[1] == sum(chan), 'number of channels mismatch'
assert len(paths) == len(chan), 'number of fields mismatch'
paths = list(zip(* paths))
assert patches.shape[0] == len(paths), 'batch size mismatch'
assert patches.shape[1] == sum(chan), "number of channels mismatch"
assert len(paths) == len(chan), "number of fields mismatch"
paths = list(zip(*paths))
assert patches.shape[0] == len(paths), "batch size mismatch"
patches = list(patches)
if label in self.assembly_line:
self.assembly_line[label] += patches
self.assembly_line[label + 'path'] += paths
self.assembly_line[label + "path"] += paths
else:
self.assembly_line[label] = patches
self.assembly_line[label + 'path'] = paths
self.assembly_line[label + "path"] = paths
del patches, paths
patches = self.assembly_line[label]
paths = self.assembly_line[label + 'path']
paths = self.assembly_line[label + "path"]
# NOTE anchor positioning assumes sufficient target padding and
# symmetric narrowing (more on the right if odd) see `models/narrow.py`
@ -285,31 +315,26 @@ class FieldDataset(Dataset):
anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2
while len(patches) >= self.ncrop:
fields = np.zeros(patches[0].shape[:1] + tuple(self.size),
patches[0].dtype)
fields = np.zeros(patches[0].shape[:1] + tuple(self.size), patches[0].dtype)
for patch, anchor in zip(patches, anchors):
fill(fields, patch, anchor)
for field, path in zip(
np.split(fields, np.cumsum(chan), axis=0),
paths[0]):
pathlib.Path(os.path.dirname(path)).mkdir(parents=True,
exist_ok=True)
for field, path in zip(np.split(fields, np.cumsum(chan), axis=0), paths[0]):
pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
path = label.join(os.path.splitext(path))
np.save(path, field)
del patches[:self.ncrop], paths[:self.ncrop]
del patches[: self.ncrop], paths[: self.ncrop]
def fill(field, patch, anchor):
ndim = len(anchor)
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch'
assert field.ndim == patch.ndim == 1 + ndim, "ndim mismatch"
ind = [slice(None)]
for d, (p, a, s) in enumerate(zip(
patch.shape[1:], anchor, field.shape[1:])):
for d, (p, a, s) in enumerate(zip(patch.shape[1:], anchor, field.shape[1:])):
i = np.arange(a, a + p)
i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
@ -320,10 +345,10 @@ def fill(field, patch, anchor):
def crop(fields, anchor, crop, pad):
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch'
assert all(x.shape == fields[0].shape for x in fields), "shape mismatch"
size = fields[0].shape[1:]
ndim = len(size)
assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch'
assert ndim == len(anchor) == len(crop) == len(pad), "ndim mismatch"
ind = [slice(None)]
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
@ -342,7 +367,7 @@ def crop(fields, anchor, crop, pad):
def flip(fields, axes, ndim):
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
assert ndim > 1, "flipping is ambiguous for 1D scalars/vectors"
if axes is None:
axes = torch.randint(2, (ndim,), dtype=torch.bool)
@ -350,7 +375,7 @@ def flip(fields, axes, ndim):
for i, x in enumerate(fields):
if x.shape[0] == ndim: # flip vector components
x[axes] = - x[axes]
x[axes] = -x[axes]
shifted_axes = (1 + axes).tolist()
x = torch.flip(x, shifted_axes)
@ -361,7 +386,7 @@ def flip(fields, axes, ndim):
def perm(fields, axes, ndim):
assert ndim > 1, 'permutation is not necessary for 1D fields'
assert ndim > 1, "permutation is not necessary for 1D fields"
if axes is None:
axes = torch.randperm(ndim)

View file

@ -10,6 +10,7 @@ def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
x *= dis_norm
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
@ -20,25 +21,28 @@ def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
def D(z, Om=0.31):
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
"""
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero"""
OL = 1 - Om
a = 1 / (1+z)
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
a = 1 / (1 + z)
return (
a
* hyp2f1(1, 1 / 3, 11 / 6, -OL * a**3 / Om)
/ hyp2f1(1, 1 / 3, 11 / 6, -OL / Om)
)
def f(z, Om=0.31):
"""linear growth rate for flat LambdaCDM
"""
"""linear growth rate for flat LambdaCDM"""
OL = 1 - Om
a = 1 / (1+z)
a = 1 / (1 + z)
aa3 = OL * a**3 / Om
return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \
/ hyp2f1(1, 1/3, 11/6, -aa3)
return 1 - 6 / 11 * aa3 * hyp2f1(2, 4 / 3, 17 / 6, -aa3) / hyp2f1(
1, 1 / 3, 11 / 6, -aa3
)
def H(z, Om=0.31):
"""Hubble in [h km/s/Mpc] for flat LambdaCDM
"""
"""Hubble in [h km/s/Mpc] for flat LambdaCDM"""
OL = 1 - Om
a = 1 / (1+z)
a = 1 / (1 + z)
return 100 * np.sqrt(Om / a**3 + OL)

View file

@ -7,18 +7,21 @@ def exp(x, undo=False, **kwargs):
else:
torch.log(x, out=x)
def log(x, eps=1e-8, undo=False, **kwargs):
if not undo:
torch.log(x + eps, out=x)
else:
torch.exp(x, out=x)
def expm1(x, undo=False, **kwargs):
if not undo:
torch.expm1(x, out=x)
else:
torch.log1p(x, out=x)
def log1p(x, eps=1e-7, undo=False, **kwargs):
if not undo:
torch.log1p(x + eps, out=x)

View file

@ -22,8 +22,8 @@ class DistFieldSampler(Sampler):
Like `DistributedSampler`, `set_epoch()` should be called at the beginning
of each epoch during training.
"""
def __init__(self, dataset, shuffle,
div_data=False, div_shuffle_dist=0):
def __init__(self, dataset, shuffle, div_data=False, div_shuffle_dist=0):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@ -50,8 +50,10 @@ class DistFieldSampler(Sampler):
ind = ind.flatten()
# displace crops with respect to files
dis = torch.rand((self.nfile, self.ncrop),
generator=g) * self.div_shuffle_dist
dis = (
torch.rand((self.nfile, self.ncrop), generator=g)
* self.div_shuffle_dist
)
loc = torch.arange(self.nfile)
loc = loc[:, None] + dis
loc = loc.flatten() % self.nfile # periodic in files

View file

@ -3,6 +3,7 @@ from . import test
import click
import os
import yaml
try:
from yaml import CLoader as Loader
except ImportError:
@ -12,20 +13,24 @@ import importlib.resources
import json
from functools import partial
def _load_resource_file(resource_path):
# Import the package
pkg_files = importlib.resources.files() / resource_path
with pkg_files.open() as file:
return file.read() # Read the file and return its content
return file.read() # Read the file and return its content
def _str_list(value):
return value.split(',')
return value.split(",")
def _int_tuple(value):
t = value.split(',')
t = value.split(",")
t = tuple(int(i) for i in t)
return t
class VariadicType(click.ParamType):
"""
A custom parameter type for Click command-line interface.
@ -54,14 +59,14 @@ class VariadicType(click.ParamType):
if typename in self._mapper:
self._type = self._mapper[typename]
elif type(typename) == dict:
self._type = self._mapper[typename["type"]]
self._type = self._mapper[typename["type"]]
self.args = typename["opts"]
else:
raise ValueError(f"Unknown type: {typename}")
raise ValueError(f"Unknown type: {typename}")
self._typename = typename
self.name = self._type["type"]
if "func" not in self._type:
self._type["func"] = eval(self._type['type'])
self._type["func"] = eval(self._type["type"])
def convert(self, value, param, ctx):
try:
@ -69,24 +74,26 @@ class VariadicType(click.ParamType):
except Exception as e:
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
def _apply_options(options_file, f):
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
common_args = common_args['arguments']
common_args = common_args["arguments"]
for arg in common_args:
argopt = common_args[arg]
if 'type' in argopt:
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
argopt['type'] = click.Choice(argopt['type']['opts'])
if "type" in argopt:
if type(argopt["type"]) == dict and argopt["type"]["type"] == "choice":
argopt["type"] = click.Choice(argopt["type"]["opts"])
else:
argopt['type'] = VariadicType(argopt['type'])
f = click.option(f'--{arg}', **argopt)(f)
argopt["type"] = VariadicType(argopt["type"])
f = click.option(f"--{arg}", **argopt)(f)
else:
f = click.option(f'--{arg}', **argopt)(f)
f = click.option(f"--{arg}", **argopt)(f)
return f
m2m_options=partial(_apply_options,"common_args.yaml")
m2m_options = partial(_apply_options, "common_args.yaml")
@click.group()
@ -94,23 +101,26 @@ m2m_options=partial(_apply_options,"common_args.yaml")
@click.pass_context
def main(ctx, config):
if config is not None and os.path.exists(config):
with open(config, 'r') as f:
with open(config, "r") as f:
config = yaml.load(f.read(), Loader=Loader)
ctx.default_map = config
# Make a class that provides access to dict with the attribute mechanism
class DictProxy:
def __init__(self, d):
self.__dict__ = d
@main.command()
@m2m_options
@partial(_apply_options, "train_args.yaml")
def train(**kwargs):
train.node_worker(DictProxy(kwargs))
@main.command()
@m2m_options
@partial(_apply_options, "test_args.yaml")
def test(**kwargs):
test.test(DictProxy(kwargs))
test.test(DictProxy(kwargs))

View file

@ -5,6 +5,7 @@ def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension
"""
class _new_module(module):
def forward(self, x):
if not isinstance(x, torch.Tensor):
@ -22,6 +23,7 @@ def adv_criterion_wrapper(module):
* expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors
"""
class _new_module(module):
def forward(self, input, target):
assert isinstance(input, torch.Tensor)
@ -35,8 +37,9 @@ def adv_criterion_wrapper(module):
target = [t.expand_as(i) for i, t in zip(input, target)]
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
loss = [
super(new_module, self).forward(i, t) for i, t in zip(input, target)
]
return loss

View file

@ -15,8 +15,10 @@ class ConvBlock(nn.Module):
'U': upsampling transposed convolution of kernel size 2 and stride 2
'D': downsampling convolution of kernel size 2 and stride 2
"""
def __init__(self, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBA'):
def __init__(
self, in_chan, out_chan=None, mid_chan=None, kernel_size=3, stride=1, seq="CBA"
):
super().__init__()
if out_chan is None:
@ -31,31 +33,30 @@ class ConvBlock(nn.Module):
self.norm_chan = in_chan
self.idx_conv = 0
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
layers = [self._get_layer(l) for l in seq]
self.convs = nn.Sequential(*layers)
def _get_layer(self, l):
if l == 'U':
if l == "U":
in_chan, out_chan = self._setup_conv()
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
elif l == 'D':
elif l == "D":
in_chan, out_chan = self._setup_conv()
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
elif l == 'C':
elif l == "C":
in_chan, out_chan = self._setup_conv()
return nn.Conv3d(in_chan, out_chan, self.kernel_size,
stride=self.stride)
elif l == 'B':
return nn.Conv3d(in_chan, out_chan, self.kernel_size, stride=self.stride)
elif l == "B":
return nn.BatchNorm3d(self.norm_chan)
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
#return nn.InstanceNorm3d(self.norm_chan)
elif l == 'A':
# return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
# return nn.InstanceNorm3d(self.norm_chan)
elif l == "A":
return nn.LeakyReLU()
else:
raise ValueError('layer type {} not supported'.format(l))
raise ValueError("layer type {} not supported".format(l))
def _setup_conv(self):
self.idx_conv += 1
@ -88,13 +89,22 @@ class ResBlock(ConvBlock):
See `ConvBlock` for `seq` types.
"""
def __init__(self, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
def __init__(
self,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBACBA",
last_act=None,
):
if last_act is None:
last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A':
last_act = seq[-1] == "A"
elif last_act and seq[-1] != "A":
warnings.warn(
'Disabling last_act without trailing activation in seq',
"Disabling last_act without trailing activation in seq",
RuntimeWarning,
)
last_act = False
@ -102,8 +112,14 @@ class ResBlock(ConvBlock):
if last_act:
seq = seq[:-1]
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan,
kernel_size=kernel_size, stride=stride, seq=seq)
super().__init__(
in_chan,
out_chan=out_chan,
mid_chan=mid_chan,
kernel_size=kernel_size,
stride=stride,
seq=seq,
)
if last_act:
self.act = nn.LeakyReLU()
@ -115,9 +131,10 @@ class ResBlock(ConvBlock):
else:
self.skip = nn.Conv3d(in_chan, out_chan, 1)
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
if "U" in seq or "D" in seq:
raise NotImplementedError(
"upsample and downsample layers " "not supported yet"
)
def forward(self, x):
y = x

View file

@ -2,7 +2,7 @@ import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, eps=0.):
def __init__(self, eps=0.0):
super().__init__()
self.eps = eps
@ -10,7 +10,7 @@ class DiceLoss(nn.Module):
return dice_loss(input, target, self.eps)
def dice_loss(input, target, eps=0.):
def dice_loss(input, target, eps=0.0):
input = input.view(-1)
target = target.view(-1)

View file

@ -1,10 +1,11 @@
import torch
class InstanceNoise:
"""Instance noise, with a linear decaying schedule
"""
"""Instance noise, with a linear decaying schedule"""
def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative'
assert init_std >= 0, "Noise std cannot be negative"
self.init_std = init_std
self._std = init_std
self.batches = batches

View file

@ -5,17 +5,18 @@ from ..data.norms.cosmology import D
def lag2eul(
dis,
val=1.0,
eul_scale_factor=1,
eul_pad=0,
rm_dis_mean=True,
periodic=False,
z=0.0,
dis_std=6.0,
boxsize=1000.,
meshsize=512,
**kwargs):
dis,
val=1.0,
eul_scale_factor=1,
eul_pad=0,
rm_dis_mean=True,
periodic=False,
z=0.0,
dis_std=6.0,
boxsize=1000.0,
meshsize=512,
**kwargs,
):
"""Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input.
@ -48,20 +49,19 @@ def lag2eul(
if isinstance(val, (float, torch.Tensor)):
val = [val]
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
raise ValueError('dis-val field mismatch')
raise ValueError("dis-val field mismatch")
if any(d.dim() != 5 for d in dis):
raise NotImplementedError('only support 3d fields for now')
raise NotImplementedError("only support 3d fields for now")
if any(d.shape[1] != 3 for d in dis):
raise ValueError('only support 3d displacement fields')
raise ValueError("only support 3d displacement fields")
# common mean displacement of all inputs
# if removed, fewer particles go outside of the box
# common for all inputs so outputs are comparable in the same coords
d_mean = 0
if rm_dis_mean:
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
for d in dis) / len(dis)
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) for d in dis) / len(dis)
out = []
if len(dis) == 1 and len(val) != 1:
@ -85,12 +85,15 @@ def lag2eul(
pos = (d - d_mean) * dis_norm
del d
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None, None]
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None]
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)
pos[:, 0] += torch.arange(
0, DHW[0] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
)[:, None, None]
pos[:, 1] += torch.arange(
0, DHW[1] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
)[:, None]
pos[:, 2] += torch.arange(
0, DHW[2] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device
)
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
@ -118,8 +121,7 @@ def lag2eul(
if periodic:
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
) * DHW[2] + tgtpos[n, 2]
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]) * DHW[2] + tgtpos[n, 2]
src = v[n]
if not periodic:

View file

@ -3,8 +3,7 @@ import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
"""Narrow a by size c symmetrically on all edges."""
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
return a[ind]

View file

@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
super().__init__()
self.convs = nn.Sequential(
ConvBlock(in_chan, 32, seq='CA'),
ConvBlock(32, 64, seq='CBA'),
ConvBlock(64, 128, seq='CBA'),
ConvBlock(128, out_chan, seq='C'),
ConvBlock(in_chan, 32, seq="CA"),
ConvBlock(32, 64, seq="CBA"),
ConvBlock(64, 128, seq="CBA"),
ConvBlock(128, out_chan, seq="C"),
)
def forward(self, x):
@ -19,23 +19,20 @@ class PatchGAN(nn.Module):
class PatchGAN42(nn.Module):
"""PatchGAN similar to the one in pix2pix
"""
"""PatchGAN similar to the one in pix2pix"""
def __init__(self, in_chan, out_chan=1, **kwargs):
super().__init__()
self.convs = nn.Sequential(
nn.Conv3d(in_chan, 64, 4, stride=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(64, 128, 4, stride=2),
nn.BatchNorm3d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(128, 256, 4, stride=2),
nn.BatchNorm3d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(256, out_chan, 1),
)

View file

@ -26,8 +26,7 @@ def power(x: torch.Tensor):
P = P.sum(dim=0)
del x
k = [torch.arange(d, dtype=torch.float32, device=P.device)
for d in P.shape]
k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape]
k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]]
k = torch.meshgrid(*k)
k = torch.stack(k, dim=0)
@ -49,9 +48,9 @@ def power(x: torch.Tensor):
del kbin
# drop k=0 mode and cut at kmax (smallest Nyquist)
k = k[1:1+kmax]
P = P[1:1+kmax]
N = N[1:1+kmax]
k = k[1 : 1 + kmax]
P = P[1 : 1 + kmax]
N = N[1 : 1 + kmax]
k /= N
P /= N

View file

@ -5,12 +5,11 @@ from .narrow import narrow_by
def resample(x, scale_factor, narrow=True):
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
ndim = x.dim() - 2
mode = modes[ndim]
x = F.interpolate(x, scale_factor=scale_factor,
mode=mode, align_corners=False)
x = F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=False)
if scale_factor > 1 and narrow == True:
edges = round(scale_factor) // 2
@ -25,18 +24,20 @@ class Resampler(nn.Module):
By default discard the inaccurate edges when upsampling.
"""
def __init__(self, ndim, scale_factor, narrow=True):
super().__init__()
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
modes = {1: "linear", 2: "bilinear", 3: "trilinear"}
self.mode = modes[ndim]
self.scale_factor = scale_factor
self.narrow = narrow
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor,
mode=self.mode, align_corners=False)
x = F.interpolate(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False
)
if self.scale_factor > 1 and self.narrow == True:
edges = round(self.scale_factor) // 2

View file

@ -4,8 +4,18 @@ from torch.nn.utils import spectral_norm, remove_spectral_norm
def add_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if isinstance(
child,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
setattr(module, name, spectral_norm(child))
else:
add_spectral_norm(child)
@ -13,8 +23,18 @@ def add_spectral_norm(module):
def rm_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if isinstance(
child,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
setattr(module, name, remove_spectral_norm(child))
else:
rm_spectral_norm(child)

View file

@ -7,9 +7,17 @@ from .resample import Resampler
class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
**kwargs):
def __init__(
self,
in_chan,
out_chan,
scale_factor=16,
chan_base=512,
chan_min=64,
chan_max=512,
cat_noise=False,
**kwargs,
):
super().__init__()
self.scale_factor = scale_factor
@ -30,15 +38,14 @@ class G(nn.Module):
self.blocks = nn.ModuleList()
for b in range(num_blocks):
prev_chan, next_chan = chan(b), chan(b+1)
self.blocks.append(
HBlock(prev_chan, next_chan, out_chan, cat_noise))
prev_chan, next_chan = chan(b), chan(b + 1)
self.blocks.append(HBlock(prev_chan, next_chan, out_chan, cat_noise))
def forward(self, x):
y = x # direct upsampling from the input
x = self.block0(x)
#y = None # no direct upsampling from the input
# y = None # no direct upsampling from the input
for block in self.blocks:
x, y = block(x, y)
@ -71,6 +78,7 @@ class HBlock(nn.Module):
-----
next_size = 2 * prev_size - 6
"""
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
super().__init__()
@ -81,7 +89,6 @@ class HBlock(nn.Module):
self.upsample,
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True),
AddNoise(cat_noise, chan=next_chan),
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True),
@ -114,6 +121,7 @@ class AddNoise(nn.Module):
The number of channels `chan` should be 1 (StyleGAN2)
or that of the input (StyleGAN).
"""
def __init__(self, cat, chan=1):
super().__init__()
@ -137,9 +145,16 @@ class AddNoise(nn.Module):
class D(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512,
**kwargs):
def __init__(
self,
in_chan,
out_chan,
scale_factor=16,
chan_base=512,
chan_min=64,
chan_max=512,
**kwargs,
):
super().__init__()
self.scale_factor = scale_factor
@ -163,7 +178,7 @@ class D(nn.Module):
self.blocks = nn.ModuleList()
for b in reversed(range(num_blocks)):
prev_chan, next_chan = chan(b+1), chan(b)
prev_chan, next_chan = chan(b + 1), chan(b)
self.blocks.append(ResBlock(prev_chan, next_chan))
self.block9 = nn.Sequential(
@ -192,13 +207,13 @@ class ResBlock(nn.Module):
-----
next_size = (prev_size - 4) // 2
"""
def __init__(self, prev_chan, next_chan):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(prev_chan, prev_chan, 3),
nn.LeakyReLU(0.2, True),
nn.Conv3d(prev_chan, next_chan, 3),
nn.LeakyReLU(0.2, True),
)

View file

@ -9,6 +9,7 @@ class PixelNorm(nn.Module):
See ProGAN, StyleGAN.
"""
def __init__(self):
super().__init__()
@ -23,6 +24,7 @@ class LinearElr(nn.Module):
Useful at all if not for regularization(1706.05350)?
"""
def __init__(self, in_size, out_size, bias=True, act=None):
super().__init__()
@ -32,7 +34,7 @@ class LinearElr(nn.Module):
if bias:
self.bias = nn.Parameter(torch.zeros(out_size))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.act = act
@ -52,20 +54,20 @@ class ConvElr3d(nn.Module):
Useful at all if not for regularization(1706.05350)?
"""
def __init__(self, in_chan, out_chan, kernel_size,
stride=1, padding=0, bias=True):
def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
)
fan_in = in_chan * kernel_size ** 3
fan_in = in_chan * kernel_size**3
self.wnorm = 1 / math.sqrt(fan_in)
if bias:
self.bias = nn.Parameter(torch.zeros(out_chan))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.stride = stride
self.padding = padding
@ -87,38 +89,49 @@ class ConvStyled3d(nn.Module):
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
"""
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
bias=True, resample=None):
def __init__(
self,
style_size,
in_chan,
out_chan,
kernel_size=3,
stride=1,
bias=True,
resample=None,
):
super().__init__()
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
mode='fan_in', nonlinearity='leaky_relu')
nn.init.kaiming_uniform_(
self.style_weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"
)
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
if resample is None:
K3 = (kernel_size,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = stride
self.conv = F.conv3d
elif resample == 'U':
elif resample == "U":
K3 = (2,) * 3
# NOTE not clear to me why convtranspose have channels swapped
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
self.stride = 2
self.conv = F.conv_transpose3d
elif resample == 'D':
elif resample == "D":
K3 = (2,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = 2
self.conv = F.conv3d
else:
raise ValueError('resample type {} not supported'.format(resample))
raise ValueError("resample type {} not supported".format(resample))
self.resample = resample
nn.init.kaiming_uniform_(
self.weight, a=math.sqrt(5),
mode='fan_in', # effectively 'fan_out' for 'D'
nonlinearity='leaky_relu',
self.weight,
a=math.sqrt(5),
mode="fan_in", # effectively 'fan_out' for 'D'
nonlinearity="leaky_relu",
)
if bias:
@ -127,12 +140,12 @@ class ConvStyled3d(nn.Module):
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
def forward(self, x, s, eps=1e-8):
N, Cin, *DHWin = x.shape
C0, C1, *K3 = self.weight.shape
if self.resample == 'U':
if self.resample == "U":
Cin, Cout = C0, C1
else:
Cout, Cin = C0, C1
@ -140,14 +153,14 @@ class ConvStyled3d(nn.Module):
s = F.linear(s, self.style_weight, bias=self.style_bias)
# modulation
if self.resample == 'U':
if self.resample == "U":
s = s.reshape(N, Cin, 1, 1, 1, 1)
else:
s = s.reshape(N, 1, Cin, 1, 1, 1)
w = self.weight * s
# demodulation
if self.resample == 'U':
if self.resample == "U":
fan_in_dim = (1, 3, 4, 5)
else:
fan_in_dim = (2, 3, 4, 5)
@ -161,18 +174,22 @@ class ConvStyled3d(nn.Module):
return x
class BatchNormStyled3d(nn.BatchNorm3d) :
""" Trivially does standard batch normalization, but accepts second argument
class BatchNormStyled3d(nn.BatchNorm3d):
"""Trivially does standard batch normalization, but accepts second argument
for style array that is not used
"""
def forward(self, x, s):
return super().forward(x)
class LeakyReLUStyled(nn.LeakyReLU):
""" Trivially evaluates standard leaky ReLU, but accepts second argument
"""Trivially evaluates standard leaky ReLU, but accepts second argument
for sytle array that is not used
"""
def forward(self, x, s):
return super().forward(x)

View file

@ -16,8 +16,17 @@ class ConvStyledBlock(nn.Module):
'U': upsampling transposed convolution of kernel size 2 and stride 2
'D': downsampling convolution of kernel size 2 and stride 2
"""
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBA'):
def __init__(
self,
style_size,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBA",
):
super().__init__()
if out_chan is None:
@ -33,31 +42,34 @@ class ConvStyledBlock(nn.Module):
self.norm_chan = in_chan
self.idx_conv = 0
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
layers = [self._get_layer(l) for l in seq]
self.convs = nn.ModuleList(layers)
def _get_layer(self, l):
if l == 'U':
if l == "U":
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
resample = 'U')
elif l == 'D':
return ConvStyled3d(
self.style_size, in_chan, out_chan, 2, stride=2, resample="U"
)
elif l == "D":
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
resample = 'D')
elif l == 'C':
return ConvStyled3d(
self.style_size, in_chan, out_chan, 2, stride=2, resample="D"
)
elif l == "C":
in_chan, out_chan = self._setup_conv()
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
stride=self.stride)
elif l == 'B':
return ConvStyled3d(
self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride
)
elif l == "B":
return BatchNormStyled3d(self.norm_chan)
elif l == 'A':
elif l == "A":
return LeakyReLUStyled()
else:
raise ValueError('layer type {} not supported'.format(l))
raise ValueError("layer type {} not supported".format(l))
def _setup_conv(self):
self.idx_conv += 1
@ -92,13 +104,23 @@ class ResStyledBlock(ConvStyledBlock):
See `ConvStyledBlock` for `seq` types.
"""
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
def __init__(
self,
style_size,
in_chan,
out_chan=None,
mid_chan=None,
kernel_size=3,
stride=1,
seq="CBACBA",
last_act=None,
):
if last_act is None:
last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A':
last_act = seq[-1] == "A"
elif last_act and seq[-1] != "A":
warnings.warn(
'Disabling last_act without trailing activation in seq',
"Disabling last_act without trailing activation in seq",
RuntimeWarning,
)
last_act = False
@ -106,8 +128,15 @@ class ResStyledBlock(ConvStyledBlock):
if last_act:
seq = seq[:-1]
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
kernel_size=kernel_size, stride=stride, seq=seq)
super().__init__(
style_size,
in_chan,
out_chan=out_chan,
mid_chan=mid_chan,
kernel_size=kernel_size,
stride=stride,
seq=seq,
)
if last_act:
self.act = LeakyReLUStyled()
@ -119,9 +148,10 @@ class ResStyledBlock(ConvStyledBlock):
else:
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
if "U" in seq or "D" in seq:
raise NotImplementedError(
"upsample and downsample layers " "not supported yet"
)
def forward(self, x, s):
y = x

View file

@ -15,17 +15,17 @@ class StyledVNet(nn.Module):
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA')
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq="CACBA")
self.down_l0 = ConvStyledBlock(style_size, 64, seq="DBA")
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
self.down_l1 = ConvStyledBlock(style_size, 64, seq="DBA")
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.conv_c = ResStyledBlock(style_size, 64, 64, seq="CBACBA")
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
self.up_r1 = ConvStyledBlock(style_size, 64, seq="UBA")
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq="CBACBA")
self.up_r0 = ConvStyledBlock(style_size, 64, seq="UBA")
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq="CAC")
if bypass is None:
self.bypass = in_chan == out_chan

View file

@ -19,17 +19,17 @@ class UNet(nn.Module):
"""
super().__init__()
self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA')
self.down_l0 = ConvBlock(64, seq='DBA')
self.conv_l1 = ConvBlock(64, seq='CBACBA')
self.down_l1 = ConvBlock(64, seq='DBA')
self.conv_l0 = ConvBlock(in_chan, 64, seq="CACBA")
self.down_l0 = ConvBlock(64, seq="DBA")
self.conv_l1 = ConvBlock(64, seq="CBACBA")
self.down_l1 = ConvBlock(64, seq="DBA")
self.conv_c = ConvBlock(64, seq='CBACBA')
self.conv_c = ConvBlock(64, seq="CBACBA")
self.up_r1 = ConvBlock(64, seq='UBA')
self.conv_r1 = ConvBlock(128, 64, seq='CBACBA')
self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
self.up_r1 = ConvBlock(64, seq="UBA")
self.conv_r1 = ConvBlock(128, 64, seq="CBACBA")
self.up_r0 = ConvBlock(64, seq="UBA")
self.conv_r0 = ConvBlock(128, out_chan, seq="CAC")
self.bypass = in_chan == out_chan

View file

@ -23,17 +23,17 @@ class VNet(nn.Module):
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
self.down_l0 = ConvBlock(64, seq='DBA')
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
self.down_l1 = ConvBlock(64, seq='DBA')
self.conv_l0 = ResBlock(in_chan, 64, seq="CACBA")
self.down_l0 = ConvBlock(64, seq="DBA")
self.conv_l1 = ResBlock(64, 64, seq="CBACBA")
self.down_l1 = ConvBlock(64, seq="DBA")
self.conv_c = ResBlock(64, 64, seq='CBACBA')
self.conv_c = ResBlock(64, 64, seq="CBACBA")
self.up_r1 = ConvBlock(64, seq='UBA')
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
self.up_r1 = ConvBlock(64, seq="UBA")
self.conv_r1 = ResBlock(128, 64, seq="CBACBA")
self.up_r0 = ConvBlock(64, seq="UBA")
self.conv_r0 = ResBlock(128, out_chan, seq="CAC")
if bypass is None:
self.bypass = in_chan == out_chan

View file

@ -16,21 +16,21 @@ from .utils import import_attr, load_model_state_dict
def test(args):
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
warnings.warn('Not parallelized but given more than 1 GPUs')
warnings.warn("Not parallelized but given more than 1 GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda', 0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda", 0)
torch.backends.cudnn.benchmark = True
else: # CPU multithreading
device = torch.device('cpu')
device = torch.device("cpu")
if args.num_threads is None:
args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE'])
args.num_threads = int(os.environ["SLURM_CPUS_ON_NODE"])
torch.set_num_threads(args.num_threads)
print('pytorch {}'.format(torch.__version__))
print("pytorch {}".format(torch.__version__))
pprint(vars(args))
sys.stdout.flush()
@ -67,26 +67,33 @@ def test(args):
out_chan = test_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(style_size, sum(in_chan), sum(out_chan),
scale_factor=args.scale_factor, **args.misc_kwargs)
model = model(
style_size,
sum(in_chan),
sum(out_chan),
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
model.to(device)
criterion = import_attr(args.criterion, torch.nn, models,
callback_at=args.callback_at)
criterion = import_attr(
args.criterion, torch.nn, models, callback_at=args.callback_at
)
criterion = criterion()
criterion.to(device)
state = torch.load(args.load_state, map_location=device)
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
print('model state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state))
load_model_state_dict(model, state["model"], strict=args.load_state_strict)
print(
"model state at epoch {} loaded from {}".format(state["epoch"], args.load_state)
)
del state
model.eval()
with torch.no_grad():
for i, data in enumerate(test_loader):
style, input, target = data['style'], data['input'], data['target']
style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
@ -94,21 +101,21 @@ def test(args):
output = model(input, style)
if i < 5:
print('##### sample :', i)
print('style shape :', style.shape)
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
print("##### sample :", i)
print("style shape :", style.shape)
print("input shape :", input.shape)
print("output shape :", output.shape)
print("target shape :", target.shape)
input, output, target = narrow_cast(input, output, target)
if i < 5:
print('narrowed shape :', output.shape, flush=True)
print("narrowed shape :", output.shape, flush=True)
loss = criterion(output, target)
print('sample {} loss: {}'.format(i, loss.item()))
print("sample {} loss: {}".format(i, loss.item()))
#if args.in_norms is not None:
# if args.in_norms is not None:
# start = 0
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
# norm = import_attr(norm, norms, callback_at=args.callback_at)
@ -119,12 +126,11 @@ def test(args):
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
# norm(target[:, start:stop], undo=True, **args.misc_kwargs)
start = stop
#test_dataset.assemble('_in', in_chan, input,
# test_dataset.assemble('_in', in_chan, input,
# data['input_relpath'])
test_dataset.assemble('_out', out_chan, output,
data['target_relpath'])
#test_dataset.assemble('_tgt', out_chan, target,
test_dataset.assemble("_out", out_chan, output, data["target_relpath"])
# test_dataset.assemble('_tgt', out_chan, target,
# data['target_relpath'])

View file

@ -18,34 +18,34 @@ from .models import narrow_cast, resample
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
ckpt_link = 'checkpoint.pt'
ckpt_link = "checkpoint.pt"
def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
elif 'SLURM_JOB_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
if "SLURM_STEP_NUM_NODES" in os.environ:
args.nodes = int(os.environ["SLURM_STEP_NUM_NODES"])
elif "SLURM_JOB_NUM_NODES" in os.environ:
args.nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
else:
raise KeyError('missing node counts in slurm env')
raise KeyError("missing node counts in slurm env")
args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID'])
node = int(os.environ["SLURM_NODEID"])
if args.gpus_per_node < 1:
raise RuntimeError('GPU not found on node {}'.format(node))
raise RuntimeError("GPU not found on node {}".format(node))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args):
#device = torch.device('cuda', local_rank)
#torch.cuda.device(device) # env var recommended over this
# device = torch.device('cuda', local_rank)
# torch.cuda.device(device) # env var recommended over this
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
device = torch.device('cuda', 0)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
device = torch.device("cuda", 0)
rank = args.gpus_per_node * node + local_rank
@ -53,7 +53,7 @@ def gpu_worker(local_rank, node, args):
# Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank)
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic
#torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.deterministic = True
dist_init(rank, args)
@ -77,9 +77,12 @@ def gpu_worker(local_rank, node, args):
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
train_sampler = DistFieldSampler(
train_dataset,
shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
@ -110,9 +113,12 @@ def gpu_worker(local_rank, node, args):
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
val_sampler = DistFieldSampler(
val_dataset,
shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
@ -127,14 +133,19 @@ def gpu_worker(local_rank, node, args):
args.out_chan = train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor, **args.misc_kwargs)
model = model(
args.style_size,
sum(args.in_chan),
sum(args.out_chan),
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
model = DistributedDataParallel(
model, device_ids=[device], process_group=dist.new_group()
)
criterion = import_attr(args.criterion, nn, models,
callback_at=args.callback_at)
criterion = import_attr(args.criterion, nn, models, callback_at=args.callback_at)
criterion = criterion()
criterion.to(device)
@ -144,11 +155,13 @@ def gpu_worker(local_rank, node, args):
lr=args.lr,
**args.optimizer_args,
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **args.scheduler_args)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
if (
args.load_state == ckpt_link
and not os.path.isfile(ckpt_link)
or not args.load_state
):
if args.init_weight_std is not None:
model.apply(init_weights)
@ -159,23 +172,28 @@ def gpu_worker(local_rank, node, args):
else:
state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch']
start_epoch = state["epoch"]
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
load_model_state_dict(
model.module, state["model"], strict=args.load_state_strict
)
if 'optimizer' in state:
optimizer.load_state_dict(state['optimizer'])
if 'scheduler' in state:
scheduler.load_state_dict(state['scheduler'])
if "optimizer" in state:
optimizer.load_state_dict(state["optimizer"])
if "scheduler" in state:
scheduler.load_state_dict(state["scheduler"])
torch.set_rng_state(state['rng'].cpu()) # move rng state back
torch.set_rng_state(state["rng"].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
min_loss = state["min_loss"]
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
print(
"state at epoch {} loaded from {}".format(
state["epoch"], args.load_state
),
flush=True,
)
del state
@ -189,21 +207,31 @@ def gpu_worker(local_rank, node, args):
logger = SummaryWriter()
if rank == 0:
print('pytorch {}'.format(torch.__version__))
print("pytorch {}".format(torch.__version__))
pprint(vars(args))
sys.stdout.flush()
for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, criterion,
optimizer, scheduler, logger, device, args)
train_loss = train(
epoch,
train_loader,
model,
criterion,
optimizer,
scheduler,
logger,
device,
args,
)
epoch_loss = train_loss
if args.val:
val_loss = validate(epoch, val_loader, model, criterion,
logger, device, args)
#epoch_loss = val_loss
val_loss = validate(
epoch, val_loader, model, criterion, logger, device, args
)
# epoch_loss = val_loss
if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[2])
@ -215,27 +243,26 @@ def gpu_worker(local_rank, node, args):
min_loss = epoch_loss[2]
state = {
'epoch': epoch + 1,
'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'rng': torch.get_rng_state(),
'min_loss': min_loss,
"epoch": epoch + 1,
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"rng": torch.get_rng_state(),
"min_loss": min_loss,
}
state_file = 'state_{}.pt'.format(epoch + 1)
state_file = "state_{}.pt".format(epoch + 1)
torch.save(state, state_file)
del state
tmp_link = '{}.pt'.format(time.time())
tmp_link = "{}.pt".format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
dist.destroy_process_group()
def train(epoch, loader, model, criterion,
optimizer, scheduler, logger, device, args):
def train(epoch, loader, model, criterion, optimizer, scheduler, logger, device, args):
model.train()
rank = dist.get_rank()
@ -246,7 +273,7 @@ def train(epoch, loader, model, criterion,
for i, data in enumerate(loader):
batch = epoch * len(loader) + i + 1
style, input, target = data['style'], data['input'], data['target']
style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
@ -254,23 +281,21 @@ def train(epoch, loader, model, criterion,
output = model(input, style)
if batch <= 5 and rank == 0:
print('##### batch :', batch)
print('style shape :', style.shape)
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
print("##### batch :", batch)
print("style shape :", style.shape)
print("input shape :", input.shape)
print("output shape :", output.shape)
print("target shape :", target.shape)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
if batch <= 5 and rank == 0:
print('narrowed shape :', output.shape)
print("narrowed shape :", output.shape)
loss = criterion(output, target)
epoch_loss[0] += loss.detach()
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -280,43 +305,45 @@ def train(epoch, loader, model, criterion,
dist.all_reduce(loss)
loss /= world_size
if rank == 0:
logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch)
logger.add_scalar("loss/batch/train", loss.item(), global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch)
logger.add_scalar("grad/first", grads[0], global_step=batch)
logger.add_scalar("grad/last", grads[-1], global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar("loss/epoch/train", epoch_loss[0], global_step=epoch + 1)
skip_chan = 0
fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'out', 'tgt', 'out - tgt'],
title=["in", "out", "tgt", "out - tgt"],
**args.misc_kwargs,
)
logger.add_figure('fig/train', fig, global_step=epoch+1)
logger.add_figure("fig/train", fig, global_step=epoch + 1)
fig.clf()
fig = plt_power(
input, output[:, skip_chan:], target[:, skip_chan:],
label=['in', 'out', 'tgt'],
input,
output[:, skip_chan:],
target[:, skip_chan:],
label=["in", "out", "tgt"],
**args.misc_kwargs,
)
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
logger.add_figure("fig/train/power/lag", fig, global_step=epoch + 1)
fig.clf()
#fig = plt_power(1.0,
# fig = plt_power(1.0,
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
# )
# logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
# fig.clf()
return epoch_loss
@ -331,7 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
with torch.no_grad():
for data in loader:
style, input, target = data['style'], data['input'], data['target']
style, input, target = data["style"], data["input"], data["target"]
style = style.to(device, non_blocking=True)
input = input.to(device, non_blocking=True)
@ -339,8 +366,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
output = model(input, style)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1:
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
@ -350,40 +376,43 @@ def validate(epoch, loader, model, criterion, logger, device, args):
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar("loss/epoch/val", epoch_loss[0], global_step=epoch + 1)
skip_chan = 0
fig = plt_slices(
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'out', 'tgt', 'out - tgt'],
title=["in", "out", "tgt", "out - tgt"],
**args.misc_kwargs,
)
logger.add_figure('fig/val', fig, global_step=epoch+1)
logger.add_figure("fig/val", fig, global_step=epoch + 1)
fig.clf()
fig = plt_power(
input, output[:, skip_chan:], target[:, skip_chan:],
label=['in', 'out', 'tgt'],
input,
output[:, skip_chan:],
target[:, skip_chan:],
label=["in", "out", "tgt"],
**args.misc_kwargs,
)
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
logger.add_figure("fig/val/power/lag", fig, global_step=epoch + 1)
fig.clf()
#fig = plt_power(1.0,
# fig = plt_power(1.0,
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()
# )
# logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
# fig.clf()
return epoch_loss
def dist_init(rank, args):
dist_file = 'dist_addr'
dist_file = "dist_addr"
if rank == 0:
addr = socket.gethostname()
@ -393,15 +422,15 @@ def dist_init(rank, args):
s.bind((addr, 0))
_, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
args.dist_addr = "tcp://{}:{}".format(addr, port)
with open(dist_file, mode='w') as f:
with open(dist_file, mode="w") as f:
f.write(args.dist_addr)
else:
while not os.path.exists(dist_file):
time.sleep(1)
with open(dist_file, mode='r') as f:
with open(dist_file, mode="r") as f:
args.dist_addr = f.read()
dist.init_process_group(
@ -417,19 +446,40 @@ def dist_init(rank, args):
def init_weights(m, args):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if isinstance(
m,
(
nn.Linear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
),
):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
elif isinstance(
m,
(
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
),
):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
def set_requires_grad(module: torch.nn.Module, requires_grad: bool = False):
for param in module.parameters():
param.requires_grad = requires_grad
@ -444,8 +494,7 @@ def get_grads(model: torch.nn.Module):
Returns:
A list containing the norms of the gradients of the first and the last layer weights.
"""
grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n)
grads = list(p.grad for n, p in model.named_parameters() if ".weight" in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm() for g in grads]
return grads

View file

@ -2,11 +2,13 @@ from math import log2, log10, ceil
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable
plt.rc('text', usetex=False)
plt.rc("text", usetex=False)
from ..models import lag2eul, power
@ -21,7 +23,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
Each field should have a channel dimension followed by spatial dimensions,
i.e. no batch dimension.
"""
plt.close('all')
plt.close("all")
assert all(isinstance(field, torch.Tensor) for field in fields)
@ -38,11 +40,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
im_size = 2
cbar_height = 0.2
fig, axes = plt.subplots(
nc + 1, nf,
nc + 1,
nf,
squeeze=False,
figsize=(nf * im_size, nc * im_size + cbar_height),
dpi=100,
gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]}
gridspec_kw={"height_ratios": nc * [im_size] + [cbar_height]},
)
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
@ -51,11 +54,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
if cmap_col is None:
if all_non_neg:
cmap_col = 'inferno'
cmap_col = "inferno"
elif all_non_pos:
cmap_col = 'inferno_r'
cmap_col = "inferno_r"
else:
cmap_col = 'RdBu_r'
cmap_col = "RdBu_r"
if norm_col is None:
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
@ -70,9 +73,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
if l1 < 0.1 * l2 or h2 == 0:
norm_col = Normalize(vmin=-quantize(-l2), vmax=0)
else:
norm_col = SymLogNorm(linthresh=quantize(-h2),
vmin=-quantize(-l2),
vmax=-quantize(-h2))
norm_col = SymLogNorm(
linthresh=quantize(-h2),
vmin=-quantize(-l2),
vmax=-quantize(-h2),
)
else:
vlim = quantize(max(-l2, h2))
if w1 > 0.1 * w2 or l1 * h1 >= 0:
@ -80,8 +85,13 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
else:
linthresh = quantize(min(-l1, h1))
linscale = np.log10(vlim / linthresh)
norm_col = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-vlim, vmax=vlim, base=10)
norm_col = SymLogNorm(
linthresh=linthresh,
linscale=linscale,
vmin=-vlim,
vmax=vlim,
base=10,
)
for c in range(field.shape[0]):
s = (c,) + tuple(d // 2 for d in field.shape[1:-2])
@ -101,7 +111,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
axes[c, f].set_aspect('equal')
axes[c, f].set_aspect("equal")
axes[c, f].set_xticks([])
axes[c, f].set_yticks([])
@ -110,12 +120,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
axes[c, f].set_title(title[f])
for c in range(field.shape[0], nc):
axes[c, f].axis('off')
axes[c, f].axis("off")
fig.colorbar(
ScalarMappable(norm=norm_col, cmap=cmap_col),
cax=axes[-1, f],
orientation='horizontal',
orientation="horizontal",
)
fig.tight_layout()
@ -133,7 +143,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
See `map2map.models.power`.
"""
plt.close('all')
plt.close("all")
if label is not None:
assert len(label) == len(fields) or len(label) == len(dis)
@ -159,8 +169,8 @@ def plt_power(*fields, dis=None, label=None, **kwargs):
axes.loglog(k, P, label=l, alpha=0.7)
axes.legend()
axes.set_xlabel('unnormalized wavenumber')
axes.set_ylabel('unnormalized power')
axes.set_xlabel("unnormalized wavenumber")
axes.set_ylabel("unnormalized power")
fig.tight_layout()

View file

@ -21,7 +21,7 @@ def import_attr(name, *pkgs, callback_at=None):
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
from 'path/to/cb_dir/mod.py'.
"""
if name.count('.') == 0:
if name.count(".") == 0:
attr = name
errors = []
@ -34,23 +34,22 @@ def import_attr(name, *pkgs, callback_at=None):
raise Exception(errors)
else:
mod, attr = name.rsplit('.', 1)
mod, attr = name.rsplit(".", 1)
errors = []
for pkg in pkgs:
try:
return getattr(
importlib.import_module(pkg.__name__ + '.' + mod), attr)
return getattr(importlib.import_module(pkg.__name__ + "." + mod), attr)
except (ModuleNotFoundError, AttributeError) as e:
errors.append(e)
if callback_at is None:
raise Exception(errors)
callback_at = os.path.join(callback_at, mod + '.py')
callback_at = os.path.join(callback_at, mod + ".py")
if not os.path.isfile(callback_at):
raise FileNotFoundError('callback file not found')
raise FileNotFoundError("callback file not found")
if mod in sys.modules:
return getattr(sys.modules[mod], attr)

View file

@ -7,9 +7,13 @@ def load_model_state_dict(module, state_dict, strict=True):
bad_keys = module.load_state_dict(state_dict, strict)
if len(bad_keys.missing_keys) > 0:
warnings.warn('Missing keys in state_dict:\n{}'.format(
pformat(bad_keys.missing_keys)))
warnings.warn(
"Missing keys in state_dict:\n{}".format(pformat(bad_keys.missing_keys))
)
if len(bad_keys.unexpected_keys) > 0:
warnings.warn('Unexpected keys in state_dict:\n{}'.format(
pformat(bad_keys.unexpected_keys)))
warnings.warn(
"Unexpected keys in state_dict:\n{}".format(
pformat(bad_keys.unexpected_keys)
)
)
sys.stderr.flush()