From 670364e54cbf06b8a293af54f36426dd83a9e963 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 22 Aug 2020 23:25:08 -0400 Subject: [PATCH] Change Lag2Eul to lag2eul as a function --- map2map/models/__init__.py | 2 +- map2map/models/lag2eul.py | 120 ++++++++++++++++++------------------- map2map/train.py | 18 +++--- 3 files changed, 68 insertions(+), 72 deletions(-) diff --git a/map2map/models/__init__.py b/map2map/models/__init__.py index c3fbd64..eca6efc 100644 --- a/map2map/models/__init__.py +++ b/map2map/models/__init__.py @@ -5,7 +5,7 @@ from .patchgan import PatchGAN, PatchGAN42 from .narrow import narrow_by, narrow_cast, narrow_like from .resample import resample, Resampler -from .lag2eul import Lag2Eul +from .lag2eul import lag2eul from .power import power from .dice import DiceLoss, dice_loss diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index 960e227..89a2738 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -class Lag2Eul(nn.Module): +def lag2eul(*xs, rm_dis_mean=True, periodic=False): """Transform fields from Lagrangian description to Eulerian description Only works for 3d fields, output same mesh size as input. @@ -14,81 +14,79 @@ class Lag2Eul(nn.Module): Implementation follows pmesh/cic.py by Yu Feng. """ - def __init__(self): - super().__init__() + # FIXME for other redshift, box and mesh sizes + from ..data.norms.cosmology import D + z = 0 + Boxsize = 1000 + Nmesh = 512 + dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit - # FIXME for other redshift, box and mesh sizes - from ..data.norms.cosmology import D - z = 0 - Boxsize = 1000 - Nmesh = 512 - self.dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit + if any(x.dim() != 5 for x in xs): + raise NotImplementedError('only support 3d fields for now') + if any(x.shape[1] < 3 for x in xs): + raise ValueError('displacement not available with <3 channels') - def forward(self, *xs, rm_dis_mean=True, periodic=False): - if any(x.shape[1] < 3 for x in xs): - raise ValueError('displacement not available with <3 channels') + # common mean displacement of all inputs + # if removed, fewer particles go outside of the box + # common for all inputs so outputs are comparable in the same coords + dis_mean = 0 + if rm_dis_mean: + dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True) + for x in xs) / len(xs) - # common mean displacement of all inputs - # if removed, fewer particles go outside of the box - # common for all inputs so outputs are comparable in the same coords - dis_mean = 0 - if rm_dis_mean: - dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True) - for x in xs) / len(xs) + out = [] + for x in xs: + N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:] - out = [] - for x in xs: - N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:] + if Cin == 3: + Cout = 1 + val = 1 + else: + Cout = Cin - 3 + val = x[:, 3:].contiguous().view(N, Cout, -1, 1) + mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device) - if Cin == 3: - Cout = 1 - val = 1 - else: - Cout = Cin - 3 - val = x[:, 3:].contiguous().view(N, Cout, -1, 1) - mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device) + pos = (x[:, :3] - dis_mean) * dis_norm - pos = (x[:, :3] - dis_mean) * self.dis_norm + pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] + pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None] + pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device) - pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] - pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None] - pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device) + pos = pos.contiguous().view(N, 3, -1, 1) - pos = pos.contiguous().view(N, 3, -1, 1) + intpos = pos.floor().to(torch.int) + neighbors = (torch.arange(8, device=x.device) + >> torch.arange(3, device=x.device)[:, None, None] ) & 1 + tgtpos = intpos + neighbors + del intpos, neighbors - intpos = pos.floor().to(torch.int) - neighbors = (torch.arange(8, device=x.device) - >> torch.arange(3, device=x.device)[:, None, None] ) & 1 - tgtpos = intpos + neighbors - del intpos, neighbors + # CIC + kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True) + del pos - # CIC - kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True) - del pos + val = val * kernel + del kernel - val = val * kernel - del kernel + tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes + val = val.view(N, Cout, -1) - tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes - val = val.view(N, Cout, -1) + for n in range(N): # because ind has variable length + bounds = torch.tensor(DHW, device=x.device)[:, None] - for n in range(N): # because ind has variable length - bounds = torch.tensor(DHW, device=x.device)[:, None] + if periodic: + torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) - if periodic: - torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) + ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] + ) * DHW[2] + tgtpos[n, 2] + src = val[n] - ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] - ) * DHW[2] + tgtpos[n, 2] - src = val[n] + if not periodic: + mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0) + ind = ind[mask] + src = src[:, mask] - if not periodic: - mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0) - ind = ind[mask] - src = src[:, mask] + mesh[n].view(Cout, -1).index_add_(1, ind, src) - mesh[n].view(Cout, -1).index_add_(1, ind, src) + out.append(mesh) - out.append(mesh) - - return out + return out diff --git a/map2map/train.py b/map2map/train.py index 1291d4d..e88a0bb 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -16,7 +16,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset, DistFieldSampler from .data.figures import plt_slices, plt_power from . import models -from .models import narrow_cast, resample, Lag2Eul +from .models import narrow_cast, resample, lag2eul from .utils import import_attr, load_model_state_dict @@ -126,8 +126,6 @@ def gpu_worker(local_rank, node, args): model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) - lag2eul = Lag2Eul() - criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = criterion() criterion.to(device) @@ -193,13 +191,13 @@ def gpu_worker(local_rank, node, args): for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train_loss = train(epoch, train_loader, model, lag2eul, criterion, - optimizer, scheduler, logger, device, args) + train_loss = train(epoch, train_loader, model, criterion, + optimizer, scheduler, logger, device, args) epoch_loss = train_loss if args.val: - val_loss = validate(epoch, val_loader, model, lag2eul, criterion, - logger, device, args) + val_loss = validate(epoch, val_loader, model, criterion, + logger, device, args) #epoch_loss = val_loss if args.reduce_lr_on_plateau: @@ -229,8 +227,8 @@ def gpu_worker(local_rank, node, args): dist.destroy_process_group() -def train(epoch, loader, model, lag2eul, criterion, - optimizer, scheduler, logger, device, args): +def train(epoch, loader, model, criterion, + optimizer, scheduler, logger, device, args): model.train() rank = dist.get_rank() @@ -321,7 +319,7 @@ def train(epoch, loader, model, lag2eul, criterion, return epoch_loss -def validate(epoch, loader, model, lag2eul, criterion, logger, device, args): +def validate(epoch, loader, model, criterion, logger, device, args): model.eval() rank = dist.get_rank()