Wrap Dice loss with nn.Module

This commit is contained in:
Yin Li 2020-02-03 22:19:38 -05:00
parent db69e9f953
commit a22fb64d12
2 changed files with 15 additions and 0 deletions

View File

@ -1,4 +1,7 @@
from .unet import UNet
from .vnet import VNet, VNetFat
from .patchgan import PatchGAN
from .conv import narrow_like
from .dice import DiceLoss, dice_loss

View File

@ -1,3 +1,15 @@
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self, eps=0.):
super().__init__()
self.eps = eps
def forward(self, input, target):
return dice_loss(input, target, self.eps)
def dice_loss(input, target, eps=0.):
input = input.view(-1)
target = target.view(-1)