Add misc kwargs passing to custom models and norms

This commit is contained in:
Yin Li 2021-05-12 16:40:00 -04:00
parent 8a3cd1843d
commit 04d0bea17e
11 changed files with 56 additions and 49 deletions

View File

@ -133,9 +133,7 @@ The model `__init__` requires two positional arguments, the number of
input and output channels. input and output channels.
Other hyperparameters can be specified as keyword arguments, including Other hyperparameters can be specified as keyword arguments, including
the `scale_factor` useful for super-resolution tasks. the `scale_factor` useful for super-resolution tasks.
Note that the `**kwargs` is necessary when `scale_factor` is not Note that the `**kwargs` is necessary for compatibility.
specified, because `scale_factor` is always passed when instantiating
a model.
### Training ### Training

View File

@ -89,7 +89,10 @@ def add_common_args(parser):
help='directory of custorm code defining callbacks for models, ' help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. ' 'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, ' 'This is appended to the default locations, '
'thus has the lowest priority.') 'thus has the lowest priority')
parser.add_argument('--misc-kwargs', default='{}', type=json.loads,
help='miscellaneous keyword arguments for custom models and '
'norms. Be careful with name collisions')
def add_train_args(parser): def add_train_args(parser):

View File

@ -47,7 +47,8 @@ class FieldDataset(Dataset):
in_norms=None, tgt_norms=None, callback_at=None, in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None, crop=None, crop_start=None, crop_stop=None, crop_step=None,
in_pad=0, tgt_pad=0, scale_factor=1): in_pad=0, tgt_pad=0, scale_factor=1,
**kwargs):
self.style_files = sorted(glob(style_pattern)) self.style_files = sorted(glob(style_pattern))
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -138,6 +139,8 @@ class FieldDataset(Dataset):
self.nsample = self.nfile * self.ncrop self.nsample = self.nfile * self.ncrop
self.kwargs = kwargs
self.assembly_line = {} self.assembly_line = {}
self.commonpath = os.path.commonpath( self.commonpath = os.path.commonpath(
@ -187,11 +190,11 @@ class FieldDataset(Dataset):
if self.in_norms is not None: if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields): for norm, x in zip(self.in_norms, in_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at) norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x) norm(x, **self.kwargs)
if self.tgt_norms is not None: if self.tgt_norms is not None:
for norm, x in zip(self.tgt_norms, tgt_fields): for norm, x in zip(self.tgt_norms, tgt_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at) norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x) norm(x, **self.kwargs)
if self.augment: if self.augment:
flip_axes = flip(in_fields, None, self.ndim) flip_axes = flip(in_fields, None, self.ndim)

View File

@ -1,2 +1,2 @@
def identity(x, undo=False): def identity(x, undo=False, **kwargs):
pass pass

View File

@ -2,34 +2,22 @@ import numpy as np
from scipy.special import hyp2f1 from scipy.special import hyp2f1
def dis(x, undo=False): def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
z = 0 # FIXME dis_norm = dis_std * D(z) # [Mpc/h]
dis_norm = 6 * D(z) # [Mpc/h]
if not undo: if not undo:
dis_norm = 1 / dis_norm dis_norm = 1 / dis_norm
x *= dis_norm x *= dis_norm
def vel(x, undo=False): def vel(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
z = 0 # FIXME vel_norm = dis_std * D(z) * H(z) * f(z) / (1 + z) # [km/s]
vel_norm = 6 * D(z) * H(z) * f(z) / (1 + z) # [km/s]
if not undo: if not undo:
vel_norm = 1 / vel_norm vel_norm = 1 / vel_norm
x *= vel_norm x *= vel_norm
def den(x, undo=False):
raise NotImplementedError
z = 0 # FIXME
den_norm = 0 # FIXME
if not undo:
den_norm = 1 / den_norm
x *= den_norm
def D(z, Om=0.31): def D(z, Om=0.31):
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero """linear growth function for flat LambdaCDM, normalized to 1 at redshift zero

View File

@ -1,25 +1,25 @@
import torch import torch
def exp(x, undo=False): def exp(x, undo=False, **kwargs):
if not undo: if not undo:
torch.exp(x, out=x) torch.exp(x, out=x)
else: else:
torch.log(x, out=x) torch.log(x, out=x)
def log(x, eps=1e-8, undo=False): def log(x, eps=1e-8, undo=False, **kwargs):
if not undo: if not undo:
torch.log(x + eps, out=x) torch.log(x + eps, out=x)
else: else:
torch.exp(x, out=x) torch.exp(x, out=x)
def expm1(x, undo=False): def expm1(x, undo=False, **kwargs):
if not undo: if not undo:
torch.expm1(x, out=x) torch.expm1(x, out=x)
else: else:
torch.log1p(x, out=x) torch.log1p(x, out=x)
def log1p(x, eps=1e-7, undo=False): def log1p(x, eps=1e-7, undo=False, **kwargs):
if not undo: if not undo:
torch.log1p(x + eps, out=x) torch.log1p(x + eps, out=x)
else: else:

View File

@ -1,8 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..data.norms.cosmology import D
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
**kwargs):
"""Transform fields from Lagrangian description to Eulerian description """Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input. Only works for 3d fields, output same mesh size as input.
@ -12,14 +16,15 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False):
latter from Lagrangian to Eulerian positions and then "paint" with CIC latter from Lagrangian to Eulerian positions and then "paint" with CIC
(trilinear) scheme. Use 1 if the latter is empty. (trilinear) scheme. Use 1 if the latter is empty.
Note that the box and mesh sizes don't have to be that of the inputs, as
long as their ratio gives the right resolution. One can therefore set them
to the values of the whole fields, and use smaller inputs.
Implementation follows pmesh/cic.py by Yu Feng. Implementation follows pmesh/cic.py by Yu Feng.
""" """
# FIXME for other redshift, box and mesh sizes # NOTE the following factor assumes normalized displacements
from ..data.norms.cosmology import D # and thus undoes it
z = 0 dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
Boxsize = 1000
Nmesh = 512
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
if any(x.dim() != 5 for x in xs): if any(x.dim() != 5 for x in xs):
raise NotImplementedError('only support 3d fields for now') raise NotImplementedError('only support 3d fields for now')

View File

@ -8,7 +8,8 @@ from .resample import Resampler
class G(nn.Module): class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16, def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=False): chan_base=512, chan_min=64, chan_max=512, cat_noise=False,
**kwargs):
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -137,7 +138,8 @@ class AddNoise(nn.Module):
class D(nn.Module): class D(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16, def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512): chan_base=512, chan_min=64, chan_max=512,
**kwargs):
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor

View File

@ -34,6 +34,7 @@ def test(args):
in_pad=args.in_pad, in_pad=args.in_pad,
tgt_pad=args.tgt_pad, tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
**args.misc_kwargs,
) )
test_loader = DataLoader( test_loader = DataLoader(
test_dataset, test_dataset,
@ -47,7 +48,8 @@ def test(args):
out_chan = test_dataset.tgt_chan out_chan = test_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) model = model(style_size, sum(in_chan), sum(out_chan),
scale_factor=args.scale_factor, **args.misc_kwargs)
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
criterion = criterion() criterion = criterion()
@ -75,14 +77,14 @@ def test(args):
# start = 0 # start = 0
# for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): # for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
# norm = import_attr(norm, norms, callback_at=args.callback_at) # norm = import_attr(norm, norms, callback_at=args.callback_at)
# norm(input[:, start:stop], undo=True) # norm(input[:, start:stop], undo=True, **args.misc_kwargs)
# start = stop # start = stop
if args.tgt_norms is not None: if args.tgt_norms is not None:
start = 0 start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at) norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True) norm(output[:, start:stop], undo=True, **args.misc_kwargs)
#norm(target[:, start:stop], undo=True) #norm(target[:, start:stop], undo=True, **args.misc_kwargs)
start = stop start = stop
#test_dataset.assemble('_in', in_chan, input, #test_dataset.assemble('_in', in_chan, input,

View File

@ -76,6 +76,7 @@ def gpu_worker(local_rank, node, args):
in_pad=args.in_pad, in_pad=args.in_pad,
tgt_pad=args.tgt_pad, tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
**args.misc_kwargs,
) )
train_sampler = DistFieldSampler(train_dataset, shuffle=True, train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data, div_data=args.div_data,
@ -108,6 +109,7 @@ def gpu_worker(local_rank, node, args):
in_pad=args.in_pad, in_pad=args.in_pad,
tgt_pad=args.tgt_pad, tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
**args.misc_kwargs,
) )
val_sampler = DistFieldSampler(val_dataset, shuffle=False, val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data, div_data=args.div_data,
@ -127,7 +129,7 @@ def gpu_worker(local_rank, node, args):
model = import_attr(args.model, models, callback_at=args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(args.style_size, sum(args.in_chan), sum(args.out_chan), model = model(args.style_size, sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor) scale_factor=args.scale_factor, **args.misc_kwargs)
model.to(device) model.to(device)
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
@ -314,16 +316,18 @@ def train(epoch, loader, model, criterion,
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs,
) )
logger.add_figure('fig/train', fig, global_step=epoch+1) logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
# **args.misc_kwargs)
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True, #fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt']) # label=['in', 'out', 'tgt'], **args.misc_kwargs)
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
@ -378,16 +382,18 @@ def validate(epoch, loader, model, criterion, logger, device, args):
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs,
) )
logger.add_figure('fig/val', fig, global_step=epoch+1) logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
# **args.misc_kwargs)
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True, #fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt']) # label=['in', 'out', 'tgt'], **args.misc_kwargs)
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf() #fig.clf()

View File

@ -14,7 +14,7 @@ def quantize(x):
return 2 ** round(log2(x), ndigits=1) return 2 ** round(log2(x), ndigits=1)
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
"""Plot slices of fields of more than 2 spatial dimensions. """Plot slices of fields of more than 2 spatial dimensions.
Each field should have a channel dimension followed by spatial dimensions, Each field should have a channel dimension followed by spatial dimensions,
@ -122,7 +122,7 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
return fig return fig
def plt_power(*fields, l2e=False, label=None): def plt_power(*fields, l2e=False, label=None, **kwargs):
"""Plot power spectra of fields. """Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial Each field should have batch and channel dimensions followed by spatial
@ -141,7 +141,7 @@ def plt_power(*fields, l2e=False, label=None):
with torch.no_grad(): with torch.no_grad():
if l2e: if l2e:
fields = lag2eul(*fields) fields = lag2eul(*fields, **kwargs)
ks, Ps = [], [] ks, Ps = [], []
for field in fields: for field in fields: