Change Lag2Eul to lag2eul as a function

This commit is contained in:
Yin Li 2020-08-22 23:25:08 -04:00
parent 3eb1b0bccc
commit 670364e54c
3 changed files with 68 additions and 72 deletions

View File

@ -5,7 +5,7 @@ from .patchgan import PatchGAN, PatchGAN42
from .narrow import narrow_by, narrow_cast, narrow_like from .narrow import narrow_by, narrow_cast, narrow_like
from .resample import resample, Resampler from .resample import resample, Resampler
from .lag2eul import Lag2Eul from .lag2eul import lag2eul
from .power import power from .power import power
from .dice import DiceLoss, dice_loss from .dice import DiceLoss, dice_loss

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
class Lag2Eul(nn.Module): def lag2eul(*xs, rm_dis_mean=True, periodic=False):
"""Transform fields from Lagrangian description to Eulerian description """Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input. Only works for 3d fields, output same mesh size as input.
@ -14,81 +14,79 @@ class Lag2Eul(nn.Module):
Implementation follows pmesh/cic.py by Yu Feng. Implementation follows pmesh/cic.py by Yu Feng.
""" """
def __init__(self): # FIXME for other redshift, box and mesh sizes
super().__init__() from ..data.norms.cosmology import D
z = 0
Boxsize = 1000
Nmesh = 512
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
# FIXME for other redshift, box and mesh sizes if any(x.dim() != 5 for x in xs):
from ..data.norms.cosmology import D raise NotImplementedError('only support 3d fields for now')
z = 0 if any(x.shape[1] < 3 for x in xs):
Boxsize = 1000 raise ValueError('displacement not available with <3 channels')
Nmesh = 512
self.dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
def forward(self, *xs, rm_dis_mean=True, periodic=False): # common mean displacement of all inputs
if any(x.shape[1] < 3 for x in xs): # if removed, fewer particles go outside of the box
raise ValueError('displacement not available with <3 channels') # common for all inputs so outputs are comparable in the same coords
dis_mean = 0
if rm_dis_mean:
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
for x in xs) / len(xs)
# common mean displacement of all inputs out = []
# if removed, fewer particles go outside of the box for x in xs:
# common for all inputs so outputs are comparable in the same coords N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
dis_mean = 0
if rm_dis_mean:
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
for x in xs) / len(xs)
out = [] if Cin == 3:
for x in xs: Cout = 1
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:] val = 1
else:
Cout = Cin - 3
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
if Cin == 3: pos = (x[:, :3] - dis_mean) * dis_norm
Cout = 1
val = 1
else:
Cout = Cin - 3
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
pos = (x[:, :3] - dis_mean) * self.dis_norm pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] pos = pos.contiguous().view(N, 3, -1, 1)
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
pos = pos.contiguous().view(N, 3, -1, 1) intpos = pos.floor().to(torch.int)
neighbors = (torch.arange(8, device=x.device)
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
tgtpos = intpos + neighbors
del intpos, neighbors
intpos = pos.floor().to(torch.int) # CIC
neighbors = (torch.arange(8, device=x.device) kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
>> torch.arange(3, device=x.device)[:, None, None] ) & 1 del pos
tgtpos = intpos + neighbors
del intpos, neighbors
# CIC val = val * kernel
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True) del kernel
del pos
val = val * kernel tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
del kernel val = val.view(N, Cout, -1)
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes for n in range(N): # because ind has variable length
val = val.view(N, Cout, -1) bounds = torch.tensor(DHW, device=x.device)[:, None]
for n in range(N): # because ind has variable length if periodic:
bounds = torch.tensor(DHW, device=x.device)[:, None] torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
if periodic: ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
torch.remainder(tgtpos[n], bounds, out=tgtpos[n]) ) * DHW[2] + tgtpos[n, 2]
src = val[n]
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1] if not periodic:
) * DHW[2] + tgtpos[n, 2] mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
src = val[n] ind = ind[mask]
src = src[:, mask]
if not periodic: mesh[n].view(Cout, -1).index_add_(1, ind, src)
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
ind = ind[mask]
src = src[:, mask]
mesh[n].view(Cout, -1).index_add_(1, ind, src) out.append(mesh)
out.append(mesh) return out
return out

View File

@ -16,7 +16,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices, plt_power from .data.figures import plt_slices, plt_power
from . import models from . import models
from .models import narrow_cast, resample, Lag2Eul from .models import narrow_cast, resample, lag2eul
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -126,8 +126,6 @@ def gpu_worker(local_rank, node, args):
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
lag2eul = Lag2Eul()
criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
@ -193,13 +191,13 @@ def gpu_worker(local_rank, node, args):
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, lag2eul, criterion, train_loss = train(epoch, train_loader, model, criterion,
optimizer, scheduler, logger, device, args) optimizer, scheduler, logger, device, args)
epoch_loss = train_loss epoch_loss = train_loss
if args.val: if args.val:
val_loss = validate(epoch, val_loader, model, lag2eul, criterion, val_loss = validate(epoch, val_loader, model, criterion,
logger, device, args) logger, device, args)
#epoch_loss = val_loss #epoch_loss = val_loss
if args.reduce_lr_on_plateau: if args.reduce_lr_on_plateau:
@ -229,8 +227,8 @@ def gpu_worker(local_rank, node, args):
dist.destroy_process_group() dist.destroy_process_group()
def train(epoch, loader, model, lag2eul, criterion, def train(epoch, loader, model, criterion,
optimizer, scheduler, logger, device, args): optimizer, scheduler, logger, device, args):
model.train() model.train()
rank = dist.get_rank() rank = dist.get_rank()
@ -321,7 +319,7 @@ def train(epoch, loader, model, lag2eul, criterion,
return epoch_loss return epoch_loss
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args): def validate(epoch, loader, model, criterion, logger, device, args):
model.eval() model.eval()
rank = dist.get_rank() rank = dist.get_rank()