-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmodel.py
96 lines (88 loc) · 3.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from PIL import Image, ImageFilter
from models import *
#from utils import progress_bar
print('==> Building model..')
def get_model(args, device):
if args.net in ['ResNet','resnet']:
net = ResNet18()
elif args.net in ['VGG','vgg']:
net = VGG('VGG19')
elif args.net == 'GoogLeNet':
net = GoogLeNet()
elif args.net in ['DenseNet','densenet']:
net = DenseNet121()
elif args.net == 'MobileNet':
net = MobileNetV2()
elif args.net == 'LeNet':
net = LeNet()
elif args.net in ['FCNet','fcnet']:
net = FCNet()
elif args.net in ['ViT4','vit']:
net = ViT4()
elif args.net == 'ViT_pt_interpolate':
net = ViT_pt_interpolate()
elif args.net == 'ViT_npt_interpolate':
net = ViT_pt_interpolate(pretrained=False)
elif args.net == 'ViT_pt':
# from https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
import timm
net = timm.create_model("vit_small_patch16_224", pretrained=True)
net.head = nn.Linear(net.head.in_features, 10)
elif args.net == 'MLPMixer4':
net = MLPMixer4()
elif args.net == 'MLPMixer_pt':
import timm
net = timm.create_model("mixer_s16_224", pretrained=True)
net.head = nn.Linear(net.head.in_features, 10)
elif args.net == 'WideResNet':
net = WideResNet(depth=28, num_classes=10, widen_factor=args.widen_factor)
if device == 'cuda' and torch.cuda.device_count() > 1:
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
net = net.to(device)
return net
def get_teacher_model(args, device):
if args.teacher_net == 'ResNet':
net = ResNet18()
elif args.teacher_net == 'VGG':
net = VGG('VGG19')
elif args.teacher_net == 'GoogLeNet':
net = GoogLeNet()
elif args.teacher_net == 'DenseNet':
net = DenseNet121()
elif args.teacher_net == 'MobileNet':
net = MobileNetV2()
elif args.teacher_net == 'LeNet':
net = LeNet()
elif args.teacher_net == 'FCNet':
net = FCNet()
elif args.teacher_net == 'ViT4':
net = ViT4()
elif args.teacher_net == 'ViT_pt_interpolate':
net = ViT_pt_interpolate()
elif args.teacher_net == 'ViT_pt':
# from https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
import timm
net = timm.create_model("vit_small_patch16_224", pretrained=True)
net.head = nn.Linear(net.head.in_features, 10)
elif args.teacher_net == 'MLPMixer4':
net = MLPMixer4()
elif args.teacher_net == 'WideResNet':
net = WideResNet(depth=28, num_classes=10, widen_factor=args.widen_factor)
if device == 'cuda' and torch.cuda.device_count() > 1:
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
net = net.to(device)
return net