-
Notifications
You must be signed in to change notification settings - Fork 21
/
BitNetMCU_MNIST_dll.c
73 lines (59 loc) · 2.21 KB
/
BitNetMCU_MNIST_dll.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <stdint.h>
#include <stdio.h>
#include "BitNetMCU_model.h"
#include "BitNetMCU_inference.c"
/**
* @file Bitnet_inference_lib.c
* @brief DLL wrapper for the BitMnist model.
* build in 64 bit visual studio:
* cl /LD BitNetMCU_MNIST_dll.c /MD /FeBitnet_inf.dll /link /MACHINE:X64
*@param input The input data for the inference.
* @return The result of the inference.
*/
uint32_t BitMnistInference(int8_t *input);
#ifdef _DLL
#ifdef WIN32
#define EXPORT __declspec(dllexport)
#else
#define EXPORT __attribute__((visibility("default")))
#endif
EXPORT uint32_t Inference(int8_t *input) {
return BitMnistInference(input);
}
#endif
void printactivations(uint8_t *activations, int32_t n_activations)
{
for (int i = 0; i < n_activations; i++) {
printf("%d, ", activations[i]);
if ((i + 1) % 16 == 0) {
printf("\n");
}
}
}
/**
* @brief Performs inference on the BitMnist model.
*
* @param input The input data for the inference.
* @return The result of the inference.
*/
uint32_t BitMnistInference(int8_t *input) {
int32_t layer_out[256];
uint8_t layer_in[256];
processfclayer(input, L1_weights, L1_bitperweight, L1_incoming_weights, L1_outgoing_weights, layer_out);
ReLUNorm(layer_out, layer_in, L1_outgoing_weights);
// printf("L1 activations: \n");
// printactivations(layer_in, L1_outgoing_weights);
processfclayer(layer_in, L2_weights, L2_bitperweight, L2_incoming_weights, L2_outgoing_weights, layer_out);
ReLUNorm(layer_out, layer_in, L2_outgoing_weights);
// printf("L2 activations: \n");
// printactivations(layer_in, L2_outgoing_weights);
#ifdef L4_active
processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out);
ReLUNorm(layer_out, layer_in, L3_outgoing_weights);
processfclayer(layer_in, L4_weights, L4_bitperweight, L4_incoming_weights, L4_outgoing_weights, layer_out);
return ReLUNorm(layer_out, layer_in, L4_outgoing_weights);
#else
processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out);
return ReLUNorm(layer_out, layer_in, L3_outgoing_weights);
#endif
}