Skip to content

Commit

Permalink
Align arm_vector_sum_s8 behaviour between default/MVE case
Browse files Browse the repository at this point in the history
Change-Id: Icedd6b84b7dadfa535bca4d388967c8a6bf416d5
  • Loading branch information
AdrianLundell committed Feb 14, 2024
1 parent 5aeada7 commit a516162
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 96 deletions.
5 changes: 2 additions & 3 deletions Include/arm_nnfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_nnfunctions.h
* Description: Public header file for CMSIS NN Library
*
* $Date: 19 January 2024
* $Revision: V.13.0.0
* $Date: 13 February 2024
* $Revision: V.13.0.1
*
* Target : Arm(R) M-Profile Architecture
Expand Down Expand Up @@ -1524,7 +1524,6 @@ arm_cmsis_nn_status arm_fully_connected_s8(const cmsis_nn_context *ctx,
* @param[in] bias_data Vector of bias data, added to each sum.
* @return The function returns
* <code>ARM_CMSIS_NN_SUCCESS</code> - Successful operation
* <code>ARM_CMSIS_NN_ARG_ERROR</code> - If not for Arm(R) Helium Architecture case.
*/
arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t vector_cols,
Expand Down
170 changes: 77 additions & 93 deletions Source/FullyConnectedFunctions/arm_vector_sum_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@
* Title: arm_vector_sum_s8
* Description: Generic function for calculating vector sums
*
* $Date: 26 January 2024
* $Revision: V.2.0.0
* $Date: 12 February 2024
* $Revision: V.2.1.0
*
* Target : Arm(R) M-Profile Architecture
*
* -------------------------------------------------------------------- */

#include "arm_nnfunctions.h"
#include "arm_nnsupportfunctions.h"

/**
* @ingroup Public
*/
Expand All @@ -53,94 +52,6 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
const int32_t lhs_offset,
const int32_t *bias_data)
{
#if defined(ARM_MATH_MVEI)
const int32_t row_loop_cnt = vector_rows / 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
const int8_t *vector_0 = vector_data;
const int8_t *vector_1 = vector_data + vector_cols;
const int8_t *vector_2 = vector_data + 2 * vector_cols;
const int8_t *vector_3 = vector_data + 3 * vector_cols;
const int8_t *vector_4 = vector_data + 4 * vector_cols;
int32_t vector_sum_0 = 0;
int32_t vector_sum_1 = 0;
int32_t vector_sum_2 = 0;
int32_t vector_sum_3 = 0;
int32_t vector_sum_4 = 0;
uint32_t col_cnt = (uint32_t)vector_cols;
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
vector_0 += 16;
vector_1 += 16;
vector_2 += 16;
vector_3 += 16;
vector_4 += 16;
}
vector_data += 5 * vector_cols;
if (lhs_offset)
{
vector_sum_0 *= lhs_offset;
vector_sum_1 *= lhs_offset;
vector_sum_2 *= lhs_offset;
vector_sum_3 *= lhs_offset;
vector_sum_4 *= lhs_offset;
}
if (bias_data)
{
vector_sum_0 += *bias_data++;
vector_sum_1 += *bias_data++;
vector_sum_2 += *bias_data++;
vector_sum_3 += *bias_data++;
vector_sum_4 += *bias_data++;
}
vector_sum_buf[0] = vector_sum_0;
vector_sum_buf[1] = vector_sum_1;
vector_sum_buf[2] = vector_sum_2;
vector_sum_buf[3] = vector_sum_3;
vector_sum_buf[4] = vector_sum_4;
vector_sum_buf += 5;
}
const int32_t loop_cnt = vector_rows % 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
const int8_t *vector_0 = vector_data;
int32_t vector_sum_0 = 0;
uint32_t col_cnt = (uint32_t)vector_cols;
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
vector_0 += 16;
}
vector_data += vector_cols;
if (lhs_offset)
{
vector_sum_0 *= lhs_offset;
}
if (bias_data)
{
vector_sum_0 += *bias_data++;
}
vector_sum_buf[i_row_loop_cnt] = vector_sum_0;
}
return (ARM_CMSIS_NN_SUCCESS);
#else

if (bias_data)
{
Expand All @@ -153,6 +64,79 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,

if (lhs_offset)
{
#if defined(ARM_MATH_MVEI)

const int32_t row_loop_cnt = vector_rows / 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
const int8_t *vector_0 = vector_data;
const int8_t *vector_1 = vector_data + vector_cols;
const int8_t *vector_2 = vector_data + 2 * vector_cols;
const int8_t *vector_3 = vector_data + 3 * vector_cols;
const int8_t *vector_4 = vector_data + 4 * vector_cols;
int32_t vector_sum_0 = 0;
int32_t vector_sum_1 = 0;
int32_t vector_sum_2 = 0;
int32_t vector_sum_3 = 0;
int32_t vector_sum_4 = 0;
uint32_t col_cnt = (uint32_t)vector_cols;
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
const int8x16_t ker_1 = vldrbq_z_s8(vector_1, p);
vector_sum_1 = vaddvaq_s8(vector_sum_1, ker_1);
const int8x16_t ker_2 = vldrbq_z_s8(vector_2, p);
vector_sum_2 = vaddvaq_s8(vector_sum_2, ker_2);
const int8x16_t ker_3 = vldrbq_z_s8(vector_3, p);
vector_sum_3 = vaddvaq_s8(vector_sum_3, ker_3);
const int8x16_t ker_4 = vldrbq_z_s8(vector_4, p);
vector_sum_4 = vaddvaq_s8(vector_sum_4, ker_4);
vector_0 += 16;
vector_1 += 16;
vector_2 += 16;
vector_3 += 16;
vector_4 += 16;
}
vector_data += 5 * vector_cols;

vector_sum_0 *= lhs_offset;
vector_sum_1 *= lhs_offset;
vector_sum_2 *= lhs_offset;
vector_sum_3 *= lhs_offset;
vector_sum_4 *= lhs_offset;

vector_sum_buf[0] += vector_sum_0;
vector_sum_buf[1] += vector_sum_1;
vector_sum_buf[2] += vector_sum_2;
vector_sum_buf[3] += vector_sum_3;
vector_sum_buf[4] += vector_sum_4;
vector_sum_buf += 5;
}
const int32_t loop_cnt = vector_rows % 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
{
const int32_t col_loop_cnt = (vector_cols + 15) / 16;
const int8_t *vector_0 = vector_data;
int32_t vector_sum_0 = 0;
uint32_t col_cnt = (uint32_t)vector_cols;
for (int i = 0; i < col_loop_cnt; i++)
{
mve_pred16_t p = vctp8q(col_cnt);
col_cnt -= 16;
const int8x16_t ker_0 = vldrbq_z_s8(vector_0, p);
vector_sum_0 = vaddvaq_s8(vector_sum_0, ker_0);
vector_0 += 16;
}
vector_data += vector_cols;
vector_sum_0 *= lhs_offset;

vector_sum_buf[i_row_loop_cnt] += vector_sum_0;
}
#else
for (int i = 0; i < vector_rows; i++)
{
int32_t sum = 0;
Expand All @@ -162,10 +146,10 @@ arm_cmsis_nn_status arm_vector_sum_s8(int32_t *vector_sum_buf,
}
*vector_sum_buf++ += sum * lhs_offset;
}
#endif
}
return (ARM_CMSIS_NN_SUCCESS);

#endif
return (ARM_CMSIS_NN_SUCCESS);
}

/**
Expand Down

0 comments on commit a516162

Please sign in to comment.