map2map/map2map/models/dice.py
2020-02-03 22:19:38 -05:00

24 lines
499 B
Python

import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, eps=0.):
super().__init__()
self.eps = eps
def forward(self, input, target):
return dice_loss(input, target, self.eps)
def dice_loss(input, target, eps=0.):
input = input.view(-1)
target = target.view(-1)
prod = (input * target).sum()
in_sq = (input * input).sum()
tgt_sq = (target * target).sum()
dice = (2 * prod + eps) / (in_sq + tgt_sq + eps)
return 1 - dice