Skip to content

Commit

Permalink
ADD: improve MLP example
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jun 9, 2024
1 parent cb939ee commit 70c4360
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 27 deletions.
42 changes: 30 additions & 12 deletions example/mlp/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include "nn.h"

#define DIM 3
#define DIM 16


// load the weight data block from the model.bin file
Expand All @@ -24,6 +24,12 @@ typedef struct {
Tensor *input;
Tensor *fc1_weight;
Tensor *fc1_bias;
Tensor *fc1_out;
Tensor *fc2_weight;
Tensor *fc2_bias;
Tensor *fc2_out;
Tensor *fc3_weight;
Tensor *fc3_bias;
Tensor *output;
} Model;

Expand All @@ -32,27 +38,43 @@ typedef struct {
*/
void init(Model *model) {
uint8_t *array_pointer = weights_data;

model->input = NN_ones(2, (size_t[]){1, DIM}, DTYPE_F32);

model->input = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, NULL);
model->fc1_weight = NN_tensor(2, (size_t[]){DIM, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * DIM * sizeof(float);
model->fc1_bias = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * sizeof(float);
model->fc1_out = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, NULL);

model->fc2_weight = NN_tensor(2, (size_t[]){DIM, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * DIM * sizeof(float);
model->fc2_bias = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * sizeof(float);
model->fc2_out = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, NULL);

model->fc3_weight = NN_tensor(2, (size_t[]){DIM, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * DIM * sizeof(float);
model->fc3_bias = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, array_pointer);
array_pointer += DIM * sizeof(float);

model->output = NN_tensor(2, (size_t[]){1, DIM}, DTYPE_F32, NULL);

printf("fc1_weight: \n");
NN_printf(model->fc1_weight);
printf("fc1_bias: \n");
NN_printf(model->fc1_bias);
// printf("fc1_weight: \n");
// NN_printf(model->fc1_weight);
// printf("fc1_bias: \n");
// NN_printf(model->fc1_bias);
}

/**
* Forward pass of the model
*/
void forward(Model *model) {
NN_linear_F32(model->output, model->input, model->fc1_weight, model->fc1_bias);
NN_relu_F32(model->output, model->output);
NN_Linear_F32(model->fc1_out, model->input, model->fc1_weight, model->fc1_bias);
NN_ReLU_F32(model->fc1_out, model->fc1_out);
NN_Linear_F32(model->fc2_out, model->input, model->fc2_weight, model->fc2_bias);
NN_ReLU_F32(model->fc2_out, model->fc2_out);
NN_Linear_F32(model->output, model->input, model->fc3_weight, model->fc3_bias);
}

int main() {
Expand All @@ -62,10 +84,6 @@ int main() {
Model *model = malloc(sizeof(Model));

init(model);

((float *)model->input->data)[0] = 1.;
((float *)model->input->data)[1] = 2.;
((float *)model->input->data)[2] = 3.;

forward(model);

Expand Down
Binary file modified example/mlp/model.bin
Binary file not shown.
Binary file modified example/mlp/model.pth
Binary file not shown.
37 changes: 22 additions & 15 deletions example/mlp/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@

torch.manual_seed(0)

class Simple(nn.Module):
"""
Simple model with one linear layer
"""
DIM = 16


class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, dim)
self.fc2 = nn.Linear(dim, dim)
self.fc3 = nn.Linear(dim, dim)

def forward(self, x: torch.Tensor):
y = nn.functional.relu(self.fc1(x))
y = nn.functional.relu(self.fc2(x))
y = self.fc3(x)
return y

# Create model
model = Simple(dim=3)
model = MLP(dim=DIM)

# Save model
torch.save(model, "model.pth")
Expand All @@ -31,23 +34,27 @@ def forward(self, x: torch.Tensor):

w1 = model.state_dict().get("fc1.weight").contiguous().numpy()
b1 = model.state_dict().get("fc1.bias").contiguous().numpy()
w2 = model.state_dict().get("fc2.weight").contiguous().numpy()
b2 = model.state_dict().get("fc2.bias").contiguous().numpy()
w3 = model.state_dict().get("fc3.weight").contiguous().numpy()
b3 = model.state_dict().get("fc3.bias").contiguous().numpy()

print("w1:\n", w1)
print("b1:\n", b1)

w1_flat = w1.astype(np.float32).flatten()
b1_flat = b1.astype(np.float32).flatten()
# print("w1:\n", w1)
# print("b1:\n", b1)

with open("model.bin", "wb") as f:
f.write(w1_flat.tobytes())
f.write(b1_flat.tobytes())
f.write(w1.astype(np.float32).flatten().tobytes())
f.write(b1.astype(np.float32).flatten().tobytes())
f.write(w2.astype(np.float32).flatten().tobytes())
f.write(b2.astype(np.float32).flatten().tobytes())
f.write(w3.astype(np.float32).flatten().tobytes())
f.write(b3.astype(np.float32).flatten().tobytes())



# Test model
test_input = np.array([
[1.0, 2.0, 3.0],
], dtype=np.float32)
test_input = np.ones((1, DIM), dtype=np.float32)

test_tensor = torch.tensor(test_input, dtype=torch.float32)

output = model.forward(test_tensor)
Expand Down

0 comments on commit 70c4360

Please sign in to comment.