Skip to content

Commit

Permalink
ADD: minor improve on tensor initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jun 19, 2024
1 parent ddc629e commit eb1375f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nn/inc/nn_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static inline void NN_deleteTensor(Tensor *tensor) {
* @param dtype: data type
* @param data: pointer to data, if NULL, the data will be allocated
*/
void NN_initTensor(Tensor *tensor, size_t ndim, const size_t *shape, DataType dtype, void *data);
void NN_initTensor(Tensor *tensor, const size_t ndim, const size_t *shape, DataType dtype, const void *data);

/**
* Create a new tensor
Expand Down
15 changes: 6 additions & 9 deletions nn/src/nn_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@
#include "nn_tensor.h"


void NN_initTensor(Tensor *tensor, size_t ndim, const size_t *shape, DataType dtype, void *data) {
tensor->ndim = ndim;
void NN_initTensor(Tensor *tensor, const size_t ndim, const size_t *shape, DataType dtype, const void *data) {
tensor->dtype = dtype;
tensor->ndim = ndim;

// set shape
for (size_t i = 0; i < ndim; i += 1) {
tensor->shape[i] = shape[i];
}
for (size_t i = ndim; i < MAX_DIMS; i += 1) {
tensor->shape[i] = 0;
}
memcpy(tensor->shape, shape, ndim * sizeof(size_t));
memset(tensor->shape + ndim, 0, (MAX_DIMS - ndim) * sizeof(size_t));

// calculate size (number of elements)
tensor->size = 1;
for (size_t i = 0; i < ndim; i += 1) {
tensor->size *= tensor->shape[i];
tensor->size *= shape[i];
}

if (data != NULL) {
tensor->data = data;
return;
}

// if this is a scalar tensor
if (tensor->ndim == 0) {
tensor->data = malloc(NN_sizeof(dtype));
return;
Expand Down

0 comments on commit eb1375f

Please sign in to comment.