Add misc kwargs passing to custom models and norms
This commit is contained in:
parent
8a3cd1843d
commit
04d0bea17e
11 changed files with 56 additions and 49 deletions
|
@ -133,9 +133,7 @@ The model `__init__` requires two positional arguments, the number of
|
|||
input and output channels.
|
||||
Other hyperparameters can be specified as keyword arguments, including
|
||||
the `scale_factor` useful for super-resolution tasks.
|
||||
Note that the `**kwargs` is necessary when `scale_factor` is not
|
||||
specified, because `scale_factor` is always passed when instantiating
|
||||
a model.
|
||||
Note that the `**kwargs` is necessary for compatibility.
|
||||
|
||||
|
||||
### Training
|
||||
|
|
|
@ -89,7 +89,10 @@ def add_common_args(parser):
|
|||
help='directory of custorm code defining callbacks for models, '
|
||||
'norms, criteria, and optimizers. Disabled if not set. '
|
||||
'This is appended to the default locations, '
|
||||
'thus has the lowest priority.')
|
||||
'thus has the lowest priority')
|
||||
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
|
||||
help='miscellaneous keyword arguments for custom models and '
|
||||
'norms. Be careful with name collisions')
|
||||
|
||||
|
||||
def add_train_args(parser):
|
||||
|
|
|
@ -47,7 +47,8 @@ class FieldDataset(Dataset):
|
|||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||
in_pad=0, tgt_pad=0, scale_factor=1,
|
||||
**kwargs):
|
||||
self.style_files = sorted(glob(style_pattern))
|
||||
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
|
@ -138,6 +139,8 @@ class FieldDataset(Dataset):
|
|||
|
||||
self.nsample = self.nfile * self.ncrop
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.assembly_line = {}
|
||||
|
||||
self.commonpath = os.path.commonpath(
|
||||
|
@ -187,11 +190,11 @@ class FieldDataset(Dataset):
|
|||
if self.in_norms is not None:
|
||||
for norm, x in zip(self.in_norms, in_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x)
|
||||
norm(x, **self.kwargs)
|
||||
if self.tgt_norms is not None:
|
||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x)
|
||||
norm(x, **self.kwargs)
|
||||
|
||||
if self.augment:
|
||||
flip_axes = flip(in_fields, None, self.ndim)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
def identity(x, undo=False):
|
||||
def identity(x, undo=False, **kwargs):
|
||||
pass
|
||||
|
|
|
@ -2,34 +2,22 @@ import numpy as np
|
|||
from scipy.special import hyp2f1
|
||||
|
||||
|
||||
def dis(x, undo=False):
|
||||
z = 0 # FIXME
|
||||
dis_norm = 6 * D(z) # [Mpc/h]
|
||||
def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||
dis_norm = dis_std * D(z) # [Mpc/h]
|
||||
|
||||
if not undo:
|
||||
dis_norm = 1 / dis_norm
|
||||
|
||||
x *= dis_norm
|
||||
|
||||
def vel(x, undo=False):
|
||||
z = 0 # FIXME
|
||||
vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||
|
||||
if not undo:
|
||||
vel_norm = 1 / vel_norm
|
||||
|
||||
x *= vel_norm
|
||||
|
||||
def den(x, undo=False):
|
||||
raise NotImplementedError
|
||||
z = 0 # FIXME
|
||||
den_norm = 0 # FIXME
|
||||
|
||||
if not undo:
|
||||
den_norm = 1 / den_norm
|
||||
|
||||
x *= den_norm
|
||||
|
||||
|
||||
def D(z, Om=0.31):
|
||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
import torch
|
||||
|
||||
|
||||
def exp(x, undo=False):
|
||||
def exp(x, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.exp(x, out=x)
|
||||
else:
|
||||
torch.log(x, out=x)
|
||||
|
||||
def log(x, eps=1e-8, undo=False):
|
||||
def log(x, eps=1e-8, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.log(x + eps, out=x)
|
||||
else:
|
||||
torch.exp(x, out=x)
|
||||
|
||||
def expm1(x, undo=False):
|
||||
def expm1(x, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.expm1(x, out=x)
|
||||
else:
|
||||
torch.log1p(x, out=x)
|
||||
|
||||
def log1p(x, eps=1e-7, undo=False):
|
||||
def log1p(x, eps=1e-7, undo=False, **kwargs):
|
||||
if not undo:
|
||||
torch.log1p(x + eps, out=x)
|
||||
else:
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..data.norms.cosmology import D
|
||||
|
||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
||||
|
||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
||||
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
|
||||
**kwargs):
|
||||
"""Transform fields from Lagrangian description to Eulerian description
|
||||
|
||||
Only works for 3d fields, output same mesh size as input.
|
||||
|
@ -12,14 +16,15 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
|||
latter from Lagrangian to Eulerian positions and then "paint" with CIC
|
||||
(trilinear) scheme. Use 1 if the latter is empty.
|
||||
|
||||
Note that the box and mesh sizes don't have to be that of the inputs, as
|
||||
long as their ratio gives the right resolution. One can therefore set them
|
||||
to the values of the whole fields, and use smaller inputs.
|
||||
|
||||
Implementation follows pmesh/cic.py by Yu Feng.
|
||||
"""
|
||||
# FIXME for other redshift, box and mesh sizes
|
||||
from ..data.norms.cosmology import D
|
||||
z = 0
|
||||
Boxsize = 1000
|
||||
Nmesh = 512
|
||||
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
||||
# NOTE the following factor assumes normalized displacements
|
||||
# and thus undoes it
|
||||
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
|
||||
|
||||
if any(x.dim() != 5 for x in xs):
|
||||
raise NotImplementedError('only support 3d fields for now')
|
||||
|
|
|
@ -8,7 +8,8 @@ from .resample import Resampler
|
|||
|
||||
class G(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
|
@ -137,7 +138,8 @@ class AddNoise(nn.Module):
|
|||
|
||||
class D(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512):
|
||||
chan_base=512, chan_min=64, chan_max=512,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
|
|
|
@ -34,6 +34,7 @@ def test(args):
|
|||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
|
@ -47,7 +48,8 @@ def test(args):
|
|||
out_chan = test_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||
model = model(style_size, sum(in_chan), sum(out_chan),
|
||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
||||
criterion = criterion()
|
||||
|
||||
|
@ -75,14 +77,14 @@ def test(args):
|
|||
# start = 0
|
||||
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
# norm(input[:, start:stop], undo=True)
|
||||
# norm(input[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
# start = stop
|
||||
if args.tgt_norms is not None:
|
||||
start = 0
|
||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(output[:, start:stop], undo=True)
|
||||
#norm(target[:, start:stop], undo=True)
|
||||
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
||||
start = stop
|
||||
|
||||
#test_dataset.assemble('_in', in_chan, input,
|
||||
|
|
|
@ -76,6 +76,7 @@ def gpu_worker(local_rank, node, args):
|
|||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||
div_data=args.div_data,
|
||||
|
@ -108,6 +109,7 @@ def gpu_worker(local_rank, node, args):
|
|||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||
div_data=args.div_data,
|
||||
|
@ -127,7 +129,7 @@ def gpu_worker(local_rank, node, args):
|
|||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
||||
scale_factor=args.scale_factor)
|
||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
@ -314,16 +316,18 @@ def train(epoch, loader, model, criterion,
|
|||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs)
|
||||
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'])
|
||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
|
@ -378,16 +382,18 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs)
|
||||
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'])
|
||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ def quantize(x):
|
|||
return 2 ** round(log2(x), ndigits=1)
|
||||
|
||||
|
||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
||||
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||
|
||||
Each field should have a channel dimension followed by spatial dimensions,
|
||||
|
@ -122,7 +122,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||
return fig
|
||||
|
||||
|
||||
def plt_power(*fields, l2e=False, label=None):
|
||||
def plt_power(*fields, l2e=False, label=None, **kwargs):
|
||||
"""Plot power spectra of fields.
|
||||
|
||||
Each field should have batch and channel dimensions followed by spatial
|
||||
|
@ -141,7 +141,7 @@ def plt_power(*fields, l2e=False, label=None):
|
|||
|
||||
with torch.no_grad():
|
||||
if l2e:
|
||||
fields = lag2eul(*fields)
|
||||
fields = lag2eul(*fields, **kwargs)
|
||||
|
||||
ks, Ps = [], []
|
||||
for field in fields:
|
||||
|
|
Loading…
Reference in a new issue