Add misc kwargs passing to custom models and norms

This commit is contained in:
Yin Li 2021-05-12 16:40:00 -04:00
parent 8a3cd1843d
commit 04d0bea17e
11 changed files with 56 additions and 49 deletions

View file

@ -133,9 +133,7 @@ The model `__init__` requires two positional arguments, the number of
input and output channels.
Other hyperparameters can be specified as keyword arguments, including
the `scale_factor` useful for super-resolution tasks.
Note that the `**kwargs` is necessary when `scale_factor` is not
specified, because `scale_factor` is always passed when instantiating
a model.
Note that the `**kwargs` is necessary for compatibility.
### Training

View file

@ -89,7 +89,10 @@ def add_common_args(parser):
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority.')
'thus has the lowest priority')
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
help='miscellaneous keyword arguments for custom models and '
'norms. Be careful with name collisions')
def add_train_args(parser):

View file

@ -47,7 +47,8 @@ class FieldDataset(Dataset):
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1):
in_pad=0, tgt_pad=0, scale_factor=1,
**kwargs):
self.style_files = sorted(glob(style_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -138,6 +139,8 @@ class FieldDataset(Dataset):
self.nsample = self.nfile * self.ncrop
self.kwargs = kwargs
self.assembly_line = {}
self.commonpath = os.path.commonpath(
@ -187,11 +190,11 @@ class FieldDataset(Dataset):
if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x)
norm(x, **self.kwargs)
if self.tgt_norms is not None:
for norm, x in zip(self.tgt_norms, tgt_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x)
norm(x, **self.kwargs)
if self.augment:
flip_axes = flip(in_fields, None, self.ndim)

View file

@ -1,2 +1,2 @@
def identity(x, undo=False):
def identity(x, undo=False, **kwargs):
pass

View file

@ -2,34 +2,22 @@ import numpy as np
from scipy.special import hyp2f1
def dis(x, undo=False):
z = 0 # FIXME
dis_norm = 6 * D(z) # [Mpc/h]
def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
dis_norm = dis_std * D(z) # [Mpc/h]
if not undo:
dis_norm = 1 / dis_norm
x *= dis_norm
def vel(x, undo=False):
z = 0 # FIXME
vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s]
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
if not undo:
vel_norm = 1 / vel_norm
x *= vel_norm
def den(x, undo=False):
raise NotImplementedError
z = 0 # FIXME
den_norm = 0 # FIXME
if not undo:
den_norm = 1 / den_norm
x *= den_norm
def D(z, Om=0.31):
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero

View file

@ -1,25 +1,25 @@
import torch
def exp(x, undo=False):
def exp(x, undo=False, **kwargs):
if not undo:
torch.exp(x, out=x)
else:
torch.log(x, out=x)
def log(x, eps=1e-8, undo=False):
def log(x, eps=1e-8, undo=False, **kwargs):
if not undo:
torch.log(x + eps, out=x)
else:
torch.exp(x, out=x)
def expm1(x, undo=False):
def expm1(x, undo=False, **kwargs):
if not undo:
torch.expm1(x, out=x)
else:
torch.log1p(x, out=x)
def log1p(x, eps=1e-7, undo=False):
def log1p(x, eps=1e-7, undo=False, **kwargs):
if not undo:
torch.log1p(x + eps, out=x)
else:

View file

@ -1,8 +1,12 @@
import torch
import torch.nn as nn
from ..data.norms.cosmology import D
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
**kwargs):
"""Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input.
@ -12,14 +16,15 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False):
latter from Lagrangian to Eulerian positions and then "paint" with CIC
(trilinear) scheme. Use 1 if the latter is empty.
Note that the box and mesh sizes don't have to be that of the inputs, as
long as their ratio gives the right resolution. One can therefore set them
to the values of the whole fields, and use smaller inputs.
Implementation follows pmesh/cic.py by Yu Feng.
"""
# FIXME for other redshift, box and mesh sizes
from ..data.norms.cosmology import D
z = 0
Boxsize = 1000
Nmesh = 512
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
# NOTE the following factor assumes normalized displacements
# and thus undoes it
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
if any(x.dim() != 5 for x in xs):
raise NotImplementedError('only support 3d fields for now')

View file

@ -8,7 +8,8 @@ from .resample import Resampler
class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
**kwargs):
super().__init__()
self.scale_factor = scale_factor
@ -137,7 +138,8 @@ class AddNoise(nn.Module):
class D(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512):
chan_base=512, chan_min=64, chan_max=512,
**kwargs):
super().__init__()
self.scale_factor = scale_factor

View file

@ -34,6 +34,7 @@ def test(args):
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
test_loader = DataLoader(
test_dataset,
@ -47,7 +48,8 @@ def test(args):
out_chan = test_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
model = model(style_size, sum(in_chan), sum(out_chan),
scale_factor=args.scale_factor, **args.misc_kwargs)
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
criterion = criterion()
@ -75,14 +77,14 @@ def test(args):
# start = 0
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
# norm = import_attr(norm, norms, callback_at=args.callback_at)
# norm(input[:, start:stop], undo=True)
# norm(input[:, start:stop], undo=True, **args.misc_kwargs)
# start = stop
if args.tgt_norms is not None:
start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True)
#norm(target[:, start:stop], undo=True)
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
start = stop
#test_dataset.assemble('_in', in_chan, input,

View file

@ -76,6 +76,7 @@ def gpu_worker(local_rank, node, args):
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
@ -108,6 +109,7 @@ def gpu_worker(local_rank, node, args):
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
**args.misc_kwargs,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
@ -127,7 +129,7 @@ def gpu_worker(local_rank, node, args):
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
scale_factor=args.scale_factor, **args.misc_kwargs)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
@ -314,16 +316,18 @@ def train(epoch, loader, model, criterion,
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs,
)
logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
# **args.misc_kwargs)
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
@ -378,16 +382,18 @@ def validate(epoch, loader, model, criterion, logger, device, args):
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs,
)
logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
# **args.misc_kwargs)
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()

View file

@ -14,7 +14,7 @@ def quantize(x):
return 2 ** round(log2(x), ndigits=1)
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
"""Plot slices of fields of more than 2 spatial dimensions.
Each field should have a channel dimension followed by spatial dimensions,
@ -122,7 +122,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
return fig
def plt_power(*fields, l2e=False, label=None):
def plt_power(*fields, l2e=False, label=None, **kwargs):
"""Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial
@ -141,7 +141,7 @@ def plt_power(*fields, l2e=False, label=None):
with torch.no_grad():
if l2e:
fields = lag2eul(*fields)
fields = lag2eul(*fields, **kwargs)
ks, Ps = [], []
for field in fields: