Add scale factor and pad to lag2eul Eulerian field

This commit is contained in:
Yin Li 2021-05-12 18:30:19 -04:00
parent 04d0bea17e
commit 6b5530256b
4 changed files with 89 additions and 49 deletions

View File

@ -1,67 +1,103 @@
import itertools
import torch
import torch.nn as nn
from ..data.norms.cosmology import D
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
def lag2eul(
dis,
val=1.0,
eul_scale_factor=1,
eul_pad=0,
rm_dis_mean=True,
periodic=False,
z=0.0,
dis_std=6.0,
boxsize=1000.,
meshsize=512,
**kwargs):
"""Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input.
Input of shape `(N, C, ...)` is first split into `(N, 3, ...)` and
`(N, C-3, ...)`. Take the former as the displacement field to map the
latter from Lagrangian to Eulerian positions and then "paint" with CIC
(trilinear) scheme. Use 1 if the latter is empty.
Use displacement fields `dis` to map the value fields `val` from Lagrangian
to Eulerian positions and then "paint" with CIC (trilinear) scheme.
Displacement and value fields are paired when are sequences of the same
length. If the displacement (value) field has only one entry, it is shared
by all the value (displacement) fields.
The Eulerian size is scaled by the `eul_scale_factor` and then padded by
the `eul_pad`.
Common mean displacement of all inputs can be removed to bring more
particles inside the box. Periodic boundary condition can be turned on.
Note that the box and mesh sizes don't have to be that of the inputs, as
long as their ratio gives the right resolution. One can therefore set them
to the values of the whole fields, and use smaller inputs.
to the values of the whole Lagrangian fields, and use smaller inputs.
Implementation follows pmesh/cic.py by Yu Feng.
"""
# NOTE the following factor assumes normalized displacements
# and thus undoes it
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
dis_norm *= eul_scale_factor
if any(x.dim() != 5 for x in xs):
if isinstance(dis, torch.Tensor):
dis = [dis]
if isinstance(val, (float, torch.Tensor)):
val = [val]
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
raise ValueError('dis-val field mismatch')
if any(d.dim() != 5 for d in dis):
raise NotImplementedError('only support 3d fields for now')
if any(x.shape[1] < 3 for x in xs):
raise ValueError('displacement not available with <3 channels')
if any(d.shape[1] != 3 for d in dis):
raise ValueError('only support 3d displacement fields')
# common mean displacement of all inputs
# if removed, fewer particles go outside of the box
# common for all inputs so outputs are comparable in the same coords
dis_mean = 0
d_mean = 0
if rm_dis_mean:
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
for x in xs) / len(xs)
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
for d in dis) / len(dis)
out = []
for x in xs:
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
if len(dis) == 1 and len(val) != 1:
dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0])
for d, v in zip(dis, val):
dtype, device = d.dtype, d.device
if Cin == 3:
Cout = 1
val = 1
N, DHW = d.shape[0], d.shape[2:]
DHW = torch.Size([s * eul_scale_factor + 2 * eul_pad for s in DHW])
if isinstance(v, float):
C = 1
else:
Cout = Cin - 3
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
C = v.shape[1]
v = v.contiguous().flatten(start_dim=2).unsqueeze(-1)
pos = (x[:, :3] - dis_mean) * dis_norm
mesh = torch.zeros(N, C, *DHW, dtype=dtype, device=device)
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
pos = (d - d_mean) * dis_norm
del d
pos = pos.contiguous().view(N, 3, -1, 1)
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None, None]
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None]
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
intpos = pos.floor().to(torch.int)
neighbors = (torch.arange(8, device=x.device)
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
neighbors = (
torch.arange(8, device=device)
>> torch.arange(3, device=device)[:, None, None]
) & 1
tgtpos = intpos + neighbors
del intpos, neighbors
@ -69,28 +105,28 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False,
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
del pos
val = val * kernel
v = v * kernel
del kernel
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
val = val.view(N, Cout, -1)
v = v.view(N, C, -1)
for n in range(N): # because ind has variable length
bounds = torch.tensor(DHW, device=x.device)[:, None]
bounds = torch.tensor(DHW, device=device)[:, None]
if periodic:
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
) * DHW[2] + tgtpos[n, 2]
src = val[n]
src = v[n]
if not periodic:
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
ind = ind[mask]
src = src[:, mask]
mesh[n].view(Cout, -1).index_add_(1, ind, src)
mesh[n].view(C, -1).index_add_(1, ind, src)
out.append(mesh)

View File

@ -1,7 +1,5 @@
import torch
from .lag2eul import lag2eul
def power(x):
"""Compute power spectra of input fields

View File

@ -269,7 +269,7 @@ def train(epoch, loader, model, criterion,
print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
@ -326,8 +326,11 @@ def train(epoch, loader, model, criterion,
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
#fig = plt_power(1.0,
# dis=[input, lag_out, lag_tgt],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
@ -358,7 +361,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
@ -392,8 +395,11 @@ def validate(epoch, loader, model, criterion, logger, device, args):
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
#fig = plt_power(1.0,
# dis=[input, lag_out, lag_tgt],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()

View File

@ -122,26 +122,26 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
return fig
def plt_power(*fields, l2e=False, label=None, **kwargs):
def plt_power(*fields, dis=None, label=None, **kwargs):
"""Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial
dimensions.
Optionally the field can be transformed by lag2eul first.
Optionally the field can be transformed by lag2eul first if given `dis`.
See `map2map.models.power`.
"""
plt.close('all')
if label is not None:
assert len(label) == len(fields)
assert len(label) == len(fields) or len(label) == len(dis)
else:
label = [None] * len(fields)
with torch.no_grad():
if l2e:
fields = lag2eul(*fields, **kwargs)
if dis:
fields = lag2eul(dis, val=fields, **kwargs)
ks, Ps = [], []
for field in fields: