diff --git a/map2map/args.py b/map2map/args.py index 835517d..501df5b 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -7,18 +7,16 @@ from .train import ckpt_link def get_args(): - """Parse arguments and set runtime defaults. - """ - parser = argparse.ArgumentParser( - description='Transform field(s) to field(s)') + """Parse arguments and set runtime defaults.""" + parser = argparse.ArgumentParser(description="Transform field(s) to field(s)") - subparsers = parser.add_subparsers(title='modes', dest='mode', required=True) + subparsers = parser.add_subparsers(title="modes", dest="mode", required=True) train_parser = subparsers.add_parser( - 'train', + "train", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) test_parser = subparsers.add_parser( - 'test', + "test", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -27,163 +25,293 @@ def get_args(): args = parser.parse_args() - if args.mode == 'train': + if args.mode == "train": set_train_args(args) - elif args.mode == 'test': + elif args.mode == "test": set_test_args(args) return args def add_common_args(parser): - parser.add_argument('--in-norms', type=str_list, help='comma-sep. list ' - 'of input normalization functions') - parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list ' - 'of target normalization functions') - parser.add_argument('--crop', type=int_tuple, - help='size to crop the input and target data. Default is the ' - 'field size. Comma-sep. list of 1 or d integers') - parser.add_argument('--crop-start', type=int_tuple, - help='starting point of the first crop. Default is the origin. ' - 'Comma-sep. list of 1 or d integers') - parser.add_argument('--crop-stop', type=int_tuple, - help='stopping point of the last crop. Default is the opposite ' - 'corner to the origin. Comma-sep. list of 1 or d integers') - parser.add_argument('--crop-step', type=int_tuple, - help='spacing between crops. Default is the crop size. ' - 'Comma-sep. list of 1 or d integers') - parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple, - help='size to pad the input data beyond the crop size, assuming ' - 'periodic boundary condition. Comma-sep. list of 1, d, or dx2 ' - 'integers, to pad equally along all axes, symmetrically on each, ' - 'or by the specified size on every boundary, respectively') - parser.add_argument('--tgt-pad', default=0, type=int_tuple, - help='size to pad the target data beyond the crop size, assuming ' - 'periodic boundary condition, useful for super-resolution. ' - 'Comma-sep. list with the same format as --in-pad') - parser.add_argument('--scale-factor', default=1, type=int, - help='upsampling factor for super-resolution, in which case ' - 'crop and pad are sizes of the input resolution') + parser.add_argument( + "--in-norms", + type=str_list, + help="comma-sep. list " "of input normalization functions", + ) + parser.add_argument( + "--tgt-norms", + type=str_list, + help="comma-sep. list " "of target normalization functions", + ) + parser.add_argument( + "--crop", + type=int_tuple, + help="size to crop the input and target data. Default is the " + "field size. Comma-sep. list of 1 or d integers", + ) + parser.add_argument( + "--crop-start", + type=int_tuple, + help="starting point of the first crop. Default is the origin. " + "Comma-sep. list of 1 or d integers", + ) + parser.add_argument( + "--crop-stop", + type=int_tuple, + help="stopping point of the last crop. Default is the opposite " + "corner to the origin. Comma-sep. list of 1 or d integers", + ) + parser.add_argument( + "--crop-step", + type=int_tuple, + help="spacing between crops. Default is the crop size. " + "Comma-sep. list of 1 or d integers", + ) + parser.add_argument( + "--in-pad", + "--pad", + default=0, + type=int_tuple, + help="size to pad the input data beyond the crop size, assuming " + "periodic boundary condition. Comma-sep. list of 1, d, or dx2 " + "integers, to pad equally along all axes, symmetrically on each, " + "or by the specified size on every boundary, respectively", + ) + parser.add_argument( + "--tgt-pad", + default=0, + type=int_tuple, + help="size to pad the target data beyond the crop size, assuming " + "periodic boundary condition, useful for super-resolution. " + "Comma-sep. list with the same format as --in-pad", + ) + parser.add_argument( + "--scale-factor", + default=1, + type=int, + help="upsampling factor for super-resolution, in which case " + "crop and pad are sizes of the input resolution", + ) - parser.add_argument('--model', type=str, required=True, - help='(generator) model') - parser.add_argument('--criterion', default='MSELoss', type=str, - help='loss function') - parser.add_argument('--load-state', default=ckpt_link, type=str, - help='path to load the states of model, optimizer, rng, etc. ' - 'Default is the checkpoint. ' - 'Start from scratch in case of empty string or missing checkpoint') - parser.add_argument('--load-state-non-strict', action='store_false', - help='allow incompatible keys when loading model states', - dest='load_state_strict') + parser.add_argument("--model", type=str, required=True, help="(generator) model") + parser.add_argument( + "--criterion", default="MSELoss", type=str, help="loss function" + ) + parser.add_argument( + "--load-state", + default=ckpt_link, + type=str, + help="path to load the states of model, optimizer, rng, etc. " + "Default is the checkpoint. " + "Start from scratch in case of empty string or missing checkpoint", + ) + parser.add_argument( + "--load-state-non-strict", + action="store_false", + help="allow incompatible keys when loading model states", + dest="load_state_strict", + ) # somehow I named it "batches" instead of batch_size at first # "batches" is kept for now for backward compatibility - parser.add_argument('--batch-size', '--batches', type=int, required=True, - help='mini-batch size, per GPU in training or in total in testing') - parser.add_argument('--loader-workers', default=8, type=int, - help='number of subprocesses per data loader. ' - '0 to disable multiprocessing') + parser.add_argument( + "--batch-size", + "--batches", + type=int, + required=True, + help="mini-batch size, per GPU in training or in total in testing", + ) + parser.add_argument( + "--loader-workers", + default=8, + type=int, + help="number of subprocesses per data loader. " "0 to disable multiprocessing", + ) - parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s), - help='directory of custorm code defining callbacks for models, ' - 'norms, criteria, and optimizers. Disabled if not set. ' - 'This is appended to the default locations, ' - 'thus has the lowest priority') - parser.add_argument('--misc-kwargs', default='{}', type=json.loads, - help='miscellaneous keyword arguments for custom models and ' - 'norms. Be careful with name collisions') + parser.add_argument( + "--callback-at", + type=lambda s: os.path.abspath(s), + help="directory of custorm code defining callbacks for models, " + "norms, criteria, and optimizers. Disabled if not set. " + "This is appended to the default locations, " + "thus has the lowest priority", + ) + parser.add_argument( + "--misc-kwargs", + default="{}", + type=json.loads, + help="miscellaneous keyword arguments for custom models and " + "norms. Be careful with name collisions", + ) def add_train_args(parser): add_common_args(parser) - parser.add_argument('--train-style-pattern', type=str, required=True, - help='glob pattern for training data styles') - parser.add_argument('--train-in-patterns', type=str_list, required=True, - help='comma-sep. list of glob patterns for training input data') - parser.add_argument('--train-tgt-patterns', type=str_list, required=True, - help='comma-sep. list of glob patterns for training target data') - parser.add_argument('--val-style-pattern', type=str, - help='glob pattern for validation data styles') - parser.add_argument('--val-in-patterns', type=str_list, - help='comma-sep. list of glob patterns for validation input data') - parser.add_argument('--val-tgt-patterns', type=str_list, - help='comma-sep. list of glob patterns for validation target data') - parser.add_argument('--augment', action='store_true', - help='enable data augmentation of axis flipping and permutation') - parser.add_argument('--aug-shift', type=int_tuple, - help='data augmentation by shifting cropping by [0, aug_shift) pixels, ' - 'useful for models that treat neighboring pixels differently, ' - 'e.g. with strided convolutions. ' - 'Comma-sep. list of 1 or d integers') - parser.add_argument('--aug-add', type=float, - help='additive data augmentation, (normal) std, ' - 'same factor for all fields') - parser.add_argument('--aug-mul', type=float, - help='multiplicative data augmentation, (log-normal) std, ' - 'same factor for all fields') + parser.add_argument( + "--train-style-pattern", + type=str, + required=True, + help="glob pattern for training data styles", + ) + parser.add_argument( + "--train-in-patterns", + type=str_list, + required=True, + help="comma-sep. list of glob patterns for training input data", + ) + parser.add_argument( + "--train-tgt-patterns", + type=str_list, + required=True, + help="comma-sep. list of glob patterns for training target data", + ) + parser.add_argument( + "--val-style-pattern", type=str, help="glob pattern for validation data styles" + ) + parser.add_argument( + "--val-in-patterns", + type=str_list, + help="comma-sep. list of glob patterns for validation input data", + ) + parser.add_argument( + "--val-tgt-patterns", + type=str_list, + help="comma-sep. list of glob patterns for validation target data", + ) + parser.add_argument( + "--augment", + action="store_true", + help="enable data augmentation of axis flipping and permutation", + ) + parser.add_argument( + "--aug-shift", + type=int_tuple, + help="data augmentation by shifting cropping by [0, aug_shift) pixels, " + "useful for models that treat neighboring pixels differently, " + "e.g. with strided convolutions. " + "Comma-sep. list of 1 or d integers", + ) + parser.add_argument( + "--aug-add", + type=float, + help="additive data augmentation, (normal) std, " "same factor for all fields", + ) + parser.add_argument( + "--aug-mul", + type=float, + help="multiplicative data augmentation, (log-normal) std, " + "same factor for all fields", + ) - parser.add_argument('--optimizer', default='Adam', type=str, - help='optimization algorithm') - parser.add_argument('--lr', type=float, required=True, - help='initial learning rate') - parser.add_argument('--optimizer-args', default='{}', type=json.loads, - help='optimizer arguments in addition to the learning rate, ' - 'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'') - parser.add_argument('--reduce-lr-on-plateau', action='store_true', - help='Enable ReduceLROnPlateau learning rate scheduler') - parser.add_argument('--scheduler-args', default='{"verbose": true}', - type=json.loads, - help='arguments for the ReduceLROnPlateau scheduler') - parser.add_argument('--init-weight-std', type=float, - help='weight initialization std') - parser.add_argument('--epochs', default=128, type=int, - help='total number of epochs to run') - parser.add_argument('--seed', default=42, type=int, - help='seed for initializing training') + parser.add_argument( + "--optimizer", default="Adam", type=str, help="optimization algorithm" + ) + parser.add_argument("--lr", type=float, required=True, help="initial learning rate") + parser.add_argument( + "--optimizer-args", + default="{}", + type=json.loads, + help="optimizer arguments in addition to the learning rate, " + "e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'", + ) + parser.add_argument( + "--reduce-lr-on-plateau", + action="store_true", + help="Enable ReduceLROnPlateau learning rate scheduler", + ) + parser.add_argument( + "--scheduler-args", + default='{"verbose": true}', + type=json.loads, + help="arguments for the ReduceLROnPlateau scheduler", + ) + parser.add_argument( + "--init-weight-std", type=float, help="weight initialization std" + ) + parser.add_argument( + "--epochs", default=128, type=int, help="total number of epochs to run" + ) + parser.add_argument( + "--seed", default=42, type=int, help="seed for initializing training" + ) - parser.add_argument('--div-data', action='store_true', - help='enable data division among GPUs for better page caching. ' - 'Data division is shuffled every epoch. ' - 'Only relevant if there are multiple crops in each field') - parser.add_argument('--div-shuffle-dist', default=1, type=float, - help='distance to further shuffle cropped samples relative to ' - 'their fields, to be used with --div-data. ' - 'Only relevant if there are multiple crops in each file. ' - 'The order of each sample is randomly displaced by this value. ' - 'Setting it to 0 turn off this randomization, and setting it to N ' - 'limits the shuffling within a distance of N files. ' - 'Change this to balance cache locality and stochasticity') - parser.add_argument('--dist-backend', default='nccl', type=str, - choices=['gloo', 'nccl'], help='distributed backend') - parser.add_argument('--log-interval', default=100, type=int, - help='interval (batches) between logging training loss') - parser.add_argument('--detect-anomaly', action='store_true', - help='enable anomaly detection for the autograd engine') + parser.add_argument( + "--div-data", + action="store_true", + help="enable data division among GPUs for better page caching. " + "Data division is shuffled every epoch. " + "Only relevant if there are multiple crops in each field", + ) + parser.add_argument( + "--div-shuffle-dist", + default=1, + type=float, + help="distance to further shuffle cropped samples relative to " + "their fields, to be used with --div-data. " + "Only relevant if there are multiple crops in each file. " + "The order of each sample is randomly displaced by this value. " + "Setting it to 0 turn off this randomization, and setting it to N " + "limits the shuffling within a distance of N files. " + "Change this to balance cache locality and stochasticity", + ) + parser.add_argument( + "--dist-backend", + default="nccl", + type=str, + choices=["gloo", "nccl"], + help="distributed backend", + ) + parser.add_argument( + "--log-interval", + default=100, + type=int, + help="interval (batches) between logging training loss", + ) + parser.add_argument( + "--detect-anomaly", + action="store_true", + help="enable anomaly detection for the autograd engine", + ) def add_test_args(parser): add_common_args(parser) - parser.add_argument('--test-style-pattern', type=str, required=True, - help='glob pattern for test data styles') - parser.add_argument('--test-in-patterns', type=str_list, required=True, - help='comma-sep. list of glob patterns for test input data') - parser.add_argument('--test-tgt-patterns', type=str_list, required=True, - help='comma-sep. list of glob patterns for test target data') + parser.add_argument( + "--test-style-pattern", + type=str, + required=True, + help="glob pattern for test data styles", + ) + parser.add_argument( + "--test-in-patterns", + type=str_list, + required=True, + help="comma-sep. list of glob patterns for test input data", + ) + parser.add_argument( + "--test-tgt-patterns", + type=str_list, + required=True, + help="comma-sep. list of glob patterns for test target data", + ) - parser.add_argument('--num-threads', type=int, - help='number of CPU threads when cuda is unavailable. ' - 'Default is the number of CPUs on the node by slurm') + parser.add_argument( + "--num-threads", + type=int, + help="number of CPU threads when cuda is unavailable. " + "Default is the number of CPUs on the node by slurm", + ) def str_list(s): - return s.split(',') + return s.split(",") def int_tuple(s): - t = s.split(',') + t = s.split(",") t = tuple(int(i) for i in t) if len(t) == 1: return t[0] @@ -198,8 +326,7 @@ def set_common_args(args): def set_train_args(args): set_common_args(args) - args.val = args.val_in_patterns is not None and \ - args.val_tgt_patterns is not None + args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None def set_test_args(args): diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 65a9867..055d957 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -43,54 +43,72 @@ class FieldDataset(Dataset): the input for super-resolution, in which case `crop` and `pad` are sizes of the input resolution. """ - def __init__(self, style_pattern, in_patterns, tgt_patterns, - in_norms=None, tgt_norms=None, callback_at=None, - augment=False, aug_shift=None, aug_add=None, aug_mul=None, - crop=None, crop_start=None, crop_stop=None, crop_step=None, - in_pad=0, tgt_pad=0, scale_factor=1, - **kwargs): + + def __init__( + self, + style_pattern, + in_patterns, + tgt_patterns, + in_norms=None, + tgt_norms=None, + callback_at=None, + augment=False, + aug_shift=None, + aug_add=None, + aug_mul=None, + crop=None, + crop_start=None, + crop_stop=None, + crop_step=None, + in_pad=0, + tgt_pad=0, + scale_factor=1, + **kwargs, + ): self.style_files = sorted(glob(style_pattern)) in_file_lists = [sorted(glob(p)) for p in in_patterns] - self.in_files = list(zip(* in_file_lists)) + self.in_files = list(zip(*in_file_lists)) tgt_file_lists = [sorted(glob(p)) for p in tgt_patterns] - self.tgt_files = list(zip(* tgt_file_lists)) - + self.tgt_files = list(zip(*tgt_file_lists)) + # self.style_files = self.style_files[:1] # self.in_files = self.in_files[:1] # self.tgt_files = self.tgt_files[:1] if len(self.style_files) != len(self.in_files) != len(self.tgt_files): - raise ValueError('number of style, input, and target files do not match') + raise ValueError("number of style, input, and target files do not match") self.nfile = len(self.in_files) if self.nfile == 0: - raise FileNotFoundError('file not found for {}'.format(in_patterns)) + raise FileNotFoundError("file not found for {}".format(in_patterns)) self.style_size = np.loadtxt(self.style_files[0], ndmin=1).shape[0] - self.in_chan = [np.load(f, mmap_mode='r').shape[0] - for f in self.in_files[0]] - self.tgt_chan = [np.load(f, mmap_mode='r').shape[0] - for f in self.tgt_files[0]] + self.in_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.in_files[0]] + self.tgt_chan = [np.load(f, mmap_mode="r").shape[0] for f in self.tgt_files[0]] - self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:] + self.size = np.load(self.in_files[0][0], mmap_mode="r").shape[1:] self.size = np.asarray(self.size) self.ndim = len(self.size) if in_norms is not None and len(in_patterns) != len(in_norms): - raise ValueError('numbers of input normalization functions and fields do not match') + raise ValueError( + "numbers of input normalization functions and fields do not match" + ) self.in_norms = in_norms if tgt_norms is not None and len(tgt_patterns) != len(tgt_norms): - raise ValueError('numbers of target normalization functions and fields do not match') + raise ValueError( + "numbers of target normalization functions and fields do not match" + ) self.tgt_norms = tgt_norms self.callback_at = callback_at self.augment = augment if self.ndim == 1 and self.augment: - raise ValueError('cannot augment 1D fields') + raise ValueError("cannot augment 1D fields") self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,)) self.aug_add = aug_add self.aug_mul = aug_mul @@ -116,10 +134,15 @@ class FieldDataset(Dataset): crop_step = np.broadcast_to(crop_step, (self.ndim,)) self.crop_step = crop_step - self.anchors = np.stack(np.mgrid[tuple( - slice(crop_start[d], crop_stop[d], crop_step[d]) - for d in range(self.ndim) - )], axis=-1).reshape(-1, self.ndim) + self.anchors = np.stack( + np.mgrid[ + tuple( + slice(crop_start[d], crop_stop[d], crop_step[d]) + for d in range(self.ndim) + ) + ], + axis=-1, + ).reshape(-1, self.ndim) self.ncrop = len(self.anchors) def format_pad(pad, ndim): @@ -130,15 +153,16 @@ class FieldDataset(Dataset): elif isinstance(pad, tuple) and len(pad) == ndim * 2: pad = np.array(pad) else: - raise ValueError('pad and ndim mismatch') + raise ValueError("pad and ndim mismatch") return pad.reshape(ndim, 2) + self.in_pad = format_pad(in_pad, self.ndim) self.tgt_pad = format_pad(tgt_pad, self.ndim) if scale_factor != 1: - tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:] + tgt_size = np.load(self.tgt_files[0][0], mmap_mode="r").shape[1:] if any(self.size * scale_factor != tgt_size): - raise ValueError('input size x scale factor != target size') + raise ValueError("input size x scale factor != target size") self.scale_factor = scale_factor self.nsample = self.nfile * self.ncrop @@ -148,9 +172,7 @@ class FieldDataset(Dataset): self.assembly_line = {} self.commonpath = os.path.commonpath( - file - for files in self.in_files[:2] + self.tgt_files[:2] - for file in files + file for files in self.in_files[:2] + self.tgt_files[:2] for file in files ) def __len__(self): @@ -180,12 +202,18 @@ class FieldDataset(Dataset): else: argsort_perm_axes = slice(None) - crop(in_fields, anchor, - self.crop[argsort_perm_axes], - self.in_pad[argsort_perm_axes]) - crop(tgt_fields, anchor * self.scale_factor, - self.crop[argsort_perm_axes] * self.scale_factor, - self.tgt_pad[argsort_perm_axes]) + crop( + in_fields, + anchor, + self.crop[argsort_perm_axes], + self.in_pad[argsort_perm_axes], + ) + crop( + tgt_fields, + anchor * self.scale_factor, + self.crop[argsort_perm_axes] * self.scale_factor, + self.tgt_pad[argsort_perm_axes], + ) style = torch.from_numpy(style).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] @@ -221,17 +249,19 @@ class FieldDataset(Dataset): in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) - #in_relpath = [os.path.relpath(file, start=self.commonpath) + # in_relpath = [os.path.relpath(file, start=self.commonpath) # for file in self.in_files[ifile]] - tgt_relpath = [os.path.relpath(file, start=self.commonpath) - for file in self.tgt_files[ifile]] + tgt_relpath = [ + os.path.relpath(file, start=self.commonpath) + for file in self.tgt_files[ifile] + ] return { - 'style': style, - 'input': in_fields, - 'target': tgt_fields, + "style": style, + "input": in_fields, + "target": tgt_fields, #'input_relpath': in_relpath, - 'target_relpath': tgt_relpath, + "target_relpath": tgt_relpath, } def assemble(self, label, chan, patches, paths): @@ -255,29 +285,29 @@ class FieldDataset(Dataset): if isinstance(patches, torch.Tensor): patches = patches.detach().cpu().numpy() - assert patches.ndim == 2 + self.ndim, 'ndim mismatch' + assert patches.ndim == 2 + self.ndim, "ndim mismatch" if any(self.crop_step > patches.shape[2:]): - raise RuntimeError('patch too small to tile') + raise RuntimeError("patch too small to tile") # the batched paths are a list of lists with shape (channel, batch) # since pytorch default_collate batches list of strings transposedly # therefore we transpose below back to (batch, channel) - assert patches.shape[1] == sum(chan), 'number of channels mismatch' - assert len(paths) == len(chan), 'number of fields mismatch' - paths = list(zip(* paths)) - assert patches.shape[0] == len(paths), 'batch size mismatch' + assert patches.shape[1] == sum(chan), "number of channels mismatch" + assert len(paths) == len(chan), "number of fields mismatch" + paths = list(zip(*paths)) + assert patches.shape[0] == len(paths), "batch size mismatch" patches = list(patches) if label in self.assembly_line: self.assembly_line[label] += patches - self.assembly_line[label + 'path'] += paths + self.assembly_line[label + "path"] += paths else: self.assembly_line[label] = patches - self.assembly_line[label + 'path'] = paths + self.assembly_line[label + "path"] = paths del patches, paths patches = self.assembly_line[label] - paths = self.assembly_line[label + 'path'] + paths = self.assembly_line[label + "path"] # NOTE anchor positioning assumes sufficient target padding and # symmetric narrowing (more on the right if odd) see `models/narrow.py` @@ -285,31 +315,26 @@ class FieldDataset(Dataset): anchors = self.anchors - self.tgt_pad[:, 0] + narrow // 2 while len(patches) >= self.ncrop: - fields = np.zeros(patches[0].shape[:1] + tuple(self.size), - patches[0].dtype) + fields = np.zeros(patches[0].shape[:1] + tuple(self.size), patches[0].dtype) for patch, anchor in zip(patches, anchors): fill(fields, patch, anchor) - for field, path in zip( - np.split(fields, np.cumsum(chan), axis=0), - paths[0]): - pathlib.Path(os.path.dirname(path)).mkdir(parents=True, - exist_ok=True) + for field, path in zip(np.split(fields, np.cumsum(chan), axis=0), paths[0]): + pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) path = label.join(os.path.splitext(path)) np.save(path, field) - del patches[:self.ncrop], paths[:self.ncrop] + del patches[: self.ncrop], paths[: self.ncrop] def fill(field, patch, anchor): ndim = len(anchor) - assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch' + assert field.ndim == patch.ndim == 1 + ndim, "ndim mismatch" ind = [slice(None)] - for d, (p, a, s) in enumerate(zip( - patch.shape[1:], anchor, field.shape[1:])): + for d, (p, a, s) in enumerate(zip(patch.shape[1:], anchor, field.shape[1:])): i = np.arange(a, a + p) i %= s i = i.reshape((-1,) + (1,) * (ndim - d - 1)) @@ -320,10 +345,10 @@ def fill(field, patch, anchor): def crop(fields, anchor, crop, pad): - assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch' + assert all(x.shape == fields[0].shape for x in fields), "shape mismatch" size = fields[0].shape[1:] ndim = len(size) - assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch' + assert ndim == len(anchor) == len(crop) == len(pad), "ndim mismatch" ind = [slice(None)] for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): @@ -342,7 +367,7 @@ def crop(fields, anchor, crop, pad): def flip(fields, axes, ndim): - assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors' + assert ndim > 1, "flipping is ambiguous for 1D scalars/vectors" if axes is None: axes = torch.randint(2, (ndim,), dtype=torch.bool) @@ -350,7 +375,7 @@ def flip(fields, axes, ndim): for i, x in enumerate(fields): if x.shape[0] == ndim: # flip vector components - x[axes] = - x[axes] + x[axes] = -x[axes] shifted_axes = (1 + axes).tolist() x = torch.flip(x, shifted_axes) @@ -361,7 +386,7 @@ def flip(fields, axes, ndim): def perm(fields, axes, ndim): - assert ndim > 1, 'permutation is not necessary for 1D fields' + assert ndim > 1, "permutation is not necessary for 1D fields" if axes is None: axes = torch.randperm(ndim) diff --git a/map2map/data/norms/cosmology.py b/map2map/data/norms/cosmology.py index a310875..be83a1d 100644 --- a/map2map/data/norms/cosmology.py +++ b/map2map/data/norms/cosmology.py @@ -10,6 +10,7 @@ def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs): x *= dis_norm + def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs): vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s] @@ -20,25 +21,28 @@ def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs): def D(z, Om=0.31): - """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero - """ + """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero""" OL = 1 - Om - a = 1 / (1+z) - return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \ - / hyp2f1(1, 1/3, 11/6, - OL / Om) + a = 1 / (1 + z) + return ( + a + * hyp2f1(1, 1 / 3, 11 / 6, -OL * a**3 / Om) + / hyp2f1(1, 1 / 3, 11 / 6, -OL / Om) + ) + def f(z, Om=0.31): - """linear growth rate for flat LambdaCDM - """ + """linear growth rate for flat LambdaCDM""" OL = 1 - Om - a = 1 / (1+z) + a = 1 / (1 + z) aa3 = OL * a**3 / Om - return 1 - 6/11*aa3 * hyp2f1(2, 4/3, 17/6, -aa3) \ - / hyp2f1(1, 1/3, 11/6, -aa3) + return 1 - 6 / 11 * aa3 * hyp2f1(2, 4 / 3, 17 / 6, -aa3) / hyp2f1( + 1, 1 / 3, 11 / 6, -aa3 + ) + def H(z, Om=0.31): - """Hubble in [h km/s/Mpc] for flat LambdaCDM - """ + """Hubble in [h km/s/Mpc] for flat LambdaCDM""" OL = 1 - Om - a = 1 / (1+z) + a = 1 / (1 + z) return 100 * np.sqrt(Om / a**3 + OL) diff --git a/map2map/data/norms/torch.py b/map2map/data/norms/torch.py index 64a1ffa..407c3f1 100644 --- a/map2map/data/norms/torch.py +++ b/map2map/data/norms/torch.py @@ -7,18 +7,21 @@ def exp(x, undo=False, **kwargs): else: torch.log(x, out=x) + def log(x, eps=1e-8, undo=False, **kwargs): if not undo: torch.log(x + eps, out=x) else: torch.exp(x, out=x) + def expm1(x, undo=False, **kwargs): if not undo: torch.expm1(x, out=x) else: torch.log1p(x, out=x) + def log1p(x, eps=1e-7, undo=False, **kwargs): if not undo: torch.log1p(x + eps, out=x) diff --git a/map2map/data/sampler.py b/map2map/data/sampler.py index ef5d5d7..f525efd 100644 --- a/map2map/data/sampler.py +++ b/map2map/data/sampler.py @@ -22,8 +22,8 @@ class DistFieldSampler(Sampler): Like `DistributedSampler`, `set_epoch()` should be called at the beginning of each epoch during training. """ - def __init__(self, dataset, shuffle, - div_data=False, div_shuffle_dist=0): + + def __init__(self, dataset, shuffle, div_data=False, div_shuffle_dist=0): self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -50,8 +50,10 @@ class DistFieldSampler(Sampler): ind = ind.flatten() # displace crops with respect to files - dis = torch.rand((self.nfile, self.ncrop), - generator=g) * self.div_shuffle_dist + dis = ( + torch.rand((self.nfile, self.ncrop), generator=g) + * self.div_shuffle_dist + ) loc = torch.arange(self.nfile) loc = loc[:, None] + dis loc = loc.flatten() % self.nfile # periodic in files diff --git a/map2map/main.py b/map2map/main.py index a31643c..bd88e98 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -3,6 +3,7 @@ from . import test import click import os import yaml + try: from yaml import CLoader as Loader except ImportError: @@ -12,20 +13,24 @@ import importlib.resources import json from functools import partial + def _load_resource_file(resource_path): # Import the package pkg_files = importlib.resources.files() / resource_path with pkg_files.open() as file: - return file.read() # Read the file and return its content + return file.read() # Read the file and return its content + def _str_list(value): - return value.split(',') + return value.split(",") + def _int_tuple(value): - t = value.split(',') + t = value.split(",") t = tuple(int(i) for i in t) return t + class VariadicType(click.ParamType): """ A custom parameter type for Click command-line interface. @@ -54,14 +59,14 @@ class VariadicType(click.ParamType): if typename in self._mapper: self._type = self._mapper[typename] elif type(typename) == dict: - self._type = self._mapper[typename["type"]] + self._type = self._mapper[typename["type"]] self.args = typename["opts"] else: - raise ValueError(f"Unknown type: {typename}") + raise ValueError(f"Unknown type: {typename}") self._typename = typename self.name = self._type["type"] if "func" not in self._type: - self._type["func"] = eval(self._type['type']) + self._type["func"] = eval(self._type["type"]) def convert(self, value, param, ctx): try: @@ -69,24 +74,26 @@ class VariadicType(click.ParamType): except Exception as e: self.fail(f"Could not parse {self._typename}: {e}", param, ctx) + def _apply_options(options_file, f): common_args = yaml.load(_load_resource_file(options_file), Loader=Loader) - common_args = common_args['arguments'] + common_args = common_args["arguments"] for arg in common_args: argopt = common_args[arg] - if 'type' in argopt: - if type(argopt['type']) == dict and argopt['type']['type'] == 'choice': - argopt['type'] = click.Choice(argopt['type']['opts']) + if "type" in argopt: + if type(argopt["type"]) == dict and argopt["type"]["type"] == "choice": + argopt["type"] = click.Choice(argopt["type"]["opts"]) else: - argopt['type'] = VariadicType(argopt['type']) - f = click.option(f'--{arg}', **argopt)(f) + argopt["type"] = VariadicType(argopt["type"]) + f = click.option(f"--{arg}", **argopt)(f) else: - f = click.option(f'--{arg}', **argopt)(f) + f = click.option(f"--{arg}", **argopt)(f) return f -m2m_options=partial(_apply_options,"common_args.yaml") + +m2m_options = partial(_apply_options, "common_args.yaml") @click.group() @@ -94,23 +101,26 @@ m2m_options=partial(_apply_options,"common_args.yaml") @click.pass_context def main(ctx, config): if config is not None and os.path.exists(config): - with open(config, 'r') as f: + with open(config, "r") as f: config = yaml.load(f.read(), Loader=Loader) ctx.default_map = config + # Make a class that provides access to dict with the attribute mechanism class DictProxy: def __init__(self, d): self.__dict__ = d + @main.command() @m2m_options @partial(_apply_options, "train_args.yaml") def train(**kwargs): train.node_worker(DictProxy(kwargs)) + @main.command() @m2m_options @partial(_apply_options, "test_args.yaml") def test(**kwargs): - test.test(DictProxy(kwargs)) + test.test(DictProxy(kwargs)) diff --git a/map2map/models/adversary.py b/map2map/models/adversary.py index 4a20509..840aa7c 100644 --- a/map2map/models/adversary.py +++ b/map2map/models/adversary.py @@ -5,6 +5,7 @@ def adv_model_wrapper(module): """Wrap an adversary model to also take lists of Tensors as input, to be concatenated along the batch dimension """ + class _new_module(module): def forward(self, x): if not isinstance(x, torch.Tensor): @@ -22,6 +23,7 @@ def adv_criterion_wrapper(module): * expand target shape as that of input * return a list of losses, one for each pair of input and target Tensors """ + class _new_module(module): def forward(self, input, target): assert isinstance(input, torch.Tensor) @@ -35,8 +37,9 @@ def adv_criterion_wrapper(module): target = [t.expand_as(i) for i, t in zip(input, target)] - loss = [super(new_module, self).forward(i, t) - for i, t in zip(input, target)] + loss = [ + super(new_module, self).forward(i, t) for i, t in zip(input, target) + ] return loss diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 7d8fc9f..be4b1e4 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -15,8 +15,10 @@ class ConvBlock(nn.Module): 'U': upsampling transposed convolution of kernel size 2 and stride 2 'D': downsampling convolution of kernel size 2 and stride 2 """ - def __init__(self, in_chan, out_chan=None, mid_chan=None, - kernel_size=3, stride=1, seq='CBA'): + + def __init__( + self, in_chan, out_chan=None, mid_chan=None, kernel_size=3, stride=1, seq="CBA" + ): super().__init__() if out_chan is None: @@ -31,31 +33,30 @@ class ConvBlock(nn.Module): self.norm_chan = in_chan self.idx_conv = 0 - self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) + self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]]) layers = [self._get_layer(l) for l in seq] self.convs = nn.Sequential(*layers) def _get_layer(self, l): - if l == 'U': + if l == "U": in_chan, out_chan = self._setup_conv() return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2) - elif l == 'D': + elif l == "D": in_chan, out_chan = self._setup_conv() return nn.Conv3d(in_chan, out_chan, 2, stride=2) - elif l == 'C': + elif l == "C": in_chan, out_chan = self._setup_conv() - return nn.Conv3d(in_chan, out_chan, self.kernel_size, - stride=self.stride) - elif l == 'B': + return nn.Conv3d(in_chan, out_chan, self.kernel_size, stride=self.stride) + elif l == "B": return nn.BatchNorm3d(self.norm_chan) - #return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) - #return nn.InstanceNorm3d(self.norm_chan) - elif l == 'A': + # return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) + # return nn.InstanceNorm3d(self.norm_chan) + elif l == "A": return nn.LeakyReLU() else: - raise ValueError('layer type {} not supported'.format(l)) + raise ValueError("layer type {} not supported".format(l)) def _setup_conv(self): self.idx_conv += 1 @@ -88,13 +89,22 @@ class ResBlock(ConvBlock): See `ConvBlock` for `seq` types. """ - def __init__(self, in_chan, out_chan=None, mid_chan=None, - kernel_size=3, stride=1, seq='CBACBA', last_act=None): + + def __init__( + self, + in_chan, + out_chan=None, + mid_chan=None, + kernel_size=3, + stride=1, + seq="CBACBA", + last_act=None, + ): if last_act is None: - last_act = seq[-1] == 'A' - elif last_act and seq[-1] != 'A': + last_act = seq[-1] == "A" + elif last_act and seq[-1] != "A": warnings.warn( - 'Disabling last_act without trailing activation in seq', + "Disabling last_act without trailing activation in seq", RuntimeWarning, ) last_act = False @@ -102,8 +112,14 @@ class ResBlock(ConvBlock): if last_act: seq = seq[:-1] - super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, - kernel_size=kernel_size, stride=stride, seq=seq) + super().__init__( + in_chan, + out_chan=out_chan, + mid_chan=mid_chan, + kernel_size=kernel_size, + stride=stride, + seq=seq, + ) if last_act: self.act = nn.LeakyReLU() @@ -115,9 +131,10 @@ class ResBlock(ConvBlock): else: self.skip = nn.Conv3d(in_chan, out_chan, 1) - if 'U' in seq or 'D' in seq: - raise NotImplementedError('upsample and downsample layers ' - 'not supported yet') + if "U" in seq or "D" in seq: + raise NotImplementedError( + "upsample and downsample layers " "not supported yet" + ) def forward(self, x): y = x diff --git a/map2map/models/dice.py b/map2map/models/dice.py index 84e2d27..21e1c38 100644 --- a/map2map/models/dice.py +++ b/map2map/models/dice.py @@ -2,7 +2,7 @@ import torch.nn as nn class DiceLoss(nn.Module): - def __init__(self, eps=0.): + def __init__(self, eps=0.0): super().__init__() self.eps = eps @@ -10,7 +10,7 @@ class DiceLoss(nn.Module): return dice_loss(input, target, self.eps) -def dice_loss(input, target, eps=0.): +def dice_loss(input, target, eps=0.0): input = input.view(-1) target = target.view(-1) diff --git a/map2map/models/instance_noise.py b/map2map/models/instance_noise.py index 897d1ed..2bf7927 100644 --- a/map2map/models/instance_noise.py +++ b/map2map/models/instance_noise.py @@ -1,10 +1,11 @@ import torch + class InstanceNoise: - """Instance noise, with a linear decaying schedule - """ + """Instance noise, with a linear decaying schedule""" + def __init__(self, init_std, batches): - assert init_std >= 0, 'Noise std cannot be negative' + assert init_std >= 0, "Noise std cannot be negative" self.init_std = init_std self._std = init_std self.batches = batches diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index eaaec9c..8d5a2c0 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -5,17 +5,18 @@ from ..data.norms.cosmology import D def lag2eul( - dis, - val=1.0, - eul_scale_factor=1, - eul_pad=0, - rm_dis_mean=True, - periodic=False, - z=0.0, - dis_std=6.0, - boxsize=1000., - meshsize=512, - **kwargs): + dis, + val=1.0, + eul_scale_factor=1, + eul_pad=0, + rm_dis_mean=True, + periodic=False, + z=0.0, + dis_std=6.0, + boxsize=1000.0, + meshsize=512, + **kwargs, +): """Transform fields from Lagrangian description to Eulerian description Only works for 3d fields, output same mesh size as input. @@ -48,20 +49,19 @@ def lag2eul( if isinstance(val, (float, torch.Tensor)): val = [val] if len(dis) != len(val) and len(dis) != 1 and len(val) != 1: - raise ValueError('dis-val field mismatch') + raise ValueError("dis-val field mismatch") if any(d.dim() != 5 for d in dis): - raise NotImplementedError('only support 3d fields for now') + raise NotImplementedError("only support 3d fields for now") if any(d.shape[1] != 3 for d in dis): - raise ValueError('only support 3d displacement fields') + raise ValueError("only support 3d displacement fields") # common mean displacement of all inputs # if removed, fewer particles go outside of the box # common for all inputs so outputs are comparable in the same coords d_mean = 0 if rm_dis_mean: - d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) - for d in dis) / len(dis) + d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) for d in dis) / len(dis) out = [] if len(dis) == 1 and len(val) != 1: @@ -85,12 +85,15 @@ def lag2eul( pos = (d - d_mean) * dis_norm del d - pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor, - dtype=dtype, device=device)[:, None, None] - pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor, - dtype=dtype, device=device)[:, None] - pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor, - dtype=dtype, device=device) + pos[:, 0] += torch.arange( + 0, DHW[0] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device + )[:, None, None] + pos[:, 1] += torch.arange( + 0, DHW[1] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device + )[:, None] + pos[:, 2] += torch.arange( + 0, DHW[2] - 2 * eul_pad, eul_scale_factor, dtype=dtype, device=device + ) pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors @@ -118,8 +121,7 @@ def lag2eul( if periodic: torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) - ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] - ) * DHW[2] + tgtpos[n, 2] + ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]) * DHW[2] + tgtpos[n, 2] src = v[n] if not periodic: diff --git a/map2map/models/narrow.py b/map2map/models/narrow.py index eb1187a..3ccdb91 100644 --- a/map2map/models/narrow.py +++ b/map2map/models/narrow.py @@ -3,8 +3,7 @@ import torch.nn as nn def narrow_by(a, c): - """Narrow a by size c symmetrically on all edges. - """ + """Narrow a by size c symmetrically on all edges.""" ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2) return a[ind] diff --git a/map2map/models/patchgan.py b/map2map/models/patchgan.py index af24832..3721844 100644 --- a/map2map/models/patchgan.py +++ b/map2map/models/patchgan.py @@ -8,10 +8,10 @@ class PatchGAN(nn.Module): super().__init__() self.convs = nn.Sequential( - ConvBlock(in_chan, 32, seq='CA'), - ConvBlock(32, 64, seq='CBA'), - ConvBlock(64, 128, seq='CBA'), - ConvBlock(128, out_chan, seq='C'), + ConvBlock(in_chan, 32, seq="CA"), + ConvBlock(32, 64, seq="CBA"), + ConvBlock(64, 128, seq="CBA"), + ConvBlock(128, out_chan, seq="C"), ) def forward(self, x): @@ -19,23 +19,20 @@ class PatchGAN(nn.Module): class PatchGAN42(nn.Module): - """PatchGAN similar to the one in pix2pix - """ + """PatchGAN similar to the one in pix2pix""" + def __init__(self, in_chan, out_chan=1, **kwargs): super().__init__() self.convs = nn.Sequential( nn.Conv3d(in_chan, 64, 4, stride=2), nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(64, 128, 4, stride=2), nn.BatchNorm3d(128), nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(128, 256, 4, stride=2), nn.BatchNorm3d(256), nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(256, out_chan, 1), ) diff --git a/map2map/models/power.py b/map2map/models/power.py index 19e32b7..f3a6a59 100644 --- a/map2map/models/power.py +++ b/map2map/models/power.py @@ -26,8 +26,7 @@ def power(x: torch.Tensor): P = P.sum(dim=0) del x - k = [torch.arange(d, dtype=torch.float32, device=P.device) - for d in P.shape] + k = [torch.arange(d, dtype=torch.float32, device=P.device) for d in P.shape] k = [j - len(j) * (j > len(j) // 2) for j in k[:-1]] + [k[-1]] k = torch.meshgrid(*k) k = torch.stack(k, dim=0) @@ -49,9 +48,9 @@ def power(x: torch.Tensor): del kbin # drop k=0 mode and cut at kmax (smallest Nyquist) - k = k[1:1+kmax] - P = P[1:1+kmax] - N = N[1:1+kmax] + k = k[1 : 1 + kmax] + P = P[1 : 1 + kmax] + N = N[1 : 1 + kmax] k /= N P /= N diff --git a/map2map/models/resample.py b/map2map/models/resample.py index 308fc80..d424096 100644 --- a/map2map/models/resample.py +++ b/map2map/models/resample.py @@ -5,12 +5,11 @@ from .narrow import narrow_by def resample(x, scale_factor, narrow=True): - modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} + modes = {1: "linear", 2: "bilinear", 3: "trilinear"} ndim = x.dim() - 2 mode = modes[ndim] - x = F.interpolate(x, scale_factor=scale_factor, - mode=mode, align_corners=False) + x = F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=False) if scale_factor > 1 and narrow == True: edges = round(scale_factor) // 2 @@ -25,18 +24,20 @@ class Resampler(nn.Module): By default discard the inaccurate edges when upsampling. """ + def __init__(self, ndim, scale_factor, narrow=True): super().__init__() - modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'} + modes = {1: "linear", 2: "bilinear", 3: "trilinear"} self.mode = modes[ndim] self.scale_factor = scale_factor self.narrow = narrow def forward(self, x): - x = F.interpolate(x, scale_factor=self.scale_factor, - mode=self.mode, align_corners=False) + x = F.interpolate( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False + ) if self.scale_factor > 1 and self.narrow == True: edges = round(self.scale_factor) // 2 diff --git a/map2map/models/spectral_norm.py b/map2map/models/spectral_norm.py index b9b22b1..d5dd7b8 100644 --- a/map2map/models/spectral_norm.py +++ b/map2map/models/spectral_norm.py @@ -4,8 +4,18 @@ from torch.nn.utils import spectral_norm, remove_spectral_norm def add_spectral_norm(module): for name, child in module.named_children(): - if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, - nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + if isinstance( + child, + ( + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ), + ): setattr(module, name, spectral_norm(child)) else: add_spectral_norm(child) @@ -13,8 +23,18 @@ def add_spectral_norm(module): def rm_spectral_norm(module): for name, child in module.named_children(): - if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, - nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + if isinstance( + child, + ( + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ), + ): setattr(module, name, remove_spectral_norm(child)) else: rm_spectral_norm(child) diff --git a/map2map/models/srsgan.py b/map2map/models/srsgan.py index af8d9c9..1a09f09 100644 --- a/map2map/models/srsgan.py +++ b/map2map/models/srsgan.py @@ -7,9 +7,17 @@ from .resample import Resampler class G(nn.Module): - def __init__(self, in_chan, out_chan, scale_factor=16, - chan_base=512, chan_min=64, chan_max=512, cat_noise=False, - **kwargs): + def __init__( + self, + in_chan, + out_chan, + scale_factor=16, + chan_base=512, + chan_min=64, + chan_max=512, + cat_noise=False, + **kwargs, + ): super().__init__() self.scale_factor = scale_factor @@ -30,15 +38,14 @@ class G(nn.Module): self.blocks = nn.ModuleList() for b in range(num_blocks): - prev_chan, next_chan = chan(b), chan(b+1) - self.blocks.append( - HBlock(prev_chan, next_chan, out_chan, cat_noise)) + prev_chan, next_chan = chan(b), chan(b + 1) + self.blocks.append(HBlock(prev_chan, next_chan, out_chan, cat_noise)) def forward(self, x): y = x # direct upsampling from the input x = self.block0(x) - #y = None # no direct upsampling from the input + # y = None # no direct upsampling from the input for block in self.blocks: x, y = block(x, y) @@ -71,6 +78,7 @@ class HBlock(nn.Module): ----- next_size = 2 * prev_size - 6 """ + def __init__(self, prev_chan, next_chan, out_chan, cat_noise): super().__init__() @@ -81,7 +89,6 @@ class HBlock(nn.Module): self.upsample, nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3), nn.LeakyReLU(0.2, True), - AddNoise(cat_noise, chan=next_chan), nn.Conv3d(next_chan + int(cat_noise), next_chan, 3), nn.LeakyReLU(0.2, True), @@ -114,6 +121,7 @@ class AddNoise(nn.Module): The number of channels `chan` should be 1 (StyleGAN2) or that of the input (StyleGAN). """ + def __init__(self, cat, chan=1): super().__init__() @@ -137,9 +145,16 @@ class AddNoise(nn.Module): class D(nn.Module): - def __init__(self, in_chan, out_chan, scale_factor=16, - chan_base=512, chan_min=64, chan_max=512, - **kwargs): + def __init__( + self, + in_chan, + out_chan, + scale_factor=16, + chan_base=512, + chan_min=64, + chan_max=512, + **kwargs, + ): super().__init__() self.scale_factor = scale_factor @@ -163,7 +178,7 @@ class D(nn.Module): self.blocks = nn.ModuleList() for b in reversed(range(num_blocks)): - prev_chan, next_chan = chan(b+1), chan(b) + prev_chan, next_chan = chan(b + 1), chan(b) self.blocks.append(ResBlock(prev_chan, next_chan)) self.block9 = nn.Sequential( @@ -192,13 +207,13 @@ class ResBlock(nn.Module): ----- next_size = (prev_size - 4) // 2 """ + def __init__(self, prev_chan, next_chan): super().__init__() self.conv = nn.Sequential( nn.Conv3d(prev_chan, prev_chan, 3), nn.LeakyReLU(0.2, True), - nn.Conv3d(prev_chan, next_chan, 3), nn.LeakyReLU(0.2, True), ) diff --git a/map2map/models/style.py b/map2map/models/style.py index 24a8349..48a24dc 100644 --- a/map2map/models/style.py +++ b/map2map/models/style.py @@ -9,6 +9,7 @@ class PixelNorm(nn.Module): See ProGAN, StyleGAN. """ + def __init__(self): super().__init__() @@ -23,6 +24,7 @@ class LinearElr(nn.Module): Useful at all if not for regularization(1706.05350)? """ + def __init__(self, in_size, out_size, bias=True, act=None): super().__init__() @@ -32,7 +34,7 @@ class LinearElr(nn.Module): if bias: self.bias = nn.Parameter(torch.zeros(out_size)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.act = act @@ -52,20 +54,20 @@ class ConvElr3d(nn.Module): Useful at all if not for regularization(1706.05350)? """ - def __init__(self, in_chan, out_chan, kernel_size, - stride=1, padding=0, bias=True): + + def __init__(self, in_chan, out_chan, kernel_size, stride=1, padding=0, bias=True): super().__init__() self.weight = nn.Parameter( torch.randn(out_chan, in_chan, *(kernel_size,) * 3), ) - fan_in = in_chan * kernel_size ** 3 + fan_in = in_chan * kernel_size**3 self.wnorm = 1 / math.sqrt(fan_in) if bias: self.bias = nn.Parameter(torch.zeros(out_chan)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.stride = stride self.padding = padding @@ -87,38 +89,49 @@ class ConvStyled3d(nn.Module): Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`. """ - def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1, - bias=True, resample=None): + + def __init__( + self, + style_size, + in_chan, + out_chan, + kernel_size=3, + stride=1, + bias=True, + resample=None, + ): super().__init__() self.style_weight = nn.Parameter(torch.empty(in_chan, style_size)) - nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5), - mode='fan_in', nonlinearity='leaky_relu') + nn.init.kaiming_uniform_( + self.style_weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu" + ) self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1 if resample is None: K3 = (kernel_size,) * 3 self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) self.stride = stride self.conv = F.conv3d - elif resample == 'U': + elif resample == "U": K3 = (2,) * 3 # NOTE not clear to me why convtranspose have channels swapped self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3)) self.stride = 2 self.conv = F.conv_transpose3d - elif resample == 'D': + elif resample == "D": K3 = (2,) * 3 self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3)) self.stride = 2 self.conv = F.conv3d else: - raise ValueError('resample type {} not supported'.format(resample)) + raise ValueError("resample type {} not supported".format(resample)) self.resample = resample nn.init.kaiming_uniform_( - self.weight, a=math.sqrt(5), - mode='fan_in', # effectively 'fan_out' for 'D' - nonlinearity='leaky_relu', + self.weight, + a=math.sqrt(5), + mode="fan_in", # effectively 'fan_out' for 'D' + nonlinearity="leaky_relu", ) if bias: @@ -127,12 +140,12 @@ class ConvStyled3d(nn.Module): bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, x, s, eps=1e-8): N, Cin, *DHWin = x.shape C0, C1, *K3 = self.weight.shape - if self.resample == 'U': + if self.resample == "U": Cin, Cout = C0, C1 else: Cout, Cin = C0, C1 @@ -140,14 +153,14 @@ class ConvStyled3d(nn.Module): s = F.linear(s, self.style_weight, bias=self.style_bias) # modulation - if self.resample == 'U': + if self.resample == "U": s = s.reshape(N, Cin, 1, 1, 1, 1) else: s = s.reshape(N, 1, Cin, 1, 1, 1) w = self.weight * s # demodulation - if self.resample == 'U': + if self.resample == "U": fan_in_dim = (1, 3, 4, 5) else: fan_in_dim = (2, 3, 4, 5) @@ -161,18 +174,22 @@ class ConvStyled3d(nn.Module): return x -class BatchNormStyled3d(nn.BatchNorm3d) : - """ Trivially does standard batch normalization, but accepts second argument + +class BatchNormStyled3d(nn.BatchNorm3d): + """Trivially does standard batch normalization, but accepts second argument for style array that is not used """ + def forward(self, x, s): return super().forward(x) + class LeakyReLUStyled(nn.LeakyReLU): - """ Trivially evaluates standard leaky ReLU, but accepts second argument + """Trivially evaluates standard leaky ReLU, but accepts second argument for sytle array that is not used """ + def forward(self, x, s): return super().forward(x) diff --git a/map2map/models/styled_conv.py b/map2map/models/styled_conv.py index 2043537..fe2e8a7 100644 --- a/map2map/models/styled_conv.py +++ b/map2map/models/styled_conv.py @@ -16,8 +16,17 @@ class ConvStyledBlock(nn.Module): 'U': upsampling transposed convolution of kernel size 2 and stride 2 'D': downsampling convolution of kernel size 2 and stride 2 """ - def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, - kernel_size=3, stride=1, seq='CBA'): + + def __init__( + self, + style_size, + in_chan, + out_chan=None, + mid_chan=None, + kernel_size=3, + stride=1, + seq="CBA", + ): super().__init__() if out_chan is None: @@ -33,31 +42,34 @@ class ConvStyledBlock(nn.Module): self.norm_chan = in_chan self.idx_conv = 0 - self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']]) + self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]]) layers = [self._get_layer(l) for l in seq] self.convs = nn.ModuleList(layers) def _get_layer(self, l): - if l == 'U': + if l == "U": in_chan, out_chan = self._setup_conv() - return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, - resample = 'U') - elif l == 'D': + return ConvStyled3d( + self.style_size, in_chan, out_chan, 2, stride=2, resample="U" + ) + elif l == "D": in_chan, out_chan = self._setup_conv() - return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2, - resample = 'D') - elif l == 'C': + return ConvStyled3d( + self.style_size, in_chan, out_chan, 2, stride=2, resample="D" + ) + elif l == "C": in_chan, out_chan = self._setup_conv() - return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size, - stride=self.stride) - elif l == 'B': + return ConvStyled3d( + self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride + ) + elif l == "B": return BatchNormStyled3d(self.norm_chan) - elif l == 'A': + elif l == "A": return LeakyReLUStyled() else: - raise ValueError('layer type {} not supported'.format(l)) + raise ValueError("layer type {} not supported".format(l)) def _setup_conv(self): self.idx_conv += 1 @@ -92,13 +104,23 @@ class ResStyledBlock(ConvStyledBlock): See `ConvStyledBlock` for `seq` types. """ - def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, - kernel_size=3, stride=1, seq='CBACBA', last_act=None): + + def __init__( + self, + style_size, + in_chan, + out_chan=None, + mid_chan=None, + kernel_size=3, + stride=1, + seq="CBACBA", + last_act=None, + ): if last_act is None: - last_act = seq[-1] == 'A' - elif last_act and seq[-1] != 'A': + last_act = seq[-1] == "A" + elif last_act and seq[-1] != "A": warnings.warn( - 'Disabling last_act without trailing activation in seq', + "Disabling last_act without trailing activation in seq", RuntimeWarning, ) last_act = False @@ -106,8 +128,15 @@ class ResStyledBlock(ConvStyledBlock): if last_act: seq = seq[:-1] - super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, - kernel_size=kernel_size, stride=stride, seq=seq) + super().__init__( + style_size, + in_chan, + out_chan=out_chan, + mid_chan=mid_chan, + kernel_size=kernel_size, + stride=stride, + seq=seq, + ) if last_act: self.act = LeakyReLUStyled() @@ -119,9 +148,10 @@ class ResStyledBlock(ConvStyledBlock): else: self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1) - if 'U' in seq or 'D' in seq: - raise NotImplementedError('upsample and downsample layers ' - 'not supported yet') + if "U" in seq or "D" in seq: + raise NotImplementedError( + "upsample and downsample layers " "not supported yet" + ) def forward(self, x, s): y = x diff --git a/map2map/models/styled_vnet.py b/map2map/models/styled_vnet.py index 151aa59..990beeb 100644 --- a/map2map/models/styled_vnet.py +++ b/map2map/models/styled_vnet.py @@ -15,17 +15,17 @@ class StyledVNet(nn.Module): # activate non-identity skip connection in residual block # by explicitly setting out_chan - self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='CACBA') - self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA') - self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA') - self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA') + self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq="CACBA") + self.down_l0 = ConvStyledBlock(style_size, 64, seq="DBA") + self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq="CBACBA") + self.down_l1 = ConvStyledBlock(style_size, 64, seq="DBA") - self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') + self.conv_c = ResStyledBlock(style_size, 64, 64, seq="CBACBA") - self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA') - self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA') - self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') - self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') + self.up_r1 = ConvStyledBlock(style_size, 64, seq="UBA") + self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq="CBACBA") + self.up_r0 = ConvStyledBlock(style_size, 64, seq="UBA") + self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq="CAC") if bypass is None: self.bypass = in_chan == out_chan diff --git a/map2map/models/unet.py b/map2map/models/unet.py index bf78221..3b3eed1 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -19,17 +19,17 @@ class UNet(nn.Module): """ super().__init__() - self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA') - self.down_l0 = ConvBlock(64, seq='DBA') - self.conv_l1 = ConvBlock(64, seq='CBACBA') - self.down_l1 = ConvBlock(64, seq='DBA') + self.conv_l0 = ConvBlock(in_chan, 64, seq="CACBA") + self.down_l0 = ConvBlock(64, seq="DBA") + self.conv_l1 = ConvBlock(64, seq="CBACBA") + self.down_l1 = ConvBlock(64, seq="DBA") - self.conv_c = ConvBlock(64, seq='CBACBA') + self.conv_c = ConvBlock(64, seq="CBACBA") - self.up_r1 = ConvBlock(64, seq='UBA') - self.conv_r1 = ConvBlock(128, 64, seq='CBACBA') - self.up_r0 = ConvBlock(64, seq='UBA') - self.conv_r0 = ConvBlock(128, out_chan, seq='CAC') + self.up_r1 = ConvBlock(64, seq="UBA") + self.conv_r1 = ConvBlock(128, 64, seq="CBACBA") + self.up_r0 = ConvBlock(64, seq="UBA") + self.conv_r0 = ConvBlock(128, out_chan, seq="CAC") self.bypass = in_chan == out_chan diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index a2f1fee..0b2f4eb 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -23,17 +23,17 @@ class VNet(nn.Module): # activate non-identity skip connection in residual block # by explicitly setting out_chan - self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA') - self.down_l0 = ConvBlock(64, seq='DBA') - self.conv_l1 = ResBlock(64, 64, seq='CBACBA') - self.down_l1 = ConvBlock(64, seq='DBA') + self.conv_l0 = ResBlock(in_chan, 64, seq="CACBA") + self.down_l0 = ConvBlock(64, seq="DBA") + self.conv_l1 = ResBlock(64, 64, seq="CBACBA") + self.down_l1 = ConvBlock(64, seq="DBA") - self.conv_c = ResBlock(64, 64, seq='CBACBA') + self.conv_c = ResBlock(64, 64, seq="CBACBA") - self.up_r1 = ConvBlock(64, seq='UBA') - self.conv_r1 = ResBlock(128, 64, seq='CBACBA') - self.up_r0 = ConvBlock(64, seq='UBA') - self.conv_r0 = ResBlock(128, out_chan, seq='CAC') + self.up_r1 = ConvBlock(64, seq="UBA") + self.conv_r1 = ResBlock(128, 64, seq="CBACBA") + self.up_r0 = ConvBlock(64, seq="UBA") + self.conv_r0 = ResBlock(128, out_chan, seq="CAC") if bypass is None: self.bypass = in_chan == out_chan diff --git a/map2map/test.py b/map2map/test.py index 1d8e3f7..eb4aacc 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -16,21 +16,21 @@ from .utils import import_attr, load_model_state_dict def test(args): if torch.cuda.is_available(): if torch.cuda.device_count() > 1: - warnings.warn('Not parallelized but given more than 1 GPUs') + warnings.warn("Not parallelized but given more than 1 GPUs") - os.environ['CUDA_VISIBLE_DEVICES'] = '0' - device = torch.device('cuda', 0) + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + device = torch.device("cuda", 0) torch.backends.cudnn.benchmark = True else: # CPU multithreading - device = torch.device('cpu') + device = torch.device("cpu") if args.num_threads is None: - args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE']) + args.num_threads = int(os.environ["SLURM_CPUS_ON_NODE"]) torch.set_num_threads(args.num_threads) - print('pytorch {}'.format(torch.__version__)) + print("pytorch {}".format(torch.__version__)) pprint(vars(args)) sys.stdout.flush() @@ -67,26 +67,33 @@ def test(args): out_chan = test_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(style_size, sum(in_chan), sum(out_chan), - scale_factor=args.scale_factor, **args.misc_kwargs) + model = model( + style_size, + sum(in_chan), + sum(out_chan), + scale_factor=args.scale_factor, + **args.misc_kwargs, + ) model.to(device) - criterion = import_attr(args.criterion, torch.nn, models, - callback_at=args.callback_at) + criterion = import_attr( + args.criterion, torch.nn, models, callback_at=args.callback_at + ) criterion = criterion() criterion.to(device) state = torch.load(args.load_state, map_location=device) - load_model_state_dict(model, state['model'], strict=args.load_state_strict) - print('model state at epoch {} loaded from {}'.format( - state['epoch'], args.load_state)) + load_model_state_dict(model, state["model"], strict=args.load_state_strict) + print( + "model state at epoch {} loaded from {}".format(state["epoch"], args.load_state) + ) del state model.eval() with torch.no_grad(): for i, data in enumerate(test_loader): - style, input, target = data['style'], data['input'], data['target'] + style, input, target = data["style"], data["input"], data["target"] style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) @@ -94,21 +101,21 @@ def test(args): output = model(input, style) if i < 5: - print('##### sample :', i) - print('style shape :', style.shape) - print('input shape :', input.shape) - print('output shape :', output.shape) - print('target shape :', target.shape) + print("##### sample :", i) + print("style shape :", style.shape) + print("input shape :", input.shape) + print("output shape :", output.shape) + print("target shape :", target.shape) input, output, target = narrow_cast(input, output, target) if i < 5: - print('narrowed shape :', output.shape, flush=True) + print("narrowed shape :", output.shape, flush=True) loss = criterion(output, target) - print('sample {} loss: {}'.format(i, loss.item())) + print("sample {} loss: {}".format(i, loss.item())) - #if args.in_norms is not None: + # if args.in_norms is not None: # start = 0 # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): # norm = import_attr(norm, norms, callback_at=args.callback_at) @@ -119,12 +126,11 @@ def test(args): for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): norm = import_attr(norm, norms, callback_at=args.callback_at) norm(output[:, start:stop], undo=True, **args.misc_kwargs) - #norm(target[:, start:stop], undo=True, **args.misc_kwargs) + # norm(target[:, start:stop], undo=True, **args.misc_kwargs) start = stop - #test_dataset.assemble('_in', in_chan, input, + # test_dataset.assemble('_in', in_chan, input, # data['input_relpath']) - test_dataset.assemble('_out', out_chan, output, - data['target_relpath']) - #test_dataset.assemble('_tgt', out_chan, target, + test_dataset.assemble("_out", out_chan, output, data["target_relpath"]) + # test_dataset.assemble('_tgt', out_chan, target, # data['target_relpath']) diff --git a/map2map/train.py b/map2map/train.py index d4fb380..867049b 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -18,34 +18,34 @@ from .models import narrow_cast, resample from .utils import import_attr, load_model_state_dict, plt_slices, plt_power -ckpt_link = 'checkpoint.pt' +ckpt_link = "checkpoint.pt" def node_worker(args): - if 'SLURM_STEP_NUM_NODES' in os.environ: - args.nodes = int(os.environ['SLURM_STEP_NUM_NODES']) - elif 'SLURM_JOB_NUM_NODES' in os.environ: - args.nodes = int(os.environ['SLURM_JOB_NUM_NODES']) + if "SLURM_STEP_NUM_NODES" in os.environ: + args.nodes = int(os.environ["SLURM_STEP_NUM_NODES"]) + elif "SLURM_JOB_NUM_NODES" in os.environ: + args.nodes = int(os.environ["SLURM_JOB_NUM_NODES"]) else: - raise KeyError('missing node counts in slurm env') + raise KeyError("missing node counts in slurm env") args.gpus_per_node = torch.cuda.device_count() args.world_size = args.nodes * args.gpus_per_node - node = int(os.environ['SLURM_NODEID']) + node = int(os.environ["SLURM_NODEID"]) if args.gpus_per_node < 1: - raise RuntimeError('GPU not found on node {}'.format(node)) + raise RuntimeError("GPU not found on node {}".format(node)) spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node) def gpu_worker(local_rank, node, args): - #device = torch.device('cuda', local_rank) - #torch.cuda.device(device) # env var recommended over this + # device = torch.device('cuda', local_rank) + # torch.cuda.device(device) # env var recommended over this - os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' - os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) - device = torch.device('cuda', 0) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank) + device = torch.device("cuda", 0) rank = args.gpus_per_node * node + local_rank @@ -53,7 +53,7 @@ def gpu_worker(local_rank, node, args): # Note DDP broadcasts initial model states from rank 0 torch.manual_seed(args.seed + rank) # good practice to disable cudnn.benchmark if enabling cudnn.deterministic - #torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.deterministic = True dist_init(rank, args) @@ -77,9 +77,12 @@ def gpu_worker(local_rank, node, args): scale_factor=args.scale_factor, **args.misc_kwargs, ) - train_sampler = DistFieldSampler(train_dataset, shuffle=True, - div_data=args.div_data, - div_shuffle_dist=args.div_shuffle_dist) + train_sampler = DistFieldSampler( + train_dataset, + shuffle=True, + div_data=args.div_data, + div_shuffle_dist=args.div_shuffle_dist, + ) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, @@ -110,9 +113,12 @@ def gpu_worker(local_rank, node, args): scale_factor=args.scale_factor, **args.misc_kwargs, ) - val_sampler = DistFieldSampler(val_dataset, shuffle=False, - div_data=args.div_data, - div_shuffle_dist=args.div_shuffle_dist) + val_sampler = DistFieldSampler( + val_dataset, + shuffle=False, + div_data=args.div_data, + div_shuffle_dist=args.div_shuffle_dist, + ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, @@ -127,14 +133,19 @@ def gpu_worker(local_rank, node, args): args.out_chan = train_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(args.style_size, sum(args.in_chan), sum(args.out_chan), - scale_factor=args.scale_factor, **args.misc_kwargs) + model = model( + args.style_size, + sum(args.in_chan), + sum(args.out_chan), + scale_factor=args.scale_factor, + **args.misc_kwargs, + ) model.to(device) - model = DistributedDataParallel(model, device_ids=[device], - process_group=dist.new_group()) + model = DistributedDataParallel( + model, device_ids=[device], process_group=dist.new_group() + ) - criterion = import_attr(args.criterion, nn, models, - callback_at=args.callback_at) + criterion = import_attr(args.criterion, nn, models, callback_at=args.callback_at) criterion = criterion() criterion.to(device) @@ -144,11 +155,13 @@ def gpu_worker(local_rank, node, args): lr=args.lr, **args.optimizer_args, ) - scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, **args.scheduler_args) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **args.scheduler_args) - if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) - or not args.load_state): + if ( + args.load_state == ckpt_link + and not os.path.isfile(ckpt_link) + or not args.load_state + ): if args.init_weight_std is not None: model.apply(init_weights) @@ -159,23 +172,28 @@ def gpu_worker(local_rank, node, args): else: state = torch.load(args.load_state, map_location=device) - start_epoch = state['epoch'] + start_epoch = state["epoch"] - load_model_state_dict(model.module, state['model'], - strict=args.load_state_strict) + load_model_state_dict( + model.module, state["model"], strict=args.load_state_strict + ) - if 'optimizer' in state: - optimizer.load_state_dict(state['optimizer']) - if 'scheduler' in state: - scheduler.load_state_dict(state['scheduler']) + if "optimizer" in state: + optimizer.load_state_dict(state["optimizer"]) + if "scheduler" in state: + scheduler.load_state_dict(state["scheduler"]) - torch.set_rng_state(state['rng'].cpu()) # move rng state back + torch.set_rng_state(state["rng"].cpu()) # move rng state back if rank == 0: - min_loss = state['min_loss'] + min_loss = state["min_loss"] - print('state at epoch {} loaded from {}'.format( - state['epoch'], args.load_state), flush=True) + print( + "state at epoch {} loaded from {}".format( + state["epoch"], args.load_state + ), + flush=True, + ) del state @@ -189,21 +207,31 @@ def gpu_worker(local_rank, node, args): logger = SummaryWriter() if rank == 0: - print('pytorch {}'.format(torch.__version__)) + print("pytorch {}".format(torch.__version__)) pprint(vars(args)) sys.stdout.flush() for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train_loss = train(epoch, train_loader, model, criterion, - optimizer, scheduler, logger, device, args) + train_loss = train( + epoch, + train_loader, + model, + criterion, + optimizer, + scheduler, + logger, + device, + args, + ) epoch_loss = train_loss if args.val: - val_loss = validate(epoch, val_loader, model, criterion, - logger, device, args) - #epoch_loss = val_loss + val_loss = validate( + epoch, val_loader, model, criterion, logger, device, args + ) + # epoch_loss = val_loss if args.reduce_lr_on_plateau: scheduler.step(epoch_loss[2]) @@ -215,27 +243,26 @@ def gpu_worker(local_rank, node, args): min_loss = epoch_loss[2] state = { - 'epoch': epoch + 1, - 'model': model.module.state_dict(), - 'optimizer': optimizer.state_dict(), - 'scheduler': scheduler.state_dict(), - 'rng': torch.get_rng_state(), - 'min_loss': min_loss, + "epoch": epoch + 1, + "model": model.module.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "rng": torch.get_rng_state(), + "min_loss": min_loss, } - state_file = 'state_{}.pt'.format(epoch + 1) + state_file = "state_{}.pt".format(epoch + 1) torch.save(state, state_file) del state - tmp_link = '{}.pt'.format(time.time()) + tmp_link = "{}.pt".format(time.time()) os.symlink(state_file, tmp_link) # workaround to overwrite os.rename(tmp_link, ckpt_link) dist.destroy_process_group() -def train(epoch, loader, model, criterion, - optimizer, scheduler, logger, device, args): +def train(epoch, loader, model, criterion, optimizer, scheduler, logger, device, args): model.train() rank = dist.get_rank() @@ -246,7 +273,7 @@ def train(epoch, loader, model, criterion, for i, data in enumerate(loader): batch = epoch * len(loader) + i + 1 - style, input, target = data['style'], data['input'], data['target'] + style, input, target = data["style"], data["input"], data["target"] style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) @@ -254,23 +281,21 @@ def train(epoch, loader, model, criterion, output = model(input, style) if batch <= 5 and rank == 0: - print('##### batch :', batch) - print('style shape :', style.shape) - print('input shape :', input.shape) - print('output shape :', output.shape) - print('target shape :', target.shape) + print("##### batch :", batch) + print("style shape :", style.shape) + print("input shape :", input.shape) + print("output shape :", output.shape) + print("target shape :", target.shape) - if (hasattr(model.module, 'scale_factor') - and model.module.scale_factor != 1): + if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1: input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) if batch <= 5 and rank == 0: - print('narrowed shape :', output.shape) + print("narrowed shape :", output.shape) loss = criterion(output, target) epoch_loss[0] += loss.detach() - optimizer.zero_grad() loss.backward() optimizer.step() @@ -280,43 +305,45 @@ def train(epoch, loader, model, criterion, dist.all_reduce(loss) loss /= world_size if rank == 0: - logger.add_scalar('loss/batch/train', loss.item(), - global_step=batch) + logger.add_scalar("loss/batch/train", loss.item(), global_step=batch) - logger.add_scalar('grad/first', grads[0], global_step=batch) - logger.add_scalar('grad/last', grads[-1], global_step=batch) + logger.add_scalar("grad/first", grads[0], global_step=batch) + logger.add_scalar("grad/last", grads[-1], global_step=batch) dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/train', epoch_loss[0], - global_step=epoch+1) + logger.add_scalar("loss/epoch/train", epoch_loss[0], global_step=epoch + 1) skip_chan = 0 fig = plt_slices( - input[-1], output[-1, skip_chan:], target[-1, skip_chan:], + input[-1], + output[-1, skip_chan:], + target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:], - title=['in', 'out', 'tgt', 'out - tgt'], + title=["in", "out", "tgt", "out - tgt"], **args.misc_kwargs, ) - logger.add_figure('fig/train', fig, global_step=epoch+1) + logger.add_figure("fig/train", fig, global_step=epoch + 1) fig.clf() fig = plt_power( - input, output[:, skip_chan:], target[:, skip_chan:], - label=['in', 'out', 'tgt'], + input, + output[:, skip_chan:], + target[:, skip_chan:], + label=["in", "out", "tgt"], **args.misc_kwargs, ) - logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) + logger.add_figure("fig/train/power/lag", fig, global_step=epoch + 1) fig.clf() - #fig = plt_power(1.0, + # fig = plt_power(1.0, # dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # label=['in', 'out', 'tgt'], # **args.misc_kwargs, - #) - #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) - #fig.clf() + # ) + # logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) + # fig.clf() return epoch_loss @@ -331,7 +358,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): with torch.no_grad(): for data in loader: - style, input, target = data['style'], data['input'], data['target'] + style, input, target = data["style"], data["input"], data["target"] style = style.to(device, non_blocking=True) input = input.to(device, non_blocking=True) @@ -339,8 +366,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): output = model(input, style) - if (hasattr(model.module, 'scale_factor') - and model.module.scale_factor != 1): + if hasattr(model.module, "scale_factor") and model.module.scale_factor != 1: input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) @@ -350,40 +376,43 @@ def validate(epoch, loader, model, criterion, logger, device, args): dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/val', epoch_loss[0], - global_step=epoch+1) + logger.add_scalar("loss/epoch/val", epoch_loss[0], global_step=epoch + 1) skip_chan = 0 fig = plt_slices( - input[-1], output[-1, skip_chan:], target[-1, skip_chan:], + input[-1], + output[-1, skip_chan:], + target[-1, skip_chan:], output[-1, skip_chan:] - target[-1, skip_chan:], - title=['in', 'out', 'tgt', 'out - tgt'], + title=["in", "out", "tgt", "out - tgt"], **args.misc_kwargs, ) - logger.add_figure('fig/val', fig, global_step=epoch+1) + logger.add_figure("fig/val", fig, global_step=epoch + 1) fig.clf() fig = plt_power( - input, output[:, skip_chan:], target[:, skip_chan:], - label=['in', 'out', 'tgt'], + input, + output[:, skip_chan:], + target[:, skip_chan:], + label=["in", "out", "tgt"], **args.misc_kwargs, ) - logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) + logger.add_figure("fig/val/power/lag", fig, global_step=epoch + 1) fig.clf() - #fig = plt_power(1.0, + # fig = plt_power(1.0, # dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # label=['in', 'out', 'tgt'], # **args.misc_kwargs, - #) - #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) - #fig.clf() + # ) + # logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) + # fig.clf() return epoch_loss def dist_init(rank, args): - dist_file = 'dist_addr' + dist_file = "dist_addr" if rank == 0: addr = socket.gethostname() @@ -393,15 +422,15 @@ def dist_init(rank, args): s.bind((addr, 0)) _, port = s.getsockname() - args.dist_addr = 'tcp://{}:{}'.format(addr, port) + args.dist_addr = "tcp://{}:{}".format(addr, port) - with open(dist_file, mode='w') as f: + with open(dist_file, mode="w") as f: f.write(args.dist_addr) else: while not os.path.exists(dist_file): time.sleep(1) - with open(dist_file, mode='r') as f: + with open(dist_file, mode="r") as f: args.dist_addr = f.read() dist.init_process_group( @@ -417,19 +446,40 @@ def dist_init(rank, args): def init_weights(m, args): - if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, - nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + if isinstance( + m, + ( + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ), + ): m.weight.data.normal_(0.0, args.init_weight_std) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, - nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm, - nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): + elif isinstance( + m, + ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + nn.LayerNorm, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + ), + ): if m.affine: # NOTE: dispersion from DCGAN, why? m.weight.data.normal_(1.0, args.init_weight_std) m.bias.data.fill_(0) -def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False): +def set_requires_grad(module: torch.nn.Module, requires_grad: bool = False): for param in module.parameters(): param.requires_grad = requires_grad @@ -444,8 +494,7 @@ def get_grads(model: torch.nn.Module): Returns: A list containing the norms of the gradients of the first and the last layer weights. """ - grads = list(p.grad for n, p in model.named_parameters() - if '.weight' in n) + grads = list(p.grad for n, p in model.named_parameters() if ".weight" in n) grads = [grads[0], grads[-1]] grads = [g.detach().norm() for g in grads] return grads diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index c6ae4a8..359e6ae 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -2,11 +2,13 @@ from math import log2, log10, ceil import torch import numpy as np import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.colors import Normalize, LogNorm, SymLogNorm from matplotlib.cm import ScalarMappable -plt.rc('text', usetex=False) + +plt.rc("text", usetex=False) from ..models import lag2eul, power @@ -21,7 +23,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): Each field should have a channel dimension followed by spatial dimensions, i.e. no batch dimension. """ - plt.close('all') + plt.close("all") assert all(isinstance(field, torch.Tensor) for field in fields) @@ -38,11 +40,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): im_size = 2 cbar_height = 0.2 fig, axes = plt.subplots( - nc + 1, nf, + nc + 1, + nf, squeeze=False, figsize=(nf * im_size, nc * im_size + cbar_height), dpi=100, - gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]} + gridspec_kw={"height_ratios": nc * [im_size] + [cbar_height]}, ) for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)): @@ -51,11 +54,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): if cmap_col is None: if all_non_neg: - cmap_col = 'inferno' + cmap_col = "inferno" elif all_non_pos: - cmap_col = 'inferno_r' + cmap_col = "inferno_r" else: - cmap_col = 'RdBu_r' + cmap_col = "RdBu_r" if norm_col is None: l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5]) @@ -70,9 +73,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): if l1 < 0.1 * l2 or h2 == 0: norm_col = Normalize(vmin=-quantize(-l2), vmax=0) else: - norm_col = SymLogNorm(linthresh=quantize(-h2), - vmin=-quantize(-l2), - vmax=-quantize(-h2)) + norm_col = SymLogNorm( + linthresh=quantize(-h2), + vmin=-quantize(-l2), + vmax=-quantize(-h2), + ) else: vlim = quantize(max(-l2, h2)) if w1 > 0.1 * w2 or l1 * h1 >= 0: @@ -80,8 +85,13 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): else: linthresh = quantize(min(-l1, h1)) linscale = np.log10(vlim / linthresh) - norm_col = SymLogNorm(linthresh=linthresh, linscale=linscale, - vmin=-vlim, vmax=vlim, base=10) + norm_col = SymLogNorm( + linthresh=linthresh, + linscale=linscale, + vmin=-vlim, + vmax=vlim, + base=10, + ) for c in range(field.shape[0]): s = (c,) + tuple(d // 2 for d in field.shape[1:-2]) @@ -101,7 +111,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col) - axes[c, f].set_aspect('equal') + axes[c, f].set_aspect("equal") axes[c, f].set_xticks([]) axes[c, f].set_yticks([]) @@ -110,12 +120,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): axes[c, f].set_title(title[f]) for c in range(field.shape[0], nc): - axes[c, f].axis('off') + axes[c, f].axis("off") fig.colorbar( ScalarMappable(norm=norm_col, cmap=cmap_col), cax=axes[-1, f], - orientation='horizontal', + orientation="horizontal", ) fig.tight_layout() @@ -133,7 +143,7 @@ def plt_power(*fields, dis=None, label=None, **kwargs): See `map2map.models.power`. """ - plt.close('all') + plt.close("all") if label is not None: assert len(label) == len(fields) or len(label) == len(dis) @@ -159,8 +169,8 @@ def plt_power(*fields, dis=None, label=None, **kwargs): axes.loglog(k, P, label=l, alpha=0.7) axes.legend() - axes.set_xlabel('unnormalized wavenumber') - axes.set_ylabel('unnormalized power') + axes.set_xlabel("unnormalized wavenumber") + axes.set_ylabel("unnormalized power") fig.tight_layout() diff --git a/map2map/utils/imp.py b/map2map/utils/imp.py index 48717e1..3ba477e 100644 --- a/map2map/utils/imp.py +++ b/map2map/utils/imp.py @@ -21,7 +21,7 @@ def import_attr(name, *pkgs, callback_at=None): first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally from 'path/to/cb_dir/mod.py'. """ - if name.count('.') == 0: + if name.count(".") == 0: attr = name errors = [] @@ -34,23 +34,22 @@ def import_attr(name, *pkgs, callback_at=None): raise Exception(errors) else: - mod, attr = name.rsplit('.', 1) + mod, attr = name.rsplit(".", 1) errors = [] for pkg in pkgs: try: - return getattr( - importlib.import_module(pkg.__name__ + '.' + mod), attr) + return getattr(importlib.import_module(pkg.__name__ + "." + mod), attr) except (ModuleNotFoundError, AttributeError) as e: errors.append(e) if callback_at is None: raise Exception(errors) - callback_at = os.path.join(callback_at, mod + '.py') + callback_at = os.path.join(callback_at, mod + ".py") if not os.path.isfile(callback_at): - raise FileNotFoundError('callback file not found') + raise FileNotFoundError("callback file not found") if mod in sys.modules: return getattr(sys.modules[mod], attr) diff --git a/map2map/utils/state.py b/map2map/utils/state.py index c35cfab..5a338c3 100644 --- a/map2map/utils/state.py +++ b/map2map/utils/state.py @@ -7,9 +7,13 @@ def load_model_state_dict(module, state_dict, strict=True): bad_keys = module.load_state_dict(state_dict, strict) if len(bad_keys.missing_keys) > 0: - warnings.warn('Missing keys in state_dict:\n{}'.format( - pformat(bad_keys.missing_keys))) + warnings.warn( + "Missing keys in state_dict:\n{}".format(pformat(bad_keys.missing_keys)) + ) if len(bad_keys.unexpected_keys) > 0: - warnings.warn('Unexpected keys in state_dict:\n{}'.format( - pformat(bad_keys.unexpected_keys))) + warnings.warn( + "Unexpected keys in state_dict:\n{}".format( + pformat(bad_keys.unexpected_keys) + ) + ) sys.stderr.flush()