Add Dice loss
This commit is contained in:
parent
0721301113
commit
698b2a8df7
11
map2map/models/dice.py
Normal file
11
map2map/models/dice.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
def dice_loss(input, target, eps=0.):
|
||||||
|
input = input.view(-1)
|
||||||
|
target = target.view(-1)
|
||||||
|
|
||||||
|
prod = (input * target).sum()
|
||||||
|
in_sq = (input * input).sum()
|
||||||
|
tgt_sq = (target * target).sum()
|
||||||
|
|
||||||
|
dice = (2 * prod + eps) / (in_sq + tgt_sq + eps)
|
||||||
|
|
||||||
|
return 1 - dice
|
Loading…
Reference in New Issue
Block a user