Change Lag2Eul to lag2eul as a function

This commit is contained in:
Yin Li 2020-08-22 23:25:08 -04:00
parent 3eb1b0bccc
commit 670364e54c
3 changed files with 68 additions and 72 deletions

View File

@ -5,7 +5,7 @@ from .patchgan import PatchGAN, PatchGAN42
from .narrow import narrow_by, narrow_cast, narrow_like from .narrow import narrow_by, narrow_cast, narrow_like
from .resample import resample, Resampler from .resample import resample, Resampler
from .lag2eul import Lag2Eul from .lag2eul import lag2eul
from .power import power from .power import power
from .dice import DiceLoss, dice_loss from .dice import DiceLoss, dice_loss

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
class Lag2Eul(nn.Module): def lag2eul(*xs, rm_dis_mean=True, periodic=False):
"""Transform fields from Lagrangian description to Eulerian description """Transform fields from Lagrangian description to Eulerian description
Only works for 3d fields, output same mesh size as input. Only works for 3d fields, output same mesh size as input.
@ -14,17 +14,15 @@ class Lag2Eul(nn.Module):
Implementation follows pmesh/cic.py by Yu Feng. Implementation follows pmesh/cic.py by Yu Feng.
""" """
def __init__(self):
super().__init__()
# FIXME for other redshift, box and mesh sizes # FIXME for other redshift, box and mesh sizes
from ..data.norms.cosmology import D from ..data.norms.cosmology import D
z = 0 z = 0
Boxsize = 1000 Boxsize = 1000
Nmesh = 512 Nmesh = 512
self.dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
def forward(self, *xs, rm_dis_mean=True, periodic=False): if any(x.dim() != 5 for x in xs):
raise NotImplementedError('only support 3d fields for now')
if any(x.shape[1] < 3 for x in xs): if any(x.shape[1] < 3 for x in xs):
raise ValueError('displacement not available with <3 channels') raise ValueError('displacement not available with <3 channels')
@ -48,7 +46,7 @@ class Lag2Eul(nn.Module):
val = x[:, 3:].contiguous().view(N, Cout, -1, 1) val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device) mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
pos = (x[:, :3] - dis_mean) * self.dis_norm pos = (x[:, :3] - dis_mean) * dis_norm
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None] pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None] pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]

View File

@ -16,7 +16,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices, plt_power from .data.figures import plt_slices, plt_power
from . import models from . import models
from .models import narrow_cast, resample, Lag2Eul from .models import narrow_cast, resample, lag2eul
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -126,8 +126,6 @@ def gpu_worker(local_rank, node, args):
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
lag2eul = Lag2Eul()
criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
@ -193,12 +191,12 @@ def gpu_worker(local_rank, node, args):
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, lag2eul, criterion, train_loss = train(epoch, train_loader, model, criterion,
optimizer, scheduler, logger, device, args) optimizer, scheduler, logger, device, args)
epoch_loss = train_loss epoch_loss = train_loss
if args.val: if args.val:
val_loss = validate(epoch, val_loader, model, lag2eul, criterion, val_loss = validate(epoch, val_loader, model, criterion,
logger, device, args) logger, device, args)
#epoch_loss = val_loss #epoch_loss = val_loss
@ -229,7 +227,7 @@ def gpu_worker(local_rank, node, args):
dist.destroy_process_group() dist.destroy_process_group()
def train(epoch, loader, model, lag2eul, criterion, def train(epoch, loader, model, criterion,
optimizer, scheduler, logger, device, args): optimizer, scheduler, logger, device, args):
model.train() model.train()
@ -321,7 +319,7 @@ def train(epoch, loader, model, lag2eul, criterion,
return epoch_loss return epoch_loss
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args): def validate(epoch, loader, model, criterion, logger, device, args):
model.eval() model.eval()
rank = dist.get_rank() rank = dist.get_rank()