Fix bug with shared displacement field
This commit is contained in:
parent
b7a97bbcdc
commit
6a3990f90c
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
from ..data.norms.cosmology import D
|
||||
|
||||
|
||||
def lag2eul(
|
||||
dis,
|
||||
val=1.0,
|
||||
@ -60,12 +61,12 @@ def lag2eul(
|
||||
d_mean = 0
|
||||
if rm_dis_mean:
|
||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
||||
for d in dis) / len(dis)
|
||||
for d in dis) / len(dis)
|
||||
|
||||
out = []
|
||||
if len(dis) == 1 and len(val) != 1:
|
||||
dis = itertools.repeat(dis[0])
|
||||
if len(dis) != 1 and len(val) == 1:
|
||||
elif len(dis) != 1 and len(val) == 1:
|
||||
val = itertools.repeat(val[0])
|
||||
for d, v in zip(dis, val):
|
||||
dtype, device = d.dtype, d.device
|
||||
|
Loading…
Reference in New Issue
Block a user