-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
58 lines (48 loc) · 2.15 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn.functional as F
def sensitivity(preds, masks, smooth=1e-6):
true_positive = (preds & masks).sum()
actual_positive = masks.sum()
sensitivity = (true_positive + smooth) / (actual_positive + smooth)
return sensitivity
def specificity(preds, masks, smooth=1e-6):
true_negative = ((~preds.bool()) & (~masks.bool())).sum()
actual_negative = (~masks.bool()).sum()
specificity = (true_negative.float() + smooth) / (actual_negative.float() + smooth)
return specificity.item() # return as plain number
def precision(preds, masks, smooth=1e-6):
true_positive = (preds & masks).sum()
predicted_positive = preds.sum()
precision = (true_positive + smooth) / (predicted_positive + smooth)
return precision
def f1_score(preds, masks, smooth=1e-6):
prec = precision(preds, masks, smooth)
sens = sensitivity(preds, masks, smooth)
f1 = 2 * (prec * sens) / (prec + sens + smooth)
return f1
def dice_score(pred, target, eps=1e-7):
intersection = (pred * target).sum(dim=(1, 2))
union = pred.sum(dim=(1, 2)) + target.sum(dim=(1, 2))
dice = (2. * intersection + eps) / (union + eps)
return dice.mean()
def dice_loss(pred, target):
return 1- dice_score(pred, target)
def dice_binary(pred, target, threshold=0.5):
pred_bin = (pred >= threshold).float()
target_bin = (target >= threshold).float()
return dice_score(pred_bin, target_bin)
def combined_loss(output, target, beta=0.5):
bce = F.binary_cross_entropy_with_logits(output, target)
dice = dice_loss(torch.sigmoid(output), target) # assuming dice_loss calculates 1 - dice_score
return beta * bce + (1 - beta) * dice
def pixel_accuracy(preds, targets, threshold=0.5):
preds_bin = (preds >= threshold).float()
correct = (preds_bin == targets).float().sum()
total = targets.numel()
return correct / total
def iou(preds, targets, threshold=0.5, eps=1e-7):
preds_bin = (preds >= threshold).float()
targets_bin = (targets >= threshold).float()
intersection = (preds_bin * targets_bin).sum()
total = (preds_bin + targets_bin).sum() - intersection
return (intersection + eps) / (total + eps)