Fix bug with shared displacement field

This commit is contained in:
Yin Li 2021-05-20 15:16:42 -04:00
parent b7a97bbcdc
commit 6a3990f90c

View File

@ -3,6 +3,7 @@ import torch
from ..data.norms.cosmology import D
def lag2eul(
dis,
val=1.0,
@ -60,12 +61,12 @@ def lag2eul(
d_mean = 0
if rm_dis_mean:
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
for d in dis) / len(dis)
for d in dis) / len(dis)
out = []
if len(dis) == 1 and len(val) != 1:
dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1:
elif len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0])
for d, v in zip(dis, val):
dtype, device = d.dtype, d.device