diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index d322009e7..81c134aec 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -989,13 +989,18 @@ void scopy_s8_to_float32(const unsigned int N, const int8_t *X, const int incX, } } +static inline void copy_s16_fp32_fallback(const unsigned int N, + const int16_t *X, float *Y) { + for (unsigned int idx = 0; idx < N; ++idx) { + Y[idx] = (float)X[idx]; + } +} + void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) { #ifdef USE_NEON nntrainer::neon::copy_s16_fp32(N, X, Y); #else - for (unsigned int idx = 0; idx < N; ++idx) { - Y[idx] = (float)X[idx]; - } + copy_s16_fp32_fallback(N, X, Y); #endif } diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index 89997661d..49d556cd1 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -756,20 +756,119 @@ void copy_s8_to_fp32(const unsigned int N, const int8_t *X, float *Y) { } } -void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) { - /// @todo implement int16_t to fp32 +void copy_fp32_s16(const unsigned int N, const float *X, int16_t *Y) { unsigned int idx = 0; + for (; (N - idx) >= 8; ++idx) { + vst1q_s16(&Y[idx], + vcombine_s16(vqmovn_s32(vcvtq_s32_f32(vld1q_f32(&X[idx]))), + vqmovn_s32(vcvtq_s32_f32(vld1q_f32(&X[idx + 4]))))); + } for (; (N - idx) >= 1; ++idx) { Y[idx] = X[idx]; } } -void copy_fp32_s16(const unsigned int N, const float *X, int16_t *Y) { +void copy_int8_to_fp32(const unsigned int N, const uint8_t *X, float *Y) { unsigned int idx = 0; - for (; (N - idx) >= 8; ++idx) { - vst1q_s16(&Y[idx], - vcombine_s16(vqmovn_s32(vcvtq_s32_f32(vld1q_f32(&X[idx]))), - vqmovn_s32(vcvtq_s32_f32(vld1q_f32(&X[idx + 4]))))); + for (; (N - idx) >= 16; idx += 16) { + uint8x16_t batch = vld1q_u8(&X[idx]); + uint8x8_t low = vget_low_u8(batch); + uint8x8_t high = vget_high_u8(batch); + + // convert to u16 + uint16x8_t batch_low_u16 = vmovl_u8(low); + uint16x8_t batch_high_u16 = vmovl_u8(high); + + // convert to u32 + uint32x4_t batch_low_u32_low = vmovl_u16(vget_low_u16(batch_low_u16)); + uint32x4_t batch_low_u32_high = vmovl_u16(vget_high_u16(batch_low_u16)); + uint32x4_t batch_high_u32_low = vmovl_u16(vget_low_u16(batch_high_u16)); + uint32x4_t batch_high_u32_high = vmovl_u16(vget_high_u16(batch_high_u16)); + + // todo : experiment with vcvt_f32_u32_ bitwise operation w.r.t. + // time/accuracy + vst1q_f32(&Y[idx], vcvtq_f32_u32(batch_low_u32_low)); + vst1q_f32(&Y[idx + 4], vcvtq_f32_u32(batch_low_u32_high)); + vst1q_f32(&Y[idx + 8], vcvtq_f32_u32(batch_high_u32_low)); + vst1q_f32(&Y[idx + 12], vcvtq_f32_u32(batch_high_u32_high)); + } + for (; (N - idx) >= 8; idx += 8) { + uint8x8_t batch = vld1_u8(&X[idx]); + + // convert to u16 + uint16x8_t batch_u16 = vmovl_u8(batch); + + // convert to u32 + uint32x4_t batch_u32_low = vmovl_u16(vget_low_u16(batch_u16)); + uint32x4_t batch_u32_high = vmovl_u16(vget_high_u16(batch_u16)); + + vst1q_f32(&Y[idx], vcvtq_f32_u32(batch_u32_low)); + vst1q_f32(&Y[idx + 4], vcvtq_f32_u32(batch_u32_high)); + } + for (; (N - idx) >= 1; ++idx) { + Y[idx] = X[idx]; + } +} + +void copy_int8_to_fp32(const unsigned int N, const int8_t *X, float *Y) { + unsigned int idx = 0; + for (; (N - idx) >= 16; idx += 16) { + int8x16_t batch = vld1q_s8(&X[idx]); + int8x8_t low = vget_low_s8(batch); + int8x8_t high = vget_high_s8(batch); + + // convert to s16 + int16x8_t batch_low_s16 = vmovl_s8(low); + int16x8_t batch_high_s16 = vmovl_s8(high); + + // convert to s32 + int32x4_t batch_low_s32_low = vmovl_s16(vget_low_s16(batch_low_s16)); + int32x4_t batch_low_s32_high = vmovl_s16(vget_high_s16(batch_low_s16)); + int32x4_t batch_high_s32_low = vmovl_s16(vget_low_s16(batch_high_s16)); + int32x4_t batch_high_s32_high = vmovl_s16(vget_high_s16(batch_high_s16)); + + // todo : experiment with vcvt_f32_s32_ bitwise operation w.r.t. + // time/accuracy + vst1q_f32(&Y[idx], vcvtq_f32_s32(batch_low_s32_low)); + vst1q_f32(&Y[idx + 4], vcvtq_f32_s32(batch_low_s32_high)); + vst1q_f32(&Y[idx + 8], vcvtq_f32_s32(batch_high_s32_low)); + vst1q_f32(&Y[idx + 12], vcvtq_f32_s32(batch_high_s32_high)); + } + for (; (N - idx) >= 8; idx += 8) { + int8x8_t batch = vld1_s8(&X[idx]); + + // convert to s16 + int16x8_t batch_s16 = vmovl_s8(batch); + + // convert to s32 + int32x4_t batch_s32_low = vmovl_s16(vget_low_s16(batch_s16)); + int32x4_t batch_s32_high = vmovl_s16(vget_high_s16(batch_s16)); + + vst1q_f32(&Y[idx], vcvtq_f32_s32(batch_s32_low)); + vst1q_f32(&Y[idx + 4], vcvtq_f32_s32(batch_s32_high)); + } + for (; (N - idx) >= 1; ++idx) { + Y[idx] = X[idx]; + } +} + +void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) { + unsigned int idx = 0; + for (; (N - idx) >= 8; idx += 8) { + int16x8_t batch = vld1q_s16(&X[idx]); + int16x4_t low = vget_low_s16(batch); + int16x4_t high = vget_high_s16(batch); + + // widen to s32 + int32x4_t low_s32 = vmovl_s16(low); + int32x4_t high_s32 = vmovl_s16(high); + + // convert to f32 + float32x4_t low_f32 = vcvtq_f32_s32(low_s32); + float32x4_t high_f32 = vcvtq_f32_s32(high_s32); + + vst1q_f32(&Y[idx], low_f32); + vst1q_f32(&Y[idx + 4], high_f32); } for (; (N - idx) >= 1; ++idx) { Y[idx] = X[idx];