Fix bug with shared displacement field
This commit is contained in:
parent
b7a97bbcdc
commit
6a3990f90c
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
from ..data.norms.cosmology import D
|
from ..data.norms.cosmology import D
|
||||||
|
|
||||||
|
|
||||||
def lag2eul(
|
def lag2eul(
|
||||||
dis,
|
dis,
|
||||||
val=1.0,
|
val=1.0,
|
||||||
@ -60,12 +61,12 @@ def lag2eul(
|
|||||||
d_mean = 0
|
d_mean = 0
|
||||||
if rm_dis_mean:
|
if rm_dis_mean:
|
||||||
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
d_mean = sum(d.detach().mean((2, 3, 4), keepdim=True)
|
||||||
for d in dis) / len(dis)
|
for d in dis) / len(dis)
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
if len(dis) == 1 and len(val) != 1:
|
if len(dis) == 1 and len(val) != 1:
|
||||||
dis = itertools.repeat(dis[0])
|
dis = itertools.repeat(dis[0])
|
||||||
if len(dis) != 1 and len(val) == 1:
|
elif len(dis) != 1 and len(val) == 1:
|
||||||
val = itertools.repeat(val[0])
|
val = itertools.repeat(val[0])
|
||||||
for d, v in zip(dis, val):
|
for d, v in zip(dis, val):
|
||||||
dtype, device = d.dtype, d.device
|
dtype, device = d.dtype, d.device
|
||||||
|
Loading…
Reference in New Issue
Block a user