Add scale factor and pad to lag2eul Eulerian field
This commit is contained in:
parent
04d0bea17e
commit
6b5530256b
@ -1,67 +1,103 @@
|
||||
import itertools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..data.norms.cosmology import D
|
||||
|
||||
|
||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
||||
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
|
||||
def lag2eul(
|
||||
dis,
|
||||
val=1.0,
|
||||
eul_scale_factor=1,
|
||||
eul_pad=0,
|
||||
rm_dis_mean=True,
|
||||
periodic=False,
|
||||
z=0.0,
|
||||
dis_std=6.0,
|
||||
boxsize=1000.,
|
||||
meshsize=512,
|
||||
**kwargs):
|
||||
"""Transform fields from Lagrangian description to Eulerian description
|
||||
|
||||
Only works for 3d fields, output same mesh size as input.
|
||||
|
||||
Input of shape `(N, C, ...)` is first split into `(N, 3, ...)` and
|
||||
`(N, C-3, ...)`. Take the former as the displacement field to map the
|
||||
latter from Lagrangian to Eulerian positions and then "paint" with CIC
|
||||
(trilinear) scheme. Use 1 if the latter is empty.
|
||||
Use displacement fields `dis` to map the value fields `val` from Lagrangian
|
||||
to Eulerian positions and then "paint" with CIC (trilinear) scheme.
|
||||
Displacement and value fields are paired when are sequences of the same
|
||||
length. If the displacement (value) field has only one entry, it is shared
|
||||
by all the value (displacement) fields.
|
||||
|
||||
The Eulerian size is scaled by the `eul_scale_factor` and then padded by
|
||||
the `eul_pad`.
|
||||
|
||||
Common mean displacement of all inputs can be removed to bring more
|
||||
particles inside the box. Periodic boundary condition can be turned on.
|
||||
|
||||
Note that the box and mesh sizes don't have to be that of the inputs, as
|
||||
long as their ratio gives the right resolution. One can therefore set them
|
||||
to the values of the whole fields, and use smaller inputs.
|
||||
to the values of the whole Lagrangian fields, and use smaller inputs.
|
||||
|
||||
Implementation follows pmesh/cic.py by Yu Feng.
|
||||
"""
|
||||
# NOTE the following factor assumes normalized displacements
|
||||
# and thus undoes it
|
||||
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
|
||||
dis_norm *= eul_scale_factor
|
||||
|
||||
if any(x.dim() != 5 for x in xs):
|
||||
if isinstance(dis, torch.Tensor):
|
||||
dis = [dis]
|
||||
if isinstance(val, (float, torch.Tensor)):
|
||||
val = [val]
|
||||
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
|
||||
raise ValueError('dis-val field mismatch')
|
||||
|
||||
if any(d.dim() != 5 for d in dis):
|
||||
raise NotImplementedError('only support 3d fields for now')
|
||||
if any(x.shape[1] < 3 for x in xs):
|
||||
raise ValueError('displacement not available with <3 channels')
|
||||
if any(d.shape[1] != 3 for d in dis):
|
||||
raise ValueError('only support 3d displacement fields')
|
||||
|
||||
# common mean displacement of all inputs
|
||||
# if removed, fewer particles go outside of the box
|
||||
# common for all inputs so outputs are comparable in the same coords
|
||||
dis_mean = 0
|
||||
d_mean = 0
|
||||
if rm_dis_mean:
|
||||
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
|
||||
for x in xs) / len(xs)
|
||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
||||
for d in dis) / len(dis)
|
||||
|
||||
out = []
|
||||
for x in xs:
|
||||
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
|
||||
if len(dis) == 1 and len(val) != 1:
|
||||
dis = itertools.repeat(dis[0])
|
||||
if len(dis) != 1 and len(val) == 1:
|
||||
val = itertools.repeat(val[0])
|
||||
for d, v in zip(dis, val):
|
||||
dtype, device = d.dtype, d.device
|
||||
|
||||
if Cin == 3:
|
||||
Cout = 1
|
||||
val = 1
|
||||
N, DHW = d.shape[0], d.shape[2:]
|
||||
DHW = torch.Size([s * eul_scale_factor + 2 * eul_pad for s in DHW])
|
||||
|
||||
if isinstance(v, float):
|
||||
C = 1
|
||||
else:
|
||||
Cout = Cin - 3
|
||||
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
|
||||
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
|
||||
C = v.shape[1]
|
||||
v = v.contiguous().flatten(start_dim=2).unsqueeze(-1)
|
||||
|
||||
pos = (x[:, :3] - dis_mean) * dis_norm
|
||||
mesh = torch.zeros(N, C, *DHW, dtype=dtype, device=device)
|
||||
|
||||
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
|
||||
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
|
||||
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
|
||||
pos = (d - d_mean) * dis_norm
|
||||
del d
|
||||
|
||||
pos = pos.contiguous().view(N, 3, -1, 1)
|
||||
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)[:, None, None]
|
||||
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)[:, None]
|
||||
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
|
||||
|
||||
intpos = pos.floor().to(torch.int)
|
||||
neighbors = (torch.arange(8, device=x.device)
|
||||
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
|
||||
neighbors = (
|
||||
torch.arange(8, device=device)
|
||||
>> torch.arange(3, device=device)[:, None, None]
|
||||
) & 1
|
||||
tgtpos = intpos + neighbors
|
||||
del intpos, neighbors
|
||||
|
||||
@ -69,28 +105,28 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
||||
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
|
||||
del pos
|
||||
|
||||
val = val * kernel
|
||||
v = v * kernel
|
||||
del kernel
|
||||
|
||||
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
|
||||
val = val.view(N, Cout, -1)
|
||||
v = v.view(N, C, -1)
|
||||
|
||||
for n in range(N): # because ind has variable length
|
||||
bounds = torch.tensor(DHW, device=x.device)[:, None]
|
||||
bounds = torch.tensor(DHW, device=device)[:, None]
|
||||
|
||||
if periodic:
|
||||
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
||||
|
||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
||||
) * DHW[2] + tgtpos[n, 2]
|
||||
src = val[n]
|
||||
src = v[n]
|
||||
|
||||
if not periodic:
|
||||
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
|
||||
ind = ind[mask]
|
||||
src = src[:, mask]
|
||||
|
||||
mesh[n].view(Cout, -1).index_add_(1, ind, src)
|
||||
mesh[n].view(C, -1).index_add_(1, ind, src)
|
||||
|
||||
out.append(mesh)
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
import torch
|
||||
|
||||
from .lag2eul import lag2eul
|
||||
|
||||
|
||||
def power(x):
|
||||
"""Compute power spectra of input fields
|
||||
|
@ -269,7 +269,7 @@ def train(epoch, loader, model, criterion,
|
||||
print('narrowed shape :', output.shape, flush=True)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
@ -326,8 +326,11 @@ def train(epoch, loader, model, criterion,
|
||||
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||
#fig = plt_power(1.0,
|
||||
# dis=[input, lag_out, lag_tgt],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
@ -358,7 +361,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
@ -392,8 +395,11 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
||||
#fig = plt_power(1.0,
|
||||
# dis=[input, lag_out, lag_tgt],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
|
@ -122,26 +122,26 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
||||
return fig
|
||||
|
||||
|
||||
def plt_power(*fields, l2e=False, label=None, **kwargs):
|
||||
def plt_power(*fields, dis=None, label=None, **kwargs):
|
||||
"""Plot power spectra of fields.
|
||||
|
||||
Each field should have batch and channel dimensions followed by spatial
|
||||
dimensions.
|
||||
|
||||
Optionally the field can be transformed by lag2eul first.
|
||||
Optionally the field can be transformed by lag2eul first if given `dis`.
|
||||
|
||||
See `map2map.models.power`.
|
||||
"""
|
||||
plt.close('all')
|
||||
|
||||
if label is not None:
|
||||
assert len(label) == len(fields)
|
||||
assert len(label) == len(fields) or len(label) == len(dis)
|
||||
else:
|
||||
label = [None] * len(fields)
|
||||
|
||||
with torch.no_grad():
|
||||
if l2e:
|
||||
fields = lag2eul(*fields, **kwargs)
|
||||
if dis:
|
||||
fields = lag2eul(dis, val=fields, **kwargs)
|
||||
|
||||
ks, Ps = [], []
|
||||
for field in fields:
|
||||
|
Loading…
Reference in New Issue
Block a user