Wrap Dice loss with nn.Module
This commit is contained in:
parent
db69e9f953
commit
a22fb64d12
@ -1,4 +1,7 @@
|
|||||||
from .unet import UNet
|
from .unet import UNet
|
||||||
from .vnet import VNet, VNetFat
|
from .vnet import VNet, VNetFat
|
||||||
from .patchgan import PatchGAN
|
from .patchgan import PatchGAN
|
||||||
|
|
||||||
from .conv import narrow_like
|
from .conv import narrow_like
|
||||||
|
|
||||||
|
from .dice import DiceLoss, dice_loss
|
||||||
|
@ -1,3 +1,15 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def __init__(self, eps=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
return dice_loss(input, target, self.eps)
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(input, target, eps=0.):
|
def dice_loss(input, target, eps=0.):
|
||||||
input = input.view(-1)
|
input = input.view(-1)
|
||||||
target = target.view(-1)
|
target = target.view(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user