Skip to content

Commit

Permalink
Change to recommended fn naming
Browse files Browse the repository at this point in the history
  • Loading branch information
rozukke committed Jun 26, 2024
1 parent 57f1bb0 commit 2d353d0
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ void read_tensor(matrix* a, const char* fileName) {
fclose(file);
}

void propagate(const matrix* weights, const matrix* input_layer, matrix* output_layer, const matrix* biases) {
void propagate_fwd(const matrix* weights, const matrix* input_layer, matrix* output_layer, const matrix* biases) {
matrix_mul(weights, input_layer, output_layer);
matrix_add(output_layer, biases);
}

// Get result from output layer
int get_result(matrix* a) {
int get_max(matrix* a) {
int idx = 0;
float res = (a->data)[0][0];
for (int i = 0; i < a->rows; i++) {
Expand All @@ -112,23 +112,23 @@ int infer(matrix* input) {
mdl_layers[5] = new_matrix(40, 1);
mdl_layers[6] = new_matrix(52, 1);

propagate(weights[0], input, mdl_layers[0], biases[0]);
propagate_fwd(weights[0], input, mdl_layers[0], biases[0]);
relu(mdl_layers[0]);
propagate(weights[1], mdl_layers[0], mdl_layers[1], biases[1]);
propagate_fwd(weights[1], mdl_layers[0], mdl_layers[1], biases[1]);
relu(mdl_layers[1]);
propagate(weights[2], mdl_layers[1], mdl_layers[2], biases[2]);
propagate_fwd(weights[2], mdl_layers[1], mdl_layers[2], biases[2]);
relu(mdl_layers[2]);
propagate(weights[3], mdl_layers[2], mdl_layers[3], biases[3]);
propagate_fwd(weights[3], mdl_layers[2], mdl_layers[3], biases[3]);
relu(mdl_layers[3]);
propagate(weights[4], mdl_layers[3], mdl_layers[4], biases[4]);
propagate_fwd(weights[4], mdl_layers[3], mdl_layers[4], biases[4]);
relu(mdl_layers[4]);
propagate(weights[5], mdl_layers[4], mdl_layers[5], biases[5]);
propagate_fwd(weights[5], mdl_layers[4], mdl_layers[5], biases[5]);
relu(mdl_layers[5]);

propagate(weights[6], mdl_layers[5], mdl_layers[6], biases[6]);
propagate_fwd(weights[6], mdl_layers[5], mdl_layers[6], biases[6]);
softmax(mdl_layers[6]);

return get_result(mdl_layers[6]);
return get_max(mdl_layers[6]);
}

int main(int argc, char* argv[]) {
Expand Down

0 comments on commit 2d353d0

Please sign in to comment.