Wrap Dice loss with nn.Module
This commit is contained in:
parent
db69e9f953
commit
a22fb64d12
@ -1,4 +1,7 @@
|
||||
from .unet import UNet
|
||||
from .vnet import VNet, VNetFat
|
||||
from .patchgan import PatchGAN
|
||||
|
||||
from .conv import narrow_like
|
||||
|
||||
from .dice import DiceLoss, dice_loss
|
||||
|
@ -1,3 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, eps=0.):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, input, target):
|
||||
return dice_loss(input, target, self.eps)
|
||||
|
||||
|
||||
def dice_loss(input, target, eps=0.):
|
||||
input = input.view(-1)
|
||||
target = target.view(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user