-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDebug.py
95 lines (94 loc) · 3.73 KB
/
Debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from dataset import ImageDataset, RandomIdentitySampler
from market1501 import Market1501
from transform import Train_Transform, Val_Transform
from torch.utils.data import DataLoader
from model import Train_Model
from tripletloss import TripletLoss
from cross_entropy_smooth import CrossEntropySmooth
from optimizer import Make_Optimizer, Warmup
import os
import time
if torch.cuda.is_available():
print("Support GPU!")
device = torch.device('cuda')
else:
device = torch.device('cpu')
market = Market1501(root='./')
train_transform = Train_Transform(True)
val_transform = Val_Transform()
train_sampler = RandomIdentitySampler(market.train, 16, 4)
train_dataset = ImageDataset(dataset=market.train, transform=train_transform)
val_dataset = ImageDataset(dataset=market.test +
market.query, transform=val_transform)
train_dataloader = DataLoader(train_dataset, 64, False, train_sampler)
val_dataloader = DataLoader(val_dataset, 128, False)
Model = Train_Model().to(device)
IDloss = CrossEntropySmooth(market.num_train_id)
optimizer = Make_Optimizer(Model, 3.5e-5)
tripletloss = TripletLoss(0.3)
warmup = Warmup(optimizer)
EPOCH = 120
for epoch in range(EPOCH):
print('Epoch {}/{}'.format(epoch, EPOCH - 1))
print('-' * 10)
warmup.step()
Model.train()
for phase in ['Train', 'Val']:
if phase == 'Train':
start = time.clock()
running_loss = 0.0
running_corrects = 0.0
running_times = 0.0
for index, data in enumerate(train_dataloader):
running_times += 1
Data, Label = data
Data = Data.to(device)
Label = Label.to(device)
ft, fi, out = Model(Data)
_, preds = torch.max(out, 1)
corrects = torch.sum(preds == Label)
running_corrects += float(corrects)
idloss = IDloss(out, Label)
triloss = TripletLoss(ft, Label)
loss = idloss+triloss
running_loss += loss * 64
loss.backward()
optimizer.step()
running_loss /= running_times
running_corrects /= market.num_train_img
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, running_loss, running_corrects))
duration = (time.clock() - start)
print('Training complete in {:.0f}m {:.0f}s'.format(
duration // 60, duration % 60))
'''
if phase == 'Val':
start = time.clock()
running_loss = 0.0
running_corrects = 0.0
running_times = 0.0
for index,data in enumerate(val_dataloader):
running_times += 1
Data, Label = data
Data = Data.to(device)
Label = Label.to(device)
with torch.no_grad():
ft,fi,out = Model(Data)
idloss = IDloss(out,Label)
triloss = TripletLoss(ft,0.3,4)
_,preds = torch.max(out,1)
corrects = torch.sum(preds == Label)
running_corrects += float(corrects)
running_loss += (idloss+triloss)
running_loss /= running_times
running_corrects /= (market.num_query_img + market.num_test_img)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase,running_loss, running_corrects))
duration = (time.clock() - start)
print('Val complete in {:.0f}m {:.0f}s'.format(
duration // 60, duration % 60))
print()
'''
if epoch % 10 == 9:
torch.save(Model.state_dict(), './Model0_' + str(epoch) + '.pth')