Skip to content

Commit

Permalink
Update mblas (#1457)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Update mblas library to the latest version.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <[email protected]>
  • Loading branch information
JinHai-CN authored Jul 10, 2024
1 parent b478b79 commit b1b2e5e
Show file tree
Hide file tree
Showing 79 changed files with 26,376 additions and 485 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:

- name: Unit test debug version
if: ${{ !cancelled() && !failure() }}
run: sudo docker exec ${BUILDER_CONTAINER} bash -c "mkdir -p /var/infinity && cd /infinity/ && cmake-build-debug/src/test_main > unittest_debug.log 2>&1"
run: sudo docker exec ${BUILDER_CONTAINER} bash -c "mkdir -p /var/infinity && cd /infinity/ && ASAN_OPTIONS=detect_leaks=0 cmake-build-debug/src/test_main > unittest_debug.log 2>&1"

- name: Collect infinity unit test debug output
if: ${{ !cancelled() }}
Expand Down
12 changes: 2 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,8 @@ elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
set(CMAKE_C_FLAGS "-O0 -g")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-stack-protector -fno-var-tracking ")
add_compile_options(-fsanitize=address -fsanitize-recover=all)
add_link_options(-fsanitize=address -fsanitize-recover=all)

option(LEAK "Memory leak detection" OFF)
message("Check memory leak: " "${LEAK}")
if(LEAK)
message("Check memory leak")
add_compile_options(-fsanitize=leak)
add_link_options(-fsanitize=leak)
endif()
add_compile_options(-fsanitize=address -fsanitize-recover=all -fsanitize=leak)
add_link_options(-fsanitize=address -fsanitize-recover=all -fsanitize=leak)

add_compile_options("-fno-omit-frame-pointer")
add_link_options("-fno-omit-frame-pointer")
Expand Down
415 changes: 384 additions & 31 deletions third_party/mlas/CMakeLists.txt

Large diffs are not rendered by default.

144 changes: 140 additions & 4 deletions third_party/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ Module Name:
#endif
#endif

#if defined(__loongarch64)
#define MLAS_TARGET_LARCH64
#endif
//
// Define the support levels for the target architecture.
//
Expand All @@ -87,7 +90,7 @@ Module Name:

#define MLAS_F16VEC_INTRINSICS_SUPPORTED

#endif //
#endif //
#endif // ARM64
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic

Expand Down Expand Up @@ -1219,6 +1222,26 @@ MlasQuantizeLinear(
OutputType ZeroPoint
);

void
MLASCALL
MlasQuantizeLinearU4(
const float* Input,
uint8_t* Output,
size_t N,
float Scale,
int8_t ZeroPoint
);

void
MLASCALL
MlasQuantizeLinearS4(
const float* Input,
uint8_t* Output,
size_t N,
float Scale,
int8_t ZeroPoint
);

/**
* @brief Requantize a block of the intermediate buffer to the output buffer,
* optionally adding the supplied bias
Expand Down Expand Up @@ -1611,6 +1634,119 @@ MlasHalfGemmConvertPackB(
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

/**
* @brief Interface for bf16 gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from single precision to precision, etc.
*
* SBGEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_SBGEMM_POSTPROCESSOR
{
public:
virtual void Process(float*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_SBGEMM_POSTPROCESSOR() {}
};

/**
* @brief bfloat16 precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
* And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR
{
public:
MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr)
: Activation_(Activation), SumBuf_(SumBuf)
{
}

void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc)
const override;

private:
const MLAS_ACTIVATION& Activation_;
const float* SumBuf_;
};

/**
* @brief Data parameters for bfloat16 precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_SBGEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const float* Bias = nullptr; /**< address of Bias, vector size N */
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
};

/**
* @brief Bfloat16 precision Batched GEMM: C = A * B + Bias
* Either B can be either fp32 or bf16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief For bfloat16 precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t MLASCALL
MlasSBGemmPackBSize(size_t N, size_t K);

/**
* @brief For bfloat16 precision GEMM, convert the float matrix B
* to blfoat16 precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void MLASCALL
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
#endif

/**
* @brief Indirect Depthwise convolution for fp16
* @param Input Supplies the indirect buffer for NHWC input
Expand All @@ -1619,7 +1755,7 @@ MlasHalfGemmConvertPackB(
* @param Channels # of input channels
* @param OutputCount # of output pixels
* @param KernelSize # kernel size
* @return
* @return
*/
void
MLASCALL
Expand Down Expand Up @@ -1657,7 +1793,7 @@ MlasTranspose(
* @param Channels C in NHWC
* @param OutputCount Number of output pixels
* @param KernelSize Size of the kernel
* @return
* @return
*/
void
MLASCALL
Expand All @@ -1676,7 +1812,7 @@ MlasNhwcMaxPool(
* @param Channels C in NHWC
* @param OutputCount Number of output pixels
* @param KernelSize size of the kernel
* @return
* @return
*/
void
MLASCALL
Expand Down
33 changes: 33 additions & 0 deletions third_party/mlas/inc/mlas_gemm_postprocessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
mlas_gemm_postprocessor.h
Abstract:
This module contains a base class for custom postprocessing following a
GEMM.
--*/

#pragma once

template<typename T>
class MLAS_GEMM_POSTPROCESSOR
{
public:
virtual void Process(T* C, /**< the address of matrix to process */
size_t RangeStartM, /**< the start row index of matrix */
size_t RangeStartN, /**< the start col index of matrix */
size_t RangeCountM, /**< the element count per row to process */
size_t RangeCountN, /**< the element count per col to process */
size_t ldc /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_GEMM_POSTPROCESSOR() {}
};
Loading

0 comments on commit b1b2e5e

Please sign in to comment.