forked from jeshraghian/snntorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_fcn.py
125 lines (104 loc) · 3.94 KB
/
train_fcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2020 Graphcore Ltd. All rights reserved.
import argparse
import ctypes
import os
import numpy as np
import popart
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import poptorch
import snntorch as snn
from snntorch.functional import loss as SF
from snntorch import surrogate
from snntorch import spikegen
import csv
import time
from datetime import datetime
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
alpha = 0.745
beta = 0.9
num_inputs=784
num_output=10
num_hidden=1000
num_steps=25
batch_size = 128
data_path='/home/jasonh/mnist'
transform = transforms.Compose([
transforms.Resize((28,28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0,),(1,))])
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
opts = poptorch.Options()
opts.Precision.halfFloatCasting(
poptorch.HalfFloatCastingBehavior.HalfUpcastToFloat)
# Create DataLoaders
train_loader = poptorch.DataLoader(options=opts,dataset=mnist_train,batch_size=batch_size,shuffle=True,num_workers=20)
spike_grad = surrogate.straight_through_estimator()
snn.slope = 50
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta=beta)
self.fc2 = nn.Linear(num_hidden, num_output)
self.lif2 = snn.Leaky(beta=beta)
self.loss_fn = SF.ce_count_loss()
def forward(self, x, labels=None):
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
spk2_rec = []
mem2_rec = []
for step in range(num_steps):
cur1 = self.fc1(x.view(batch_size,-1))
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)
spk2_rec.append(spk2)
mem2_rec.append(mem2)
spk2_rec = torch.stack(spk2_rec)
mem2_rec = torch.stack(mem2_rec)
if self.training:
return spk2_rec, poptorch.identity_loss(self.loss_fn(mem2_rec, labels), "none")
#return spk2_rec, poptorch.identity_loss(self.loss_fn(spk2_rec, labels), "none") # Options are "none", "sum" or "mean"
return spk2_rec
if __name__ == '__main__':
net = Model()
#test_net = Model()
#net.half()
#net.train()
#test_net.eval()
optimizer = poptorch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))
poptorch_model = poptorch.trainingModel(net, options=opts, optimizer=optimizer)
# Time
date = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
csv_file = "Batch_Size_" + str(batch_size) + "_Graphcore_Throughput_" + date + ".csv"
with open(csv_file, 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow(["Epoch", "Steps", "Accuracy", "Throughput"])
epochs = 2
for epoch in tqdm(range(epochs), desc="epochs"):
correct = 0.0
total_loss = 0.0
total = 0.0
for i, (data, labels) in enumerate(train_loader):
#flops = FlopCountAnalysis(test_net, data)
#print(flop_count_table(flops))
data = data.half()
start_time = time.time()
output, loss = poptorch_model(data, labels)
end_time = time.time()
if i % 250 == 0:
_, pred = output.sum(dim=0).max(1)
correct = (labels == pred).sum().item()/len(labels)
throughput = len(data)/(end_time - start_time)
print("accuracy: ", correct)
with open(csv_file, 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch, i, correct, throughput])