Add scale factor and pad to lag2eul Eulerian field

This commit is contained in:
Yin Li 2021-05-12 18:30:19 -04:00
parent 04d0bea17e
commit 6b5530256b
4 changed files with 89 additions and 49 deletions

View File

@ -1,67 +1,103 @@
import itertools
import torch import torch
import torch.nn as nn
from ..data.norms.cosmology import D from ..data.norms.cosmology import D
def lag2eul(
def lag2eul(*xs, rm_dis_mean=True, periodic=False, dis,
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512, val=1.0,
eul_scale_factor=1,
eul_pad=0,
rm_dis_mean=True,
periodic=False,
z=0.0,
dis_std=6.0,
boxsize=1000.,
meshsize=512,
**kwargs): **kwargs):
"""Transform fields from Lagrangian description to Eulerian description """Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input. Only works for 3d fields, output same mesh size as input.
Input of shape `(N, C, ...)` is first split into `(N, 3, ...)` and Use displacement fields `dis` to map the value fields `val` from Lagrangian
`(N, C-3, ...)`. Take the former as the displacement field to map the to Eulerian positions and then "paint" with CIC (trilinear) scheme.
latter from Lagrangian to Eulerian positions and then "paint" with CIC Displacement and value fields are paired when are sequences of the same
(trilinear) scheme. Use 1 if the latter is empty. length. If the displacement (value) field has only one entry, it is shared
by all the value (displacement) fields.
The Eulerian size is scaled by the `eul_scale_factor` and then padded by
the `eul_pad`.
Common mean displacement of all inputs can be removed to bring more
particles inside the box. Periodic boundary condition can be turned on.
Note that the box and mesh sizes don't have to be that of the inputs, as Note that the box and mesh sizes don't have to be that of the inputs, as
long as their ratio gives the right resolution. One can therefore set them long as their ratio gives the right resolution. One can therefore set them
to the values of the whole fields, and use smaller inputs. to the values of the whole Lagrangian fields, and use smaller inputs.
Implementation follows pmesh/cic.py by Yu Feng. Implementation follows pmesh/cic.py by Yu Feng.
""" """
# NOTE the following factor assumes normalized displacements # NOTE the following factor assumes normalized displacements
# and thus undoes it # and thus undoes it
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
dis_norm *= eul_scale_factor
if any(x.dim() != 5 for x in xs): if isinstance(dis, torch.Tensor):
dis = [dis]
if isinstance(val, (float, torch.Tensor)):
val = [val]
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
raise ValueError('dis-val field mismatch')
if any(d.dim() != 5 for d in dis):
raise NotImplementedError('only support 3d fields for now') raise NotImplementedError('only support 3d fields for now')
if any(x.shape[1] < 3 for x in xs): if any(d.shape[1] != 3 for d in dis):
raise ValueError('displacement not available with <3 channels') raise ValueError('only support 3d displacement fields')
# common mean displacement of all inputs # common mean displacement of all inputs
# if removed, fewer particles go outside of the box # if removed, fewer particles go outside of the box
# common for all inputs so outputs are comparable in the same coords # common for all inputs so outputs are comparable in the same coords
dis_mean = 0 d_mean = 0
if rm_dis_mean: if rm_dis_mean:
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True) d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
for x in xs) / len(xs) for d in dis) / len(dis)
out = [] out = []
for x in xs: if len(dis) == 1 and len(val) != 1:
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:] dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0])
for d, v in zip(dis, val):
dtype, device = d.dtype, d.device
if Cin == 3: N, DHW = d.shape[0], d.shape[2:]
Cout = 1 DHW = torch.Size([s * eul_scale_factor + 2 * eul_pad for s in DHW])
val = 1
if isinstance(v, float):
C = 1
else: else:
Cout = Cin - 3 C = v.shape[1]
val = x[:, 3:].contiguous().view(N, Cout, -1, 1) v = v.contiguous().flatten(start_dim=2).unsqueeze(-1)
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
pos = (x[:, :3] - dis_mean) * dis_norm mesh = torch.zeros(N, C, *DHW, dtype=dtype, device=device)
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] pos = (d - d_mean) * dis_norm
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None] del d
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
pos = pos.contiguous().view(N, 3, -1, 1) pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None, None]
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)[:, None]
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
dtype=dtype, device=device)
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
intpos = pos.floor().to(torch.int) intpos = pos.floor().to(torch.int)
neighbors = (torch.arange(8, device=x.device) neighbors = (
>> torch.arange(3, device=x.device)[:, None, None] ) & 1 torch.arange(8, device=device)
>> torch.arange(3, device=device)[:, None, None]
) & 1
tgtpos = intpos + neighbors tgtpos = intpos + neighbors
del intpos, neighbors del intpos, neighbors
@ -69,28 +105,28 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False,
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True) kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
del pos del pos
val = val * kernel v = v * kernel
del kernel del kernel
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
val = val.view(N, Cout, -1) v = v.view(N, C, -1)
for n in range(N): # because ind has variable length for n in range(N): # because ind has variable length
bounds = torch.tensor(DHW, device=x.device)[:, None] bounds = torch.tensor(DHW, device=device)[:, None]
if periodic: if periodic:
torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
) * DHW[2] + tgtpos[n, 2] ) * DHW[2] + tgtpos[n, 2]
src = val[n] src = v[n]
if not periodic: if not periodic:
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0) mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
ind = ind[mask] ind = ind[mask]
src = src[:, mask] src = src[:, mask]
mesh[n].view(Cout, -1).index_add_(1, ind, src) mesh[n].view(C, -1).index_add_(1, ind, src)
out.append(mesh) out.append(mesh)

View File

@ -1,7 +1,5 @@
import torch import torch
from .lag2eul import lag2eul
def power(x): def power(x):
"""Compute power spectra of input fields """Compute power spectra of input fields

View File

@ -269,7 +269,7 @@ def train(epoch, loader, model, criterion,
print('narrowed shape :', output.shape, flush=True) print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
lag_loss = criterion(lag_out, lag_tgt) lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt) eul_loss = criterion(eul_out, eul_tgt)
@ -326,8 +326,11 @@ def train(epoch, loader, model, criterion,
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True, #fig = plt_power(1.0,
# label=['in', 'out', 'tgt'], **args.misc_kwargs) # dis=[input, lag_out, lag_tgt],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
@ -358,7 +361,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
lag_loss = criterion(lag_out, lag_tgt) lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt) eul_loss = criterion(eul_out, eul_tgt)
@ -392,8 +395,11 @@ def validate(epoch, loader, model, criterion, logger, device, args):
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf() #fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True, #fig = plt_power(1.0,
# label=['in', 'out', 'tgt'], **args.misc_kwargs) # dis=[input, lag_out, lag_tgt],
# label=['in', 'out', 'tgt'],
# **args.misc_kwargs,
#)
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf() #fig.clf()

View File

@ -122,26 +122,26 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
return fig return fig
def plt_power(*fields, l2e=False, label=None, **kwargs): def plt_power(*fields, dis=None, label=None, **kwargs):
"""Plot power spectra of fields. """Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial Each field should have batch and channel dimensions followed by spatial
dimensions. dimensions.
Optionally the field can be transformed by lag2eul first. Optionally the field can be transformed by lag2eul first if given `dis`.
See `map2map.models.power`. See `map2map.models.power`.
""" """
plt.close('all') plt.close('all')
if label is not None: if label is not None:
assert len(label) == len(fields) assert len(label) == len(fields) or len(label) == len(dis)
else: else:
label = [None] * len(fields) label = [None] * len(fields)
with torch.no_grad(): with torch.no_grad():
if l2e: if dis:
fields = lag2eul(*fields, **kwargs) fields = lag2eul(dis, val=fields, **kwargs)
ks, Ps = [], [] ks, Ps = [], []
for field in fields: for field in fields: