Add Dice loss
This commit is contained in:
parent
0721301113
commit
698b2a8df7
1 changed files with 11 additions and 0 deletions
11
map2map/models/dice.py
Normal file
11
map2map/models/dice.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
def dice_loss(input, target, eps=0.):
|
||||
input = input.view(-1)
|
||||
target = target.view(-1)
|
||||
|
||||
prod = (input * target).sum()
|
||||
in_sq = (input * input).sum()
|
||||
tgt_sq = (target * target).sum()
|
||||
|
||||
dice = (2 * prod + eps) / (in_sq + tgt_sq + eps)
|
||||
|
||||
return 1 - dice
|
Loading…
Add table
Reference in a new issue