Add misc kwargs passing to custom models and norms
This commit is contained in:
parent
8a3cd1843d
commit
04d0bea17e
@ -133,9 +133,7 @@ The model `__init__` requires two positional arguments, the number of
|
|||||||
input and output channels.
|
input and output channels.
|
||||||
Other hyperparameters can be specified as keyword arguments, including
|
Other hyperparameters can be specified as keyword arguments, including
|
||||||
the `scale_factor` useful for super-resolution tasks.
|
the `scale_factor` useful for super-resolution tasks.
|
||||||
Note that the `**kwargs` is necessary when `scale_factor` is not
|
Note that the `**kwargs` is necessary for compatibility.
|
||||||
specified, because `scale_factor` is always passed when instantiating
|
|
||||||
a model.
|
|
||||||
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
@ -89,7 +89,10 @@ def add_common_args(parser):
|
|||||||
help='directory of custorm code defining callbacks for models, '
|
help='directory of custorm code defining callbacks for models, '
|
||||||
'norms, criteria, and optimizers. Disabled if not set. '
|
'norms, criteria, and optimizers. Disabled if not set. '
|
||||||
'This is appended to the default locations, '
|
'This is appended to the default locations, '
|
||||||
'thus has the lowest priority.')
|
'thus has the lowest priority')
|
||||||
|
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
|
||||||
|
help='miscellaneous keyword arguments for custom models and '
|
||||||
|
'norms. Be careful with name collisions')
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
|
@ -47,7 +47,8 @@ class FieldDataset(Dataset):
|
|||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
in_pad=0, tgt_pad=0, scale_factor=1,
|
||||||
|
**kwargs):
|
||||||
self.style_files = sorted(glob(style_pattern))
|
self.style_files = sorted(glob(style_pattern))
|
||||||
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
@ -138,6 +139,8 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
self.nsample = self.nfile * self.ncrop
|
self.nsample = self.nfile * self.ncrop
|
||||||
|
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
self.assembly_line = {}
|
self.assembly_line = {}
|
||||||
|
|
||||||
self.commonpath = os.path.commonpath(
|
self.commonpath = os.path.commonpath(
|
||||||
@ -187,11 +190,11 @@ class FieldDataset(Dataset):
|
|||||||
if self.in_norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, x in zip(self.in_norms, in_fields):
|
for norm, x in zip(self.in_norms, in_fields):
|
||||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
norm(x)
|
norm(x, **self.kwargs)
|
||||||
if self.tgt_norms is not None:
|
if self.tgt_norms is not None:
|
||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
norm(x)
|
norm(x, **self.kwargs)
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
flip_axes = flip(in_fields, None, self.ndim)
|
flip_axes = flip(in_fields, None, self.ndim)
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
def identity(x, undo=False):
|
def identity(x, undo=False, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -2,34 +2,22 @@ import numpy as np
|
|||||||
from scipy.special import hyp2f1
|
from scipy.special import hyp2f1
|
||||||
|
|
||||||
|
|
||||||
def dis(x, undo=False):
|
def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||||
z = 0 # FIXME
|
dis_norm = dis_std * D(z) # [Mpc/h]
|
||||||
dis_norm = 6 * D(z) # [Mpc/h]
|
|
||||||
|
|
||||||
if not undo:
|
if not undo:
|
||||||
dis_norm = 1 / dis_norm
|
dis_norm = 1 / dis_norm
|
||||||
|
|
||||||
x *= dis_norm
|
x *= dis_norm
|
||||||
|
|
||||||
def vel(x, undo=False):
|
def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||||
z = 0 # FIXME
|
vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
||||||
vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s]
|
|
||||||
|
|
||||||
if not undo:
|
if not undo:
|
||||||
vel_norm = 1 / vel_norm
|
vel_norm = 1 / vel_norm
|
||||||
|
|
||||||
x *= vel_norm
|
x *= vel_norm
|
||||||
|
|
||||||
def den(x, undo=False):
|
|
||||||
raise NotImplementedError
|
|
||||||
z = 0 # FIXME
|
|
||||||
den_norm = 0 # FIXME
|
|
||||||
|
|
||||||
if not undo:
|
|
||||||
den_norm = 1 / den_norm
|
|
||||||
|
|
||||||
x *= den_norm
|
|
||||||
|
|
||||||
|
|
||||||
def D(z, Om=0.31):
|
def D(z, Om=0.31):
|
||||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def exp(x, undo=False):
|
def exp(x, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.exp(x, out=x)
|
torch.exp(x, out=x)
|
||||||
else:
|
else:
|
||||||
torch.log(x, out=x)
|
torch.log(x, out=x)
|
||||||
|
|
||||||
def log(x, eps=1e-8, undo=False):
|
def log(x, eps=1e-8, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.log(x + eps, out=x)
|
torch.log(x + eps, out=x)
|
||||||
else:
|
else:
|
||||||
torch.exp(x, out=x)
|
torch.exp(x, out=x)
|
||||||
|
|
||||||
def expm1(x, undo=False):
|
def expm1(x, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.expm1(x, out=x)
|
torch.expm1(x, out=x)
|
||||||
else:
|
else:
|
||||||
torch.log1p(x, out=x)
|
torch.log1p(x, out=x)
|
||||||
|
|
||||||
def log1p(x, eps=1e-7, undo=False):
|
def log1p(x, eps=1e-7, undo=False, **kwargs):
|
||||||
if not undo:
|
if not undo:
|
||||||
torch.log1p(x + eps, out=x)
|
torch.log1p(x + eps, out=x)
|
||||||
else:
|
else:
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..data.norms.cosmology import D
|
||||||
|
|
||||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
|
||||||
|
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
||||||
|
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
|
||||||
|
**kwargs):
|
||||||
"""Transform fields from Lagrangian description to Eulerian description
|
"""Transform fields from Lagrangian description to Eulerian description
|
||||||
|
|
||||||
Only works for 3d fields, output same mesh size as input.
|
Only works for 3d fields, output same mesh size as input.
|
||||||
@ -12,14 +16,15 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
|||||||
latter from Lagrangian to Eulerian positions and then "paint" with CIC
|
latter from Lagrangian to Eulerian positions and then "paint" with CIC
|
||||||
(trilinear) scheme. Use 1 if the latter is empty.
|
(trilinear) scheme. Use 1 if the latter is empty.
|
||||||
|
|
||||||
|
Note that the box and mesh sizes don't have to be that of the inputs, as
|
||||||
|
long as their ratio gives the right resolution. One can therefore set them
|
||||||
|
to the values of the whole fields, and use smaller inputs.
|
||||||
|
|
||||||
Implementation follows pmesh/cic.py by Yu Feng.
|
Implementation follows pmesh/cic.py by Yu Feng.
|
||||||
"""
|
"""
|
||||||
# FIXME for other redshift, box and mesh sizes
|
# NOTE the following factor assumes normalized displacements
|
||||||
from ..data.norms.cosmology import D
|
# and thus undoes it
|
||||||
z = 0
|
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
|
||||||
Boxsize = 1000
|
|
||||||
Nmesh = 512
|
|
||||||
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
|
||||||
|
|
||||||
if any(x.dim() != 5 for x in xs):
|
if any(x.dim() != 5 for x in xs):
|
||||||
raise NotImplementedError('only support 3d fields for now')
|
raise NotImplementedError('only support 3d fields for now')
|
||||||
|
@ -8,7 +8,8 @@ from .resample import Resampler
|
|||||||
|
|
||||||
class G(nn.Module):
|
class G(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
|
chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
|
||||||
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -137,7 +138,8 @@ class AddNoise(nn.Module):
|
|||||||
|
|
||||||
class D(nn.Module):
|
class D(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||||
chan_base=512, chan_min=64, chan_max=512):
|
chan_base=512, chan_min=64, chan_max=512,
|
||||||
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
@ -34,6 +34,7 @@ def test(args):
|
|||||||
in_pad=args.in_pad,
|
in_pad=args.in_pad,
|
||||||
tgt_pad=args.tgt_pad,
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
@ -47,7 +48,8 @@ def test(args):
|
|||||||
out_chan = test_dataset.tgt_chan
|
out_chan = test_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
model = model(style_size, sum(in_chan), sum(out_chan),
|
||||||
|
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||||
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
@ -75,14 +77,14 @@ def test(args):
|
|||||||
# start = 0
|
# start = 0
|
||||||
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||||
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
# norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
# norm(input[:, start:stop], undo=True)
|
# norm(input[:, start:stop], undo=True, **args.misc_kwargs)
|
||||||
# start = stop
|
# start = stop
|
||||||
if args.tgt_norms is not None:
|
if args.tgt_norms is not None:
|
||||||
start = 0
|
start = 0
|
||||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(output[:, start:stop], undo=True)
|
norm(output[:, start:stop], undo=True, **args.misc_kwargs)
|
||||||
#norm(target[:, start:stop], undo=True)
|
#norm(target[:, start:stop], undo=True, **args.misc_kwargs)
|
||||||
start = stop
|
start = stop
|
||||||
|
|
||||||
#test_dataset.assemble('_in', in_chan, input,
|
#test_dataset.assemble('_in', in_chan, input,
|
||||||
|
@ -76,6 +76,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
in_pad=args.in_pad,
|
in_pad=args.in_pad,
|
||||||
tgt_pad=args.tgt_pad,
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
@ -108,6 +109,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
in_pad=args.in_pad,
|
in_pad=args.in_pad,
|
||||||
tgt_pad=args.tgt_pad,
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
@ -127,7 +129,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
|
||||||
scale_factor=args.scale_factor)
|
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
@ -314,16 +316,18 @@ def train(epoch, loader, model, criterion,
|
|||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||||
|
# **args.misc_kwargs)
|
||||||
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||||
# label=['in', 'out', 'tgt'])
|
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
@ -378,16 +382,18 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||||
|
# **args.misc_kwargs)
|
||||||
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||||
# label=['in', 'out', 'tgt'])
|
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ def quantize(x):
|
|||||||
return 2 ** round(log2(x), ndigits=1)
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
|
|
||||||
|
|
||||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
||||||
"""Plot slices of fields of more than 2 spatial dimensions.
|
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||||
|
|
||||||
Each field should have a channel dimension followed by spatial dimensions,
|
Each field should have a channel dimension followed by spatial dimensions,
|
||||||
@ -122,7 +122,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plt_power(*fields, l2e=False, label=None):
|
def plt_power(*fields, l2e=False, label=None, **kwargs):
|
||||||
"""Plot power spectra of fields.
|
"""Plot power spectra of fields.
|
||||||
|
|
||||||
Each field should have batch and channel dimensions followed by spatial
|
Each field should have batch and channel dimensions followed by spatial
|
||||||
@ -141,7 +141,7 @@ def plt_power(*fields, l2e=False, label=None):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if l2e:
|
if l2e:
|
||||||
fields = lag2eul(*fields)
|
fields = lag2eul(*fields, **kwargs)
|
||||||
|
|
||||||
ks, Ps = [], []
|
ks, Ps = [], []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
|
Loading…
Reference in New Issue
Block a user