From 04d0bea17e0c11fcfb6d60976640671967488a39 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 12 May 2021 16:40:00 -0400 Subject: [PATCH] Add misc kwargs passing to custom models and norms --- README.md | 4 +--- map2map/args.py | 5 ++++- map2map/data/fields.py | 9 ++++++--- map2map/data/norms/__init__.py | 2 +- map2map/data/norms/cosmology.py | 20 ++++---------------- map2map/data/norms/torch.py | 8 ++++---- map2map/models/lag2eul.py | 19 ++++++++++++------- map2map/models/srsgan.py | 6 ++++-- map2map/test.py | 10 ++++++---- map2map/train.py | 16 +++++++++++----- map2map/utils/figures.py | 6 +++--- 11 files changed, 56 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index ec8257f..bcf42b5 100644 --- a/README.md +++ b/README.md @@ -133,9 +133,7 @@ The model `__init__` requires two positional arguments, the number of input and output channels. Other hyperparameters can be specified as keyword arguments, including the `scale_factor` useful for super-resolution tasks. -Note that the `**kwargs` is necessary when `scale_factor` is not -specified, because `scale_factor` is always passed when instantiating -a model. +Note that the `**kwargs` is necessary for compatibility. ### Training diff --git a/map2map/args.py b/map2map/args.py index ef0f146..dcce6a8 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -89,7 +89,10 @@ def add_common_args(parser): help='directory of custorm code defining callbacks for models, ' 'norms, criteria, and optimizers. Disabled if not set. ' 'This is appended to the default locations, ' - 'thus has the lowest priority.') + 'thus has the lowest priority') + parser.add_argument('--misc-kwargs', default='{}', type=json.loads, + help='miscellaneous keyword arguments for custom models and ' + 'norms. Be careful with name collisions') def add_train_args(parser): diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 62ea348..4ef0c9e 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -47,7 +47,8 @@ class FieldDataset(Dataset): in_norms=None, tgt_norms=None, callback_at=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None, crop=None, crop_start=None, crop_stop=None, crop_step=None, - in_pad=0, tgt_pad=0, scale_factor=1): + in_pad=0, tgt_pad=0, scale_factor=1, + **kwargs): self.style_files = sorted(glob(style_pattern)) in_file_lists = [sorted(glob(p)) for p in in_patterns] @@ -138,6 +139,8 @@ class FieldDataset(Dataset): self.nsample = self.nfile * self.ncrop + self.kwargs = kwargs + self.assembly_line = {} self.commonpath = os.path.commonpath( @@ -187,11 +190,11 @@ class FieldDataset(Dataset): if self.in_norms is not None: for norm, x in zip(self.in_norms, in_fields): norm = import_attr(norm, norms, callback_at=self.callback_at) - norm(x) + norm(x, **self.kwargs) if self.tgt_norms is not None: for norm, x in zip(self.tgt_norms, tgt_fields): norm = import_attr(norm, norms, callback_at=self.callback_at) - norm(x) + norm(x, **self.kwargs) if self.augment: flip_axes = flip(in_fields, None, self.ndim) diff --git a/map2map/data/norms/__init__.py b/map2map/data/norms/__init__.py index 060df41..bdb4711 100644 --- a/map2map/data/norms/__init__.py +++ b/map2map/data/norms/__init__.py @@ -1,2 +1,2 @@ -def identity(x, undo=False): +def identity(x, undo=False, **kwargs): pass diff --git a/map2map/data/norms/cosmology.py b/map2map/data/norms/cosmology.py index ee1cfa4..a310875 100644 --- a/map2map/data/norms/cosmology.py +++ b/map2map/data/norms/cosmology.py @@ -2,34 +2,22 @@ import numpy as np from scipy.special import hyp2f1 -def dis(x, undo=False): - z = 0 # FIXME - dis_norm = 6 * D(z) # [Mpc/h] +def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs): + dis_norm = dis_std * D(z) # [Mpc/h] if not undo: dis_norm = 1 / dis_norm x *= dis_norm -def vel(x, undo=False): - z = 0 # FIXME - vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s] +def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs): + vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s] if not undo: vel_norm = 1 / vel_norm x *= vel_norm -def den(x, undo=False): - raise NotImplementedError - z = 0 # FIXME - den_norm = 0 # FIXME - - if not undo: - den_norm = 1 / den_norm - - x *= den_norm - def D(z, Om=0.31): """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero diff --git a/map2map/data/norms/torch.py b/map2map/data/norms/torch.py index f9ab2af..64a1ffa 100644 --- a/map2map/data/norms/torch.py +++ b/map2map/data/norms/torch.py @@ -1,25 +1,25 @@ import torch -def exp(x, undo=False): +def exp(x, undo=False, **kwargs): if not undo: torch.exp(x, out=x) else: torch.log(x, out=x) -def log(x, eps=1e-8, undo=False): +def log(x, eps=1e-8, undo=False, **kwargs): if not undo: torch.log(x + eps, out=x) else: torch.exp(x, out=x) -def expm1(x, undo=False): +def expm1(x, undo=False, **kwargs): if not undo: torch.expm1(x, out=x) else: torch.log1p(x, out=x) -def log1p(x, eps=1e-7, undo=False): +def log1p(x, eps=1e-7, undo=False, **kwargs): if not undo: torch.log1p(x + eps, out=x) else: diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index 89a2738..63913ba 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -1,8 +1,12 @@ import torch import torch.nn as nn +from ..data.norms.cosmology import D -def lag2eul(*xs, rm_dis_mean=True, periodic=False): + +def lag2eul(*xs, rm_dis_mean=True, periodic=False, + z=0.0, dis_std=6.0, boxsize=1000., meshsize=512, + **kwargs): """Transform fields from Lagrangian description to Eulerian description Only works for 3d fields, output same mesh size as input. @@ -12,14 +16,15 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False): latter from Lagrangian to Eulerian positions and then "paint" with CIC (trilinear) scheme. Use 1 if the latter is empty. + Note that the box and mesh sizes don't have to be that of the inputs, as + long as their ratio gives the right resolution. One can therefore set them + to the values of the whole fields, and use smaller inputs. + Implementation follows pmesh/cic.py by Yu Feng. """ - # FIXME for other redshift, box and mesh sizes - from ..data.norms.cosmology import D - z = 0 - Boxsize = 1000 - Nmesh = 512 - dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit + # NOTE the following factor assumes normalized displacements + # and thus undoes it + dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit if any(x.dim() != 5 for x in xs): raise NotImplementedError('only support 3d fields for now') diff --git a/map2map/models/srsgan.py b/map2map/models/srsgan.py index ff0de9a..18faa42 100644 --- a/map2map/models/srsgan.py +++ b/map2map/models/srsgan.py @@ -8,7 +8,8 @@ from .resample import Resampler class G(nn.Module): def __init__(self, in_chan, out_chan, scale_factor=16, - chan_base=512, chan_min=64, chan_max=512, cat_noise=False): + chan_base=512, chan_min=64, chan_max=512, cat_noise=False, + **kwargs): super().__init__() self.scale_factor = scale_factor @@ -137,7 +138,8 @@ class AddNoise(nn.Module): class D(nn.Module): def __init__(self, in_chan, out_chan, scale_factor=16, - chan_base=512, chan_min=64, chan_max=512): + chan_base=512, chan_min=64, chan_max=512, + **kwargs): super().__init__() self.scale_factor = scale_factor diff --git a/map2map/test.py b/map2map/test.py index 314b14e..1cdd6c8 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -34,6 +34,7 @@ def test(args): in_pad=args.in_pad, tgt_pad=args.tgt_pad, scale_factor=args.scale_factor, + **args.misc_kwargs, ) test_loader = DataLoader( test_dataset, @@ -47,7 +48,8 @@ def test(args): out_chan = test_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) + model = model(style_size, sum(in_chan), sum(out_chan), + scale_factor=args.scale_factor, **args.misc_kwargs) criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) criterion = criterion() @@ -75,14 +77,14 @@ def test(args): # start = 0 # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): # norm = import_attr(norm, norms, callback_at=args.callback_at) - # norm(input[:, start:stop], undo=True) + # norm(input[:, start:stop], undo=True, **args.misc_kwargs) # start = stop if args.tgt_norms is not None: start = 0 for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): norm = import_attr(norm, norms, callback_at=args.callback_at) - norm(output[:, start:stop], undo=True) - #norm(target[:, start:stop], undo=True) + norm(output[:, start:stop], undo=True, **args.misc_kwargs) + #norm(target[:, start:stop], undo=True, **args.misc_kwargs) start = stop #test_dataset.assemble('_in', in_chan, input, diff --git a/map2map/train.py b/map2map/train.py index 9dac857..6307649 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -76,6 +76,7 @@ def gpu_worker(local_rank, node, args): in_pad=args.in_pad, tgt_pad=args.tgt_pad, scale_factor=args.scale_factor, + **args.misc_kwargs, ) train_sampler = DistFieldSampler(train_dataset, shuffle=True, div_data=args.div_data, @@ -108,6 +109,7 @@ def gpu_worker(local_rank, node, args): in_pad=args.in_pad, tgt_pad=args.tgt_pad, scale_factor=args.scale_factor, + **args.misc_kwargs, ) val_sampler = DistFieldSampler(val_dataset, shuffle=False, div_data=args.div_data, @@ -127,7 +129,7 @@ def gpu_worker(local_rank, node, args): model = import_attr(args.model, models, callback_at=args.callback_at) model = model(args.style_size, sum(args.in_chan), sum(args.out_chan), - scale_factor=args.scale_factor) + scale_factor=args.scale_factor, **args.misc_kwargs) model.to(device) model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) @@ -314,16 +316,18 @@ def train(epoch, loader, model, criterion, eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], + **args.misc_kwargs, ) logger.add_figure('fig/train', fig, global_step=epoch+1) fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) + #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], + # **args.misc_kwargs) #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) #fig.clf() #fig = plt_power(input, lag_out, lag_tgt, l2e=True, - # label=['in', 'out', 'tgt']) + # label=['in', 'out', 'tgt'], **args.misc_kwargs) #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) #fig.clf() @@ -378,16 +382,18 @@ def validate(epoch, loader, model, criterion, logger, device, args): eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], + **args.misc_kwargs, ) logger.add_figure('fig/val', fig, global_step=epoch+1) fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) + #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], + # **args.misc_kwargs) #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) #fig.clf() #fig = plt_power(input, lag_out, lag_tgt, l2e=True, - # label=['in', 'out', 'tgt']) + # label=['in', 'out', 'tgt'], **args.misc_kwargs) #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) #fig.clf() diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index 03d8b7e..500e1f8 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -14,7 +14,7 @@ def quantize(x): return 2 ** round(log2(x), ndigits=1) -def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): +def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): """Plot slices of fields of more than 2 spatial dimensions. Each field should have a channel dimension followed by spatial dimensions, @@ -122,7 +122,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): return fig -def plt_power(*fields, l2e=False, label=None): +def plt_power(*fields, l2e=False, label=None, **kwargs): """Plot power spectra of fields. Each field should have batch and channel dimensions followed by spatial @@ -141,7 +141,7 @@ def plt_power(*fields, l2e=False, label=None): with torch.no_grad(): if l2e: - fields = lag2eul(*fields) + fields = lag2eul(*fields, **kwargs) ks, Ps = [], [] for field in fields: