Fix bug with shared displacement field

This commit is contained in:
Yin Li 2021-05-20 15:16:42 -04:00
parent b7a97bbcdc
commit 6a3990f90c

View File

@ -3,6 +3,7 @@ import torch
from ..data.norms.cosmology import D from ..data.norms.cosmology import D
def lag2eul( def lag2eul(
dis, dis,
val=1.0, val=1.0,
@ -65,7 +66,7 @@ def lag2eul(
out = [] out = []
if len(dis) == 1 and len(val) != 1: if len(dis) == 1 and len(val) != 1:
dis = itertools.repeat(dis[0]) dis = itertools.repeat(dis[0])
if len(dis) != 1 and len(val) == 1: elif len(dis) != 1 and len(val) == 1:
val = itertools.repeat(val[0]) val = itertools.repeat(val[0])
for d, v in zip(dis, val): for d, v in zip(dis, val):
dtype, device = d.dtype, d.device dtype, device = d.dtype, d.device