Change Lag2Eul to lag2eul as a function
This commit is contained in:
parent
3eb1b0bccc
commit
670364e54c
@ -5,7 +5,7 @@ from .patchgan import PatchGAN, PatchGAN42
|
|||||||
from .narrow import narrow_by, narrow_cast, narrow_like
|
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||||
from .resample import resample, Resampler
|
from .resample import resample, Resampler
|
||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
from .lag2eul import lag2eul
|
||||||
from .power import power
|
from .power import power
|
||||||
|
|
||||||
from .dice import DiceLoss, dice_loss
|
from .dice import DiceLoss, dice_loss
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class Lag2Eul(nn.Module):
|
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
||||||
"""Transform fields from Lagrangian description to Eulerian description
|
"""Transform fields from Lagrangian description to Eulerian description
|
||||||
|
|
||||||
Only works for 3d fields, output same mesh size as input.
|
Only works for 3d fields, output same mesh size as input.
|
||||||
@ -14,81 +14,79 @@ class Lag2Eul(nn.Module):
|
|||||||
|
|
||||||
Implementation follows pmesh/cic.py by Yu Feng.
|
Implementation follows pmesh/cic.py by Yu Feng.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
# FIXME for other redshift, box and mesh sizes
|
||||||
super().__init__()
|
from ..data.norms.cosmology import D
|
||||||
|
z = 0
|
||||||
|
Boxsize = 1000
|
||||||
|
Nmesh = 512
|
||||||
|
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
||||||
|
|
||||||
# FIXME for other redshift, box and mesh sizes
|
if any(x.dim() != 5 for x in xs):
|
||||||
from ..data.norms.cosmology import D
|
raise NotImplementedError('only support 3d fields for now')
|
||||||
z = 0
|
if any(x.shape[1] < 3 for x in xs):
|
||||||
Boxsize = 1000
|
raise ValueError('displacement not available with <3 channels')
|
||||||
Nmesh = 512
|
|
||||||
self.dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
|
||||||
|
|
||||||
def forward(self, *xs, rm_dis_mean=True, periodic=False):
|
# common mean displacement of all inputs
|
||||||
if any(x.shape[1] < 3 for x in xs):
|
# if removed, fewer particles go outside of the box
|
||||||
raise ValueError('displacement not available with <3 channels')
|
# common for all inputs so outputs are comparable in the same coords
|
||||||
|
dis_mean = 0
|
||||||
|
if rm_dis_mean:
|
||||||
|
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
|
||||||
|
for x in xs) / len(xs)
|
||||||
|
|
||||||
# common mean displacement of all inputs
|
out = []
|
||||||
# if removed, fewer particles go outside of the box
|
for x in xs:
|
||||||
# common for all inputs so outputs are comparable in the same coords
|
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
|
||||||
dis_mean = 0
|
|
||||||
if rm_dis_mean:
|
|
||||||
dis_mean = sum(x[:, :3].detach().mean((2, 3, 4), keepdim=True)
|
|
||||||
for x in xs) / len(xs)
|
|
||||||
|
|
||||||
out = []
|
if Cin == 3:
|
||||||
for x in xs:
|
Cout = 1
|
||||||
N, Cin, DHW = x.shape[0], x.shape[1], x.shape[2:]
|
val = 1
|
||||||
|
else:
|
||||||
|
Cout = Cin - 3
|
||||||
|
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
|
||||||
|
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
if Cin == 3:
|
pos = (x[:, :3] - dis_mean) * dis_norm
|
||||||
Cout = 1
|
|
||||||
val = 1
|
|
||||||
else:
|
|
||||||
Cout = Cin - 3
|
|
||||||
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
|
|
||||||
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
pos = (x[:, :3] - dis_mean) * self.dis_norm
|
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
|
||||||
|
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
|
||||||
|
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
|
||||||
|
|
||||||
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
|
pos = pos.contiguous().view(N, 3, -1, 1)
|
||||||
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
|
|
||||||
pos[:, 2] += torch.arange(0.5, DHW[2], device=x.device)
|
|
||||||
|
|
||||||
pos = pos.contiguous().view(N, 3, -1, 1)
|
intpos = pos.floor().to(torch.int)
|
||||||
|
neighbors = (torch.arange(8, device=x.device)
|
||||||
|
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
|
||||||
|
tgtpos = intpos + neighbors
|
||||||
|
del intpos, neighbors
|
||||||
|
|
||||||
intpos = pos.floor().to(torch.int)
|
# CIC
|
||||||
neighbors = (torch.arange(8, device=x.device)
|
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
|
||||||
>> torch.arange(3, device=x.device)[:, None, None] ) & 1
|
del pos
|
||||||
tgtpos = intpos + neighbors
|
|
||||||
del intpos, neighbors
|
|
||||||
|
|
||||||
# CIC
|
val = val * kernel
|
||||||
kernel = (1.0 - torch.abs(pos - tgtpos)).prod(1, keepdim=True)
|
del kernel
|
||||||
del pos
|
|
||||||
|
|
||||||
val = val * kernel
|
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
|
||||||
del kernel
|
val = val.view(N, Cout, -1)
|
||||||
|
|
||||||
tgtpos = tgtpos.view(N, 3, -1) # fuse spatial and neighbor axes
|
for n in range(N): # because ind has variable length
|
||||||
val = val.view(N, Cout, -1)
|
bounds = torch.tensor(DHW, device=x.device)[:, None]
|
||||||
|
|
||||||
for n in range(N): # because ind has variable length
|
if periodic:
|
||||||
bounds = torch.tensor(DHW, device=x.device)[:, None]
|
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
||||||
|
|
||||||
if periodic:
|
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
||||||
torch.remainder(tgtpos[n], bounds, out=tgtpos[n])
|
) * DHW[2] + tgtpos[n, 2]
|
||||||
|
src = val[n]
|
||||||
|
|
||||||
ind = (tgtpos[n, 0] * DHW[1] + tgtpos[n, 1]
|
if not periodic:
|
||||||
) * DHW[2] + tgtpos[n, 2]
|
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
|
||||||
src = val[n]
|
ind = ind[mask]
|
||||||
|
src = src[:, mask]
|
||||||
|
|
||||||
if not periodic:
|
mesh[n].view(Cout, -1).index_add_(1, ind, src)
|
||||||
mask = ((tgtpos[n] >= 0) & (tgtpos[n] < bounds)).all(0)
|
|
||||||
ind = ind[mask]
|
|
||||||
src = src[:, mask]
|
|
||||||
|
|
||||||
mesh[n].view(Cout, -1).index_add_(1, ind, src)
|
out.append(mesh)
|
||||||
|
|
||||||
out.append(mesh)
|
return out
|
||||||
|
|
||||||
return out
|
|
||||||
|
@ -16,7 +16,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset, DistFieldSampler
|
from .data import FieldDataset, DistFieldSampler
|
||||||
from .data.figures import plt_slices, plt_power
|
from .data.figures import plt_slices, plt_power
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_cast, resample, Lag2Eul
|
from .models import narrow_cast, resample, lag2eul
|
||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -126,8 +126,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
lag2eul = Lag2Eul()
|
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
@ -193,13 +191,13 @@ def gpu_worker(local_rank, node, args):
|
|||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader, model, lag2eul, criterion,
|
train_loss = train(epoch, train_loader, model, criterion,
|
||||||
optimizer, scheduler, logger, device, args)
|
optimizer, scheduler, logger, device, args)
|
||||||
epoch_loss = train_loss
|
epoch_loss = train_loss
|
||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_loss = validate(epoch, val_loader, model, lag2eul, criterion,
|
val_loss = validate(epoch, val_loader, model, criterion,
|
||||||
logger, device, args)
|
logger, device, args)
|
||||||
#epoch_loss = val_loss
|
#epoch_loss = val_loss
|
||||||
|
|
||||||
if args.reduce_lr_on_plateau:
|
if args.reduce_lr_on_plateau:
|
||||||
@ -229,8 +227,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, lag2eul, criterion,
|
def train(epoch, loader, model, criterion,
|
||||||
optimizer, scheduler, logger, device, args):
|
optimizer, scheduler, logger, device, args):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
@ -321,7 +319,7 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
def validate(epoch, loader, model, criterion, logger, device, args):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
Loading…
Reference in New Issue
Block a user