-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlenet5.cpp
21 lines (19 loc) · 961 Bytes
/
lenet5.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "lenet5.h"
LeNet5Impl::LeNet5Impl(int input_size) : Module("LeNet5")
{
C1 = register_module("C1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5).padding((32 - input_size) / 2)));
C3 = register_module("C3", torch::nn::Conv2d(6, 16, 5));
C5 = register_module("C5", torch::nn::Conv2d(16, 120, 5));
F6 = register_module("F6", torch::nn::Linear(120, 84));
OUTPUT = register_module("OUTPUT", torch::nn::Linear(84, 10));
}
torch::Tensor LeNet5Impl::forward(torch::Tensor x)
{
namespace F = torch::nn::functional;
x = F::max_pool2d(F::relu(C1(x)), F::MaxPool2dFuncOptions({2, 2})); // C1 S2
x = F::max_pool2d(F::relu(C3(x)), F::MaxPool2dFuncOptions({2, 2})); // C3 S4
x = F::relu(C5(x)); // C5
x = F::relu(F6(x.flatten(1))); // F6
x = OUTPUT(x); // OUTPUT
return x;
}