-
Notifications
You must be signed in to change notification settings - Fork 0
/
OCT_54_vf_model.py
39 lines (32 loc) · 1.11 KB
/
OCT_54_vf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
from torchvision import models
class OCTVF54Model(nn.Module):
def __init__(self, config):
super().__init__()
self.loss_fn=nn.MSELoss(reduction='none')
self.config = config
self.arch = self.config['arch']
if self.arch == 'resnet18':
self.net = models.resnet18(pretrained=True)
n_out = 512
elif self.arch == 'resnet50':
self.net = models.resnet50(pretrained=True)
n_out = 2048
self.net.fc = nn.Sequential(
nn.Linear(n_out, n_out),
nn.ReLU(),
nn.Linear(n_out, 1)
)
def backbone_parameters(self):
return map(lambda kv: kv[1], filter(lambda kv: not kv[0].startswith('fc.'), self.net.named_parameters()))
def head_parameters(self):
return self.net.fc.parameters()
def _forward(self, img):
return self.net(img)
def forward(self, img, label=None, loss_weights=None):
y = self.net(img)
loss = None
if label is not None:
loss = self.loss_fn(y, label)
return y, loss