Add scale factor and pad to lag2eul Eulerian field
This commit is contained in:
parent
04d0bea17e
commit
6b5530256b
@ -1,67 +1,103 @@
|
|||||||
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ..data.norms.cosmology import D
|
from ..data.norms.cosmology import D
|
||||||
|
|
||||||
|
def lag2eul(
|
||||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
dis,
|
||||||
z=0.0, dis_std=6.0, boxsize=1000., meshsize=512,
|
val=1.0,
|
||||||
**kwargs):
|
eul_scale_factor=1,
|
||||||
|
eul_pad=0,
|
||||||
|
rm_dis_mean=True,
|
||||||
|
periodic=False,
|
||||||
|
z=0.0,
|
||||||
|
dis_std=6.0,
|
||||||
|
boxsize=1000.,
|
||||||
|
meshsize=512,
|
||||||
|
**kwargs):
|
||||||
"""Transform fields from Lagrangian description to Eulerian description
|
"""Transform fields from Lagrangian description to Eulerian description
|
||||||
|
|
||||||
Only works for 3d fields, output same mesh size as input.
|
Only works for 3d fields, output same mesh size as input.
|
||||||
|
|
||||||
Input of shape `(N, C, ...)` is first split into `(N, 3, ...)` and
|
Use displacement fields `dis` to map the value fields `val` from Lagrangian
|
||||||
`(N, C-3, ...)`. Take the former as the displacement field to map the
|
to Eulerian positions and then "paint" with CIC (trilinear) scheme.
|
||||||
latter from Lagrangian to Eulerian positions and then "paint" with CIC
|
Displacement and value fields are paired when are sequences of the same
|
||||||
(trilinear) scheme. Use 1 if the latter is empty.
|
length. If the displacement (value) field has only one entry, it is shared
|
||||||
|
by all the value (displacement) fields.
|
||||||
|
|
||||||
|
The Eulerian size is scaled by the `eul_scale_factor` and then padded by
|
||||||
|
the `eul_pad`.
|
||||||
|
|
||||||
|
Common mean displacement of all inputs can be removed to bring more
|
||||||
|
particles inside the box. Periodic boundary condition can be turned on.
|
||||||
|
|
||||||
Note that the box and mesh sizes don't have to be that of the inputs, as
|
Note that the box and mesh sizes don't have to be that of the inputs, as
|
||||||
long as their ratio gives the right resolution. One can therefore set them
|
long as their ratio gives the right resolution. One can therefore set them
|
||||||
to the values of the whole fields, and use smaller inputs.
|
to the values of the whole Lagrangian fields, and use smaller inputs.
|
||||||
|
|
||||||
Implementation follows pmesh/cic.py by Yu Feng.
|
Implementation follows pmesh/cic.py by Yu Feng.
|
||||||
"""
|
"""
|
||||||
# NOTE the following factor assumes normalized displacements
|
# NOTE the following factor assumes normalized displacements
|
||||||
# and thus undoes it
|
# and thus undoes it
|
||||||
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
|
dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit
|
||||||
|
dis_norm *= eul_scale_factor
|
||||||
|
|
||||||
if any(x.dim() != 5 for x in xs):
|
if isinstance(dis, torch.Tensor):
|
||||||
|
dis = [dis]
|
||||||
|
if isinstance(val, (float, torch.Tensor)):
|
||||||
|
val = [val]
|
||||||
|
if len(dis) != len(val) and len(dis) != 1 and len(val) != 1:
|
||||||
|
raise ValueError('dis-val field mismatch')
|
||||||
|
|
||||||
|
if any(d.dim() != 5 for d in dis):
|
||||||
raise NotImplementedError('only support 3d fields for now')
|
raise NotImplementedError('only support 3d fields for now')
|
||||||
if any(x.shape[1] < 3 for x in xs):
|
if any(d.shape[1] != 3 for d in dis):
|
||||||
raise ValueError('displacement not available with <3 channels')
|
raise ValueError('only support 3d displacement fields')
|
||||||
|
|
||||||
# common mean displacement of all inputs
|
# common mean displacement of all inputs
|
||||||
# if removed, fewer particles go outside of the box
|
# if removed, fewer particles go outside of the box
|
||||||
# common for all inputs so outputs are comparable in the same coords
|
# common for all inputs so outputs are comparable in the same coords
|
||||||
dis_mean = 0
|
d_mean = 0
|
||||||
if rm_dis_mean:
|
if rm_dis_mean:
|
||||||
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
|
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
||||||
for x in xs) / len(xs)
|
for d in dis) / len(dis)
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
for x in xs:
|
if len(dis) == 1 and len(val) != 1:
|
||||||
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
|
dis = itertools.repeat(dis[0])
|
||||||
|
if len(dis) != 1 and len(val) == 1:
|
||||||
|
val = itertools.repeat(val[0])
|
||||||
|
for d, v in zip(dis, val):
|
||||||
|
dtype, device = d.dtype, d.device
|
||||||
|
|
||||||
if Cin == 3:
|
N, DHW = d.shape[0], d.shape[2:]
|
||||||
Cout = 1
|
DHW = torch.Size([s * eul_scale_factor + 2 * eul_pad for s in DHW])
|
||||||
val = 1
|
|
||||||
|
if isinstance(v, float):
|
||||||
|
C = 1
|
||||||
else:
|
else:
|
||||||
Cout = Cin - 3
|
C = v.shape[1]
|
||||||
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
|
v = v.contiguous().flatten(start_dim=2).unsqueeze(-1)
|
||||||
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
pos = (x[:, :3] - dis_mean) * dis_norm
|
mesh = torch.zeros(N, C, *DHW, dtype=dtype, device=device)
|
||||||
|
|
||||||
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
|
pos = (d - d_mean) * dis_norm
|
||||||
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
|
del d
|
||||||
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
|
|
||||||
|
|
||||||
pos = pos.contiguous().view(N, 3, -1, 1)
|
pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor,
|
||||||
|
dtype=dtype, device=device)[:, None, None]
|
||||||
|
pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor,
|
||||||
|
dtype=dtype, device=device)[:, None]
|
||||||
|
pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
|
||||||
|
pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors
|
||||||
|
|
||||||
intpos = pos.floor().to(torch.int)
|
intpos = pos.floor().to(torch.int)
|
||||||
neighbors = (torch.arange(8, device=x.device)
|
neighbors = (
|
||||||
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
|
torch.arange(8, device=device)
|
||||||
|
>> torch.arange(3, device=device)[:, None, None]
|
||||||
|
) & 1
|
||||||
tgtpos = intpos + neighbors
|
tgtpos = intpos + neighbors
|
||||||
del intpos, neighbors
|
del intpos, neighbors
|
||||||
|
|
||||||
@ -69,28 +105,28 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False,
|
|||||||
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
|
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
|
||||||
del pos
|
del pos
|
||||||
|
|
||||||
val = val * kernel
|
v = v * kernel
|
||||||
del kernel
|
del kernel
|
||||||
|
|
||||||
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
|
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
|
||||||
val = val.view(N, Cout, -1)
|
v = v.view(N, C, -1)
|
||||||
|
|
||||||
for n in range(N): # because ind has variable length
|
for n in range(N): # because ind has variable length
|
||||||
bounds = torch.tensor(DHW, device=x.device)[:, None]
|
bounds = torch.tensor(DHW, device=device)[:, None]
|
||||||
|
|
||||||
if periodic:
|
if periodic:
|
||||||
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
||||||
|
|
||||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
||||||
) * DHW[2] + tgtpos[n, 2]
|
) * DHW[2] + tgtpos[n, 2]
|
||||||
src = val[n]
|
src = v[n]
|
||||||
|
|
||||||
if not periodic:
|
if not periodic:
|
||||||
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
|
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
|
||||||
ind = ind[mask]
|
ind = ind[mask]
|
||||||
src = src[:, mask]
|
src = src[:, mask]
|
||||||
|
|
||||||
mesh[n].view(Cout, -1).index_add_(1, ind, src)
|
mesh[n].view(C, -1).index_add_(1, ind, src)
|
||||||
|
|
||||||
out.append(mesh)
|
out.append(mesh)
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .lag2eul import lag2eul
|
|
||||||
|
|
||||||
|
|
||||||
def power(x):
|
def power(x):
|
||||||
"""Compute power spectra of input fields
|
"""Compute power spectra of input fields
|
||||||
|
@ -269,7 +269,7 @@ def train(epoch, loader, model, criterion,
|
|||||||
print('narrowed shape :', output.shape, flush=True)
|
print('narrowed shape :', output.shape, flush=True)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
lag_out, lag_tgt = output, target
|
||||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||||
|
|
||||||
lag_loss = criterion(lag_out, lag_tgt)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
@ -326,8 +326,11 @@ def train(epoch, loader, model, criterion,
|
|||||||
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
#fig = plt_power(1.0,
|
||||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
# dis=[input, lag_out, lag_tgt],
|
||||||
|
# label=['in', 'out', 'tgt'],
|
||||||
|
# **args.misc_kwargs,
|
||||||
|
#)
|
||||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
@ -358,7 +361,7 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
lag_out, lag_tgt = output, target
|
||||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||||
|
|
||||||
lag_loss = criterion(lag_out, lag_tgt)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
@ -392,8 +395,11 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
#fig = plt_power(1.0,
|
||||||
# label=['in', 'out', 'tgt'], **args.misc_kwargs)
|
# dis=[input, lag_out, lag_tgt],
|
||||||
|
# label=['in', 'out', 'tgt'],
|
||||||
|
# **args.misc_kwargs,
|
||||||
|
#)
|
||||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||||
#fig.clf()
|
#fig.clf()
|
||||||
|
|
||||||
|
@ -122,26 +122,26 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs):
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plt_power(*fields, l2e=False, label=None, **kwargs):
|
def plt_power(*fields, dis=None, label=None, **kwargs):
|
||||||
"""Plot power spectra of fields.
|
"""Plot power spectra of fields.
|
||||||
|
|
||||||
Each field should have batch and channel dimensions followed by spatial
|
Each field should have batch and channel dimensions followed by spatial
|
||||||
dimensions.
|
dimensions.
|
||||||
|
|
||||||
Optionally the field can be transformed by lag2eul first.
|
Optionally the field can be transformed by lag2eul first if given `dis`.
|
||||||
|
|
||||||
See `map2map.models.power`.
|
See `map2map.models.power`.
|
||||||
"""
|
"""
|
||||||
plt.close('all')
|
plt.close('all')
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
assert len(label) == len(fields)
|
assert len(label) == len(fields) or len(label) == len(dis)
|
||||||
else:
|
else:
|
||||||
label = [None] * len(fields)
|
label = [None] * len(fields)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if l2e:
|
if dis:
|
||||||
fields = lag2eul(*fields, **kwargs)
|
fields = lag2eul(dis, val=fields, **kwargs)
|
||||||
|
|
||||||
ks, Ps = [], []
|
ks, Ps = [], []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
|
Loading…
Reference in New Issue
Block a user