diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index 97b4960..9c06f93 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -3,6 +3,7 @@ import torch from ..data.norms.cosmology import D + def lag2eul( dis, val=1.0, @@ -60,12 +61,12 @@ def lag2eul( d_mean = 0 if rm_dis_mean: d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) - for d in dis) / len(dis) + for d in dis) / len(dis) out = [] if len(dis) == 1 and len(val) != 1: dis = itertools.repeat(dis[0]) - if len(dis) != 1 and len(val) == 1: + elif len(dis) != 1 and len(val) == 1: val = itertools.repeat(val[0]) for d, v in zip(dis, val): dtype, device = d.dtype, d.device