Add Dice loss
This commit is contained in:
parent
0721301113
commit
698b2a8df7
11
map2map/models/dice.py
Normal file
11
map2map/models/dice.py
Normal file
@ -0,0 +1,11 @@
|
||||
def dice_loss(input, target, eps=0.):
|
||||
input = input.view(-1)
|
||||
target = target.view(-1)
|
||||
|
||||
prod = (input * target).sum()
|
||||
in_sq = (input * input).sum()
|
||||
tgt_sq = (target * target).sum()
|
||||
|
||||
dice = (2 * prod + eps) / (in_sq + tgt_sq + eps)
|
||||
|
||||
return 1 - dice
|
Loading…
Reference in New Issue
Block a user