-
Notifications
You must be signed in to change notification settings - Fork 5
/
softmaxMnist.py
90 lines (77 loc) · 3.06 KB
/
softmaxMnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
File Name:softmaxMnist
Description : mnist data sets, softmax model
pytorch 不需要进行 one-hot 编码, 使用类别即可
Email : [email protected]
Date:18-1-16
"""
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Module, functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
def get_data(flag=True):
mnist = MNIST('../datasets/mnist/', train=flag, transform=transforms.ToTensor(), download=flag)
loader = DataLoader(mnist, batch_size=config['batch_size'], shuffle=flag, drop_last=False)
return loader
# 网络模型定义
class Network(Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(config['in_feature'], 500)
self.l2 = nn.Linear(500, 350)
self.l3 = nn.Linear(350, 200)
self.l4 = nn.Linear(200, 130)
self.l5 = nn.Linear(130, config['out_feature'])
def forward(self, x):
data = x.view(-1, config['in_feature'])
y = F.relu(self.l1(data))
y = F.relu(self.l2(y))
y = F.relu(self.l3(y))
y = F.relu(self.l4(y))
return self.l5(y)
def train_m(mod, data_loader):
mod.train()
for batch_idx, (data, target) in enumerate(data_loader):
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = mod.forward(data)
loss = criterion.forward(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
len1 = batch_idx * len(data)
len2 = len(data_loader.dataset)
pec = 100. * batch_idx / len(data_loader)
print(f"Train Epoch: {epoch + 1} [{len1:5d}/{len2:5d} ({pec:3.2f}%)] \t Loss: {loss.data[0]:.5f}")
def test_m(mod, data_loader):
mod.eval()
test_loss, correct = 0, 0
for data, target in data_loader:
data, target = Variable(data, volatile=True), Variable(target)
output = mod(data)
# sum up batch loss
test_loss += criterion(output, target).data[0]
# get the index of the max
_, pred = output.data.max(1, keepdim=True)
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
test_loss /= len(data_loader.dataset)
len1 = len(data_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len1, 100. * correct / len1))
if __name__ == '__main__':
# some config
config = {'batch_size': 64, 'epoch_num': 100, 'lr': 0.001, 'in_feature': 28 * 28, 'out_feature': 10}
train_loader, test_loader = get_data(), get_data(flag=False)
# 模型实例与损失函数, 优化函数
model = Network()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9)
# 训练与测试
for epoch in range(config['epoch_num']):
train_m(model, train_loader)
test_m(model, test_loader)