From 6a3990f90c4e55dc2430b402b0bcb9e85377a3c4 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 20 May 2021 15:16:42 -0400 Subject: [PATCH] Fix bug with shared displacement field --- map2map/models/lag2eul.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/map2map/models/lag2eul.py b/map2map/models/lag2eul.py index 97b4960..9c06f93 100644 --- a/map2map/models/lag2eul.py +++ b/map2map/models/lag2eul.py @@ -3,6 +3,7 @@ import torch from ..data.norms.cosmology import D + def lag2eul( dis, val=1.0, @@ -60,12 +61,12 @@ def lag2eul( d_mean = 0 if rm_dis_mean: d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True) - for d in dis) / len(dis) + for d in dis) / len(dis) out = [] if len(dis) == 1 and len(val) != 1: dis = itertools.repeat(dis[0]) - if len(dis) != 1 and len(val) == 1: + elif len(dis) != 1 and len(val) == 1: val = itertools.repeat(val[0]) for d, v in zip(dis, val): dtype, device = d.dtype, d.device