diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index 63913ba..97b4960 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -1,67 +1,103 @@ +import itertools import torch -import torch.nn as nn from ..data.norms.cosmology import D - -def lag2eul(*xs, rm_dis_mean=True, periodic=False, - z=0.0, dis_std=6.0, boxsize=1000., meshsize=512, - **kwargs): +def lag2eul( + dis, + val=1.0, + eul_scale_factor=1, + eul_pad=0, + rm_dis_mean=True, + periodic=False, + z=0.0, + dis_std=6.0, + boxsize=1000., + meshsize=512, + **kwargs): """Transform fields from Lagrangian description to Eulerian description Only works for 3d fields, output same mesh size as input. - Input of shape `(N, C, ...)` is first split into `(N, 3, ...)` and - `(N, C-3, ...)`. Take the former as the displacement field to map the - latter from Lagrangian to Eulerian positions and then "paint" with CIC - (trilinear) scheme. Use 1 if the latter is empty. + Use displacement fields `dis` to map the value fields `val` from Lagrangian + to Eulerian positions and then "paint" with CIC (trilinear) scheme. + Displacement and value fields are paired when are sequences of the same + length. If the displacement (value) field has only one entry, it is shared + by all the value (displacement) fields. + + The Eulerian size is scaled by the `eul_scale_factor` and then padded by + the `eul_pad`. + + Common mean displacement of all inputs can be removed to bring more + particles inside the box. Periodic boundary condition can be turned on. Note that the box and mesh sizes don't have to be that of the inputs, as long as their ratio gives the right resolution. One can therefore set them - to the values of the whole fields, and use smaller inputs. + to the values of the whole Lagrangian fields, and use smaller inputs. Implementation follows pmesh/cic.py by Yu Feng. """ # NOTE the following factor assumes normalized displacements # and thus undoes it dis_norm = dis_std * D(z) * meshsize / boxsize # to mesh unit + dis_norm *= eul_scale_factor - if any(x.dim() != 5 for x in xs): + if isinstance(dis, torch.Tensor): + dis = [dis] + if isinstance(val, (float, torch.Tensor)): + val = [val] + if len(dis) != len(val) and len(dis) != 1 and len(val) != 1: + raise ValueError('dis-val field mismatch') + + if any(d.dim() != 5 for d in dis): raise NotImplementedError('only support 3d fields for now') - if any(x.shape[1] < 3 for x in xs): - raise ValueError('displacement not available with <3 channels') + if any(d.shape[1] != 3 for d in dis): + raise ValueError('only support 3d displacement fields') # common mean displacement of all inputs # if removed, fewer particles go outside of the box # common for all inputs so outputs are comparable in the same coords - dis_mean = 0 + d_mean = 0 if rm_dis_mean: - dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True) - for x in xs) / len(xs) + d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) + for d in dis) / len(dis) out = [] - for x in xs: - N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:] + if len(dis) == 1 and len(val) != 1: + dis = itertools.repeat(dis[0]) + if len(dis) != 1 and len(val) == 1: + val = itertools.repeat(val[0]) + for d, v in zip(dis, val): + dtype, device = d.dtype, d.device - if Cin == 3: - Cout = 1 - val = 1 + N, DHW = d.shape[0], d.shape[2:] + DHW = torch.Size([s * eul_scale_factor + 2 * eul_pad for s in DHW]) + + if isinstance(v, float): + C = 1 else: - Cout = Cin - 3 - val = x[:, 3:].contiguous().view(N, Cout, -1, 1) - mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device) + C = v.shape[1] + v = v.contiguous().flatten(start_dim=2).unsqueeze(-1) - pos = (x[:, :3] - dis_mean) * dis_norm + mesh = torch.zeros(N, C, *DHW, dtype=dtype, device=device) - pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] - pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None] - pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device) + pos = (d - d_mean) * dis_norm + del d - pos = pos.contiguous().view(N, 3, -1, 1) + pos[:, 0] += torch.arange(0, DHW[0] - 2 * eul_pad, eul_scale_factor, + dtype=dtype, device=device)[:, None, None] + pos[:, 1] += torch.arange(0, DHW[1] - 2 * eul_pad, eul_scale_factor, + dtype=dtype, device=device)[:, None] + pos[:, 2] += torch.arange(0, DHW[2] - 2 * eul_pad, eul_scale_factor, + dtype=dtype, device=device) + + pos = pos.contiguous().view(N, 3, -1, 1) # last axis for neighbors intpos = pos.floor().to(torch.int) - neighbors = (torch.arange(8, device=x.device) - >> torch.arange(3, device=x.device)[:, None, None] ) & 1 + neighbors = ( + torch.arange(8, device=device) + >> torch.arange(3, device=device)[:, None, None] + ) & 1 tgtpos = intpos + neighbors del intpos, neighbors @@ -69,28 +105,28 @@ def lag2eul(*xs, rm_dis_mean=True, periodic=False, kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True) del pos - val = val * kernel + v = v * kernel del kernel tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes - val = val.view(N, Cout, -1) + v = v.view(N, C, -1) for n in range(N): # because ind has variable length - bounds = torch.tensor(DHW, device=x.device)[:, None] + bounds = torch.tensor(DHW, device=device)[:, None] if periodic: torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] ) * DHW[2] + tgtpos[n, 2] - src = val[n] + src = v[n] if not periodic: mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0) ind = ind[mask] src = src[:, mask] - mesh[n].view(Cout, -1).index_add_(1, ind, src) + mesh[n].view(C, -1).index_add_(1, ind, src) out.append(mesh) diff --git a/map2map/models/power.py b/map2map/models/power.py index 81fb7ac..7aaf249 100644 --- a/map2map/models/power.py +++ b/map2map/models/power.py @@ -1,7 +1,5 @@ import torch -from .lag2eul import lag2eul - def power(x): """Compute power spectra of input fields diff --git a/map2map/train.py b/map2map/train.py index 6307649..2388f7a 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -269,7 +269,7 @@ def train(epoch, loader, model, criterion, print('narrowed shape :', output.shape, flush=True) lag_out, lag_tgt = output, target - eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) + eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt) @@ -326,8 +326,11 @@ def train(epoch, loader, model, criterion, #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) #fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, l2e=True, - # label=['in', 'out', 'tgt'], **args.misc_kwargs) + #fig = plt_power(1.0, + # dis=[input, lag_out, lag_tgt], + # label=['in', 'out', 'tgt'], + # **args.misc_kwargs, + #) #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) #fig.clf() @@ -358,7 +361,7 @@ def validate(epoch, loader, model, criterion, logger, device, args): input, output, target = narrow_cast(input, output, target) lag_out, lag_tgt = output, target - eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) + eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt) @@ -392,8 +395,11 @@ def validate(epoch, loader, model, criterion, logger, device, args): #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) #fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, l2e=True, - # label=['in', 'out', 'tgt'], **args.misc_kwargs) + #fig = plt_power(1.0, + # dis=[input, lag_out, lag_tgt], + # label=['in', 'out', 'tgt'], + # **args.misc_kwargs, + #) #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) #fig.clf() diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index 500e1f8..7e98d18 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -122,26 +122,26 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None, **kwargs): return fig -def plt_power(*fields, l2e=False, label=None, **kwargs): +def plt_power(*fields, dis=None, label=None, **kwargs): """Plot power spectra of fields. Each field should have batch and channel dimensions followed by spatial dimensions. - Optionally the field can be transformed by lag2eul first. + Optionally the field can be transformed by lag2eul first if given `dis`. See `map2map.models.power`. """ plt.close('all') if label is not None: - assert len(label) == len(fields) + assert len(label) == len(fields) or len(label) == len(dis) else: label = [None] * len(fields) with torch.no_grad(): - if l2e: - fields = lag2eul(*fields, **kwargs) + if dis: + fields = lag2eul(dis, val=fields, **kwargs) ks, Ps = [], [] for field in fields: