Change Lag2Eul to lag2eul as a function
This commit is contained in:
parent
3eb1b0bccc
commit
670364e54c
@ -5,7 +5,7 @@ from .patchgan import PatchGAN, PatchGAN42
|
||||
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||
from .resample import resample, Resampler
|
||||
|
||||
from .lag2eul import Lag2Eul
|
||||
from .lag2eul import lag2eul
|
||||
from .power import power
|
||||
|
||||
from .dice import DiceLoss, dice_loss
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Lag2Eul(nn.Module):
|
||||
def lag2eul(*xs, rm_dis_mean=True, periodic=False):
|
||||
"""Transform fields from Lagrangian description to Eulerian description
|
||||
|
||||
Only works for 3d fields, output same mesh size as input.
|
||||
@ -14,17 +14,15 @@ class Lag2Eul(nn.Module):
|
||||
|
||||
Implementation follows pmesh/cic.py by Yu Feng.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# FIXME for other redshift, box and mesh sizes
|
||||
from ..data.norms.cosmology import D
|
||||
z = 0
|
||||
Boxsize = 1000
|
||||
Nmesh = 512
|
||||
self.dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
||||
dis_norm = 6 * D(z) * Nmesh / Boxsize # to mesh unit
|
||||
|
||||
def forward(self, *xs, rm_dis_mean=True, periodic=False):
|
||||
if any(x.dim() != 5 for x in xs):
|
||||
raise NotImplementedError('only support 3d fields for now')
|
||||
if any(x.shape[1] < 3 for x in xs):
|
||||
raise ValueError('displacement not available with <3 channels')
|
||||
|
||||
@ -48,7 +46,7 @@ class Lag2Eul(nn.Module):
|
||||
val = x[:, 3:].contiguous().view(N, Cout, -1, 1)
|
||||
mesh = torch.zeros(N, Cout, *DHW, dtype=x.dtype, device=x.device)
|
||||
|
||||
pos = (x[:, :3] - dis_mean) * self.dis_norm
|
||||
pos = (x[:, :3] - dis_mean) * dis_norm
|
||||
|
||||
pos[:, 0] += torch.arange(0.5, DHW[0], device=x.device)[:, None, None]
|
||||
pos[:, 1] += torch.arange(0.5, DHW[1], device=x.device)[:, None]
|
||||
|
@ -16,7 +16,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from .data import FieldDataset, DistFieldSampler
|
||||
from .data.figures import plt_slices, plt_power
|
||||
from . import models
|
||||
from .models import narrow_cast, resample, Lag2Eul
|
||||
from .models import narrow_cast, resample, lag2eul
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
||||
|
||||
@ -126,8 +126,6 @@ def gpu_worker(local_rank, node, args):
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
lag2eul = Lag2Eul()
|
||||
|
||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
@ -193,12 +191,12 @@ def gpu_worker(local_rank, node, args):
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_loss = train(epoch, train_loader, model, lag2eul, criterion,
|
||||
train_loss = train(epoch, train_loader, model, criterion,
|
||||
optimizer, scheduler, logger, device, args)
|
||||
epoch_loss = train_loss
|
||||
|
||||
if args.val:
|
||||
val_loss = validate(epoch, val_loader, model, lag2eul, criterion,
|
||||
val_loss = validate(epoch, val_loader, model, criterion,
|
||||
logger, device, args)
|
||||
#epoch_loss = val_loss
|
||||
|
||||
@ -229,7 +227,7 @@ def gpu_worker(local_rank, node, args):
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def train(epoch, loader, model, lag2eul, criterion,
|
||||
def train(epoch, loader, model, criterion,
|
||||
optimizer, scheduler, logger, device, args):
|
||||
model.train()
|
||||
|
||||
@ -321,7 +319,7 @@ def train(epoch, loader, model, lag2eul, criterion,
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
||||
def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
model.eval()
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
Loading…
Reference in New Issue
Block a user