Fix bug with shared displacement field

This commit is contained in:
Yin Li 2021-05-20 15:16:42 -04:00
parent b7a97bbcdc
commit 6a3990f90c

View File

@ -3,6 +3,7 @@ import torch
from ..data.norms.cosmology import D
def lag2eul(
dis,
val=1.0,
@ -65,7 +66,7 @@ def lag2eul(
out = []
if len(dis) == 1 and len(val) != 1:
dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1:
elif len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0])
for d, v in zip(dis, val):
dtype, device = d.dtype, d.device