Fix bug with shared displacement field

This commit is contained in:
Yin Li 2021-05-20 15:16:42 -04:00
parent b7a97bbcdc
commit 6a3990f90c

View File

@ -3,6 +3,7 @@ import torch
from ..data.norms.cosmology import D from ..data.norms.cosmology import D
def lag2eul( def lag2eul(
dis, dis,
val=1.0, val=1.0,
@ -60,12 +61,12 @@ def lag2eul(
d_mean = 0 d_mean = 0
if rm_dis_mean: if rm_dis_mean:
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
for d in dis) / len(dis) for d in dis) / len(dis)
out = [] out = []
if len(dis) == 1 and len(val) != 1: if len(dis) == 1 and len(val) != 1:
dis = itertools.repeat(dis[0]) dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1: elif len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0]) val = itertools.repeat(val[0])
for d, v in zip(dis, val): for d, v in zip(dis, val):
dtype, device = d.dtype, d.device dtype, device = d.dtype, d.device