diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cb80af5081..033f66af04 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: - name: Unit test debug version if: ${{ !cancelled() && !failure() }} - run: sudo docker exec ${BUILDER_CONTAINER} bash -c "mkdir -p /var/infinity && cd /infinity/ && cmake-build-debug/src/test_main > unittest_debug.log 2>&1" + run: sudo docker exec ${BUILDER_CONTAINER} bash -c "mkdir -p /var/infinity && cd /infinity/ && ASAN_OPTIONS=detect_leaks=0 cmake-build-debug/src/test_main > unittest_debug.log 2>&1" - name: Collect infinity unit test debug output if: ${{ !cancelled() }} diff --git a/CMakeLists.txt b/CMakeLists.txt index fcb0d7e4d8..4abbb4b60f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,16 +117,8 @@ elseif ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") set(CMAKE_C_FLAGS "-O0 -g") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-stack-protector -fno-var-tracking ") - add_compile_options(-fsanitize=address -fsanitize-recover=all) - add_link_options(-fsanitize=address -fsanitize-recover=all) - - option(LEAK "Memory leak detection" OFF) - message("Check memory leak: " "${LEAK}") - if(LEAK) - message("Check memory leak") - add_compile_options(-fsanitize=leak) - add_link_options(-fsanitize=leak) - endif() + add_compile_options(-fsanitize=address -fsanitize-recover=all -fsanitize=leak) + add_link_options(-fsanitize=address -fsanitize-recover=all -fsanitize=leak) add_compile_options("-fno-omit-frame-pointer") add_link_options("-fno-omit-frame-pointer") diff --git a/third_party/mlas/CMakeLists.txt b/third_party/mlas/CMakeLists.txt index 05b985288e..e56f791c5b 100644 --- a/third_party/mlas/CMakeLists.txt +++ b/third_party/mlas/CMakeLists.txt @@ -1,34 +1,239 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -set(MLAS_SRC_DIR lib) +set(MLAS_ROOT ${CMAKE_SOURCE_DIR}/third_party/mlas) +set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) +set(MLAS_INC_DIR ${MLAS_ROOT}/inc) + +function(onnxruntime_add_include_to_target dst_target) + foreach(src_target ${ARGN}) + if(TARGET ${src_target}) + target_include_directories(${dst_target} PRIVATE $) + target_compile_definitions(${dst_target} PRIVATE $) + target_sources(${dst_target} PRIVATE $) + endif() + endforeach() +endfunction() -set(MLAS_AMX_SUPPORTED FALSE) +function(onnxruntime_set_compile_flags target_name) + if (CPUINFO_SUPPORTED) + onnxruntime_add_include_to_target(${target_name} cpuinfo::cpuinfo) + endif() + if(onnxruntime_ENABLE_LAZY_TENSOR) + target_compile_definitions(${target_name} PRIVATE ENABLE_LAZY_TENSOR) + endif() + if (onnxruntime_ENABLE_CPU_FP16_OPS) + target_compile_definitions(${target_name} PRIVATE ENABLE_CPU_FP16_TRAINING_OPS) + endif() + if(onnxruntime_DISABLE_ABSEIL) + target_compile_definitions(${target_name} PRIVATE DISABLE_ABSEIL) + endif() + if(UNIX) + target_compile_definitions(${target_name} PRIVATE PLATFORM_POSIX) + endif() + target_compile_definitions(${target_name} PRIVATE EIGEN_USE_THREADS) + if (onnxruntime_DISABLE_CONTRIB_OPS) + target_compile_definitions(${target_name} PRIVATE DISABLE_CONTRIB_OPS) + endif() -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 11) - # match assembler version, AMX instructions are supported from 2.38 - if (CMAKE_ASM_COMPILER_ID STREQUAL "GNU") - execute_process( - COMMAND as --version - OUTPUT_VARIABLE _as_version - ) - # 2.38 or later - if (_as_version MATCHES "GNU.[Aa]ssembler.*(2\\.38|2\\.39|2\\.[4-9][0-9]|[3-9]\\.[0-9][0-9])") - set(MLAS_AMX_SUPPORTED TRUE) + if (onnxruntime_DISABLE_ML_OPS) + target_compile_definitions(${target_name} PRIVATE DISABLE_ML_OPS) endif() - endif() -endif() -if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set(MLAS_AMX_SUPPORTED TRUE) -endif() + if (onnxruntime_DISABLE_SPARSE_TENSORS) + target_compile_definitions(${target_name} PRIVATE DISABLE_SPARSE_TENSORS) + endif() + + if (onnxruntime_DISABLE_OPTIONAL_TYPE) + target_compile_definitions(${target_name} PRIVATE DISABLE_OPTIONAL_TYPE) + endif() + + if (onnxruntime_DISABLE_FLOAT8_TYPES) + target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT8_TYPES) + endif() + + if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) + endif() + + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) + endif() + + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) + if (onnxruntime_USE_CUDA) + # Suppress a "conversion_function_not_usable" warning in gsl/span + target_compile_options(${target_name} PRIVATE "$<$:SHELL:-Xcudafe \"--diag_suppress=conversion_function_not_usable\">") + target_compile_definitions(${target_name} PRIVATE -DDISABLE_CUSPARSE_DEPRECATED) + endif() + if (MSVC) + foreach(CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORY ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_compile_options(${target_name} PRIVATE "$<$:/external:I${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORY}>") + endforeach() + + foreach(onnxruntime_external_lib IN LISTS onnxruntime_EXTERNAL_LIBRARIES) + #TODO: the list contains cmake keywords like "debug". We should exclude them. + if(TARGET ${onnxruntime_external_lib}) + get_target_property(onnxruntime_external_lib_include_dirs ${onnxruntime_external_lib} INTERFACE_INCLUDE_DIRECTORIES) + foreach(onnxruntime_external_lib_include_dir IN LISTS onnxruntime_external_lib_include_dirs) + if(onnxruntime_external_lib_include_dir MATCHES "^\\$") + if(onnxruntime_external_lib_include_dir MATCHES "^\\$]+)>$") + string(REGEX REPLACE "^\\$]+)>$" "\\1" onnxruntime_external_lib_include_dir_cmake "${onnxruntime_external_lib_include_dir}") + cmake_path(NATIVE_PATH onnxruntime_external_lib_include_dir_cmake NORMALIZE onnxruntime_external_lib_include_dir_native) + target_compile_options(${target_name} PRIVATE "$<$:/external:I${onnxruntime_external_lib_include_dir_native}>") + endif() + else() + cmake_path(NATIVE_PATH onnxruntime_external_lib_include_dir NORMALIZE onnxruntime_external_lib_include_dir_native) + target_compile_options(${target_name} PRIVATE "$<$:/external:I${onnxruntime_external_lib_include_dir_native}>") + endif() + endforeach() + endif() + endforeach() + target_compile_definitions(${target_name} PRIVATE -DPLATFORM_WINDOWS -DNOGDI -DNOMINMAX -D_USE_MATH_DEFINES -D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS) + if (onnxruntime_ENABLE_MEMLEAK_CHECKER) + target_compile_definitions(${target_name} PRIVATE -DONNXRUNTIME_ENABLE_MEMLEAK_CHECK) + endif() + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$:/utf-8>") + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /sdl>" "$<$:/sdl>") + set_target_properties(${target_name} + PROPERTIES VS_GLOBAL_CAExcludePath "${ORT_BINARY_DIR};${ORT_SOURCE_DIR}") + # We do not treat warnings from 3rd-party libraries as errors. In order to do that, we need to add their header files locations to /external:I. + target_compile_options(${target_name} PRIVATE "$<$:/experimental:external>" "$<$:SHELL:--compiler-options /experimental:external>") + target_compile_options(${target_name} PRIVATE "$<$:/external:W0>" "$<$:SHELL:--compiler-options /external:W0>") + target_compile_options(${target_name} PRIVATE "$<$:/external:templates->" "$<$:SHELL:--compiler-options /external:templates->") + target_compile_options(${target_name} PRIVATE "$<$:/external:I${CMAKE_CURRENT_SOURCE_DIR}>" "$<$:SHELL:--compiler-options /external:I${CMAKE_CURRENT_SOURCE_DIR}>") + target_compile_options(${target_name} PRIVATE "$<$:/external:I${CMAKE_CURRENT_BINARY_DIR}>" "$<$:SHELL:--compiler-options /external:I${CMAKE_CURRENT_BINARY_DIR}>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /analyze>" "$<$:/analyze>") + if (onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE) + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /analyze:autolog:ext.sarif>" "$<$:/analyze:autolog:ext.sarif>") + endif() + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /analyze:external->" "$<$:/analyze:external->") + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /wd6385>" ) + # There are many such warnings from STL: + # include\list(148): warning C6011: Dereferencing NULL pointer '_Mycont'. : Lines: 146, 147, 148 + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options /wd6011>" ) + endif() + else() + # Enable warning + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options -Wall>" "$<$>:-Wall>") + target_compile_options(${target_name} PRIVATE "$<$>:-Wextra>") + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + #external/protobuf/src/google/protobuf/arena.h:445:18: error: unused parameter 'p' + target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter") + endif() + target_compile_definitions(${target_name} PUBLIC -DNSYNC_ATOMIC_CPP11) + onnxruntime_add_include_to_target(${target_name} nsync::nsync_cpp) + endif() + foreach(ORT_FLAG ${ORT_PROVIDER_FLAGS}) + target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG}) + endforeach() + if (HAS_DEPRECATED_COPY) + #too many such errors in eigen + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options -Wno-deprecated-copy>" "$<$:-Wno-deprecated-copy>") + endif() + foreach(FLAG ${ORT_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE "$<$:${FLAG}>") + endforeach() + if (onnxruntime_USE_CUDA) + foreach(FLAG ${ORT_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") + endforeach() + if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() + if (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() + endif() + if (onnxruntime_USE_ROCM) + # flags are detected with CXX language mode, some flags are not supported with hipclang + # because we may mix gcc and hipclang + set(ORT_HIP_WARNING_FLAGS ${ORT_WARNING_FLAGS}) + list(REMOVE_ITEM ORT_HIP_WARNING_FLAGS -Wno-nonnull-compare) + + # float16.h:90:12: error: ‘tmp’ is used uninitialized + list(APPEND ORT_HIP_WARNING_FLAGS -Wno-uninitialized) + list(APPEND ORT_HIP_WARNING_FLAGS -Wno-deprecated-copy) + + # some #pragma unroll will fail, do not treat them as error + # #warning must not be treated as error + list(APPEND ORT_HIP_WARNING_FLAGS -Wno-error=pass-failed "-Wno-error=#warnings") + + # otherwise error: builtin __has_trivial_assign is deprecated; use __is_trivially_assignable instead + if (ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.4") + list(APPEND ORT_HIP_WARNING_FLAGS "-Wno-deprecated-builtins") + endif() + + foreach(FLAG ${ORT_HIP_WARNING_FLAGS}) + target_compile_options(${target_name} PRIVATE "$<$:SHELL:${FLAG}>") + endforeach() + endif() +endfunction() + +function(onnxruntime_set_source_file_properties target_name) + get_target_property(srcs ${target_name} SOURCES) + + # enable ARC for Objective-C/C++ + set(objective_c_cc_srcs ${srcs}) + list(FILTER objective_c_cc_srcs INCLUDE REGEX "\\.mm?$") + set_property(SOURCE ${objective_c_cc_srcs} APPEND PROPERTY COMPILE_OPTIONS "-fobjc-arc") +endfunction() + +function(onnxruntime_configure_target target_name) + target_link_directories(${target_name} PRIVATE ${onnxruntime_LINK_DIRS}) + onnxruntime_set_compile_flags(${target_name}) + onnxruntime_set_source_file_properties(${target_name}) + if(WIN32 AND onnxruntime_ENABLE_STATIC_ANALYSIS AND onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES) + set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props) + endif() + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${abseil_cpp_SOURCE_DIR}) + if (onnxruntime_ENABLE_TRAINING_OPS) + target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT}) + endif() + if (onnxruntime_ENABLE_LTO) + set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) + set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELWITHDEBINFO TRUE) + set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL TRUE) + endif() + + if (onnxruntime_BUILD_KERNEL_EXPLORER) + get_target_property(target_type ${target_name} TYPE) + if (target_type STREQUAL "MODULE_LIBRARY" OR target_type STREQUAL "SHARED_LIBRARY") + set_property(TARGET ${target_name} + APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/tools/kernel_explorer/version_script.lds ") + endif() + endif() + + # Keep BinSkim happy + if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM") + target_link_options(${target_name} PRIVATE "/CETCOMPAT") + endif() + +endfunction() + +function(onnxruntime_add_executable target_name) + add_executable(${target_name} ${ARGN}) + onnxruntime_configure_target(${target_name}) + if (MSVC AND onnxruntime_target_platform STREQUAL "x86") + target_link_options(${target_name} PRIVATE /SAFESEH) + endif() +endfunction() + + +function(onnxruntime_add_static_library target_name) + add_library(${target_name} STATIC ${ARGN}) + onnxruntime_configure_target(${target_name}) +endfunction() # # All hardware agnostic source files here # hardware specific files would cause trouble in # multi-target build # -add_library(onnxruntime_mlas + +onnxruntime_add_static_library(onnxruntime_mlas + ${MLAS_SRC_DIR}/mlasi.h ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp @@ -53,12 +258,24 @@ add_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.h + ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ) -if(MLAS_AMX_SUPPORTED) - target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_AMX_SUPPORTED) -else() - message(WARNING "AMX instructions NOT supported due to lack of compiler tool chain!") +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + +if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4_dq.cpp + ${MLAS_SRC_DIR}/q4gemm.cpp + ) endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) @@ -87,6 +304,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -172,6 +390,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm @@ -202,12 +423,18 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/sgemma.asm ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + endif() else() target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -218,7 +445,7 @@ function(setup_mlas_source_for_windows) endif() endfunction() -if (onnxruntime_BUILD_WEBASSEMBLY) +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/wasm_simd/*.cpp" @@ -291,14 +518,16 @@ else() set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) endif() endif() if(APPLE) get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGH GREATER 1) + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) endif() #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below @@ -343,20 +572,33 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -514,6 +756,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") @@ -521,6 +764,7 @@ else() ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp @@ -532,9 +776,15 @@ else() ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set(mlas_platform_srcs ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp @@ -546,16 +796,25 @@ else() ${mlas_platform_srcs_avx2} ${mlas_platform_srcs_avx512f} ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} ) - if(MLAS_AMX_SUPPORTED) + if (NOT onnxruntime_ORT_MINIMAL_BUILD) set(mlas_platform_srcs ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + if(NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mamx-tile -mamx-int8 -mavx2 -mavx512bw -mavx512dq -mavx512vl") + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -567,6 +826,26 @@ else() set(MLAS_SOURCE_IS_NOT_SET 0) endif() endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") @@ -574,10 +853,26 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE inc ${MLAS_SRC_DIR}) + target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") endforeach() -set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") + +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") + endif() +endif() + +if (PLATFORM_NAME STREQUAL "macabi") + # Needed for maccatalyst C compilation + # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" + target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) +endif() if (NOT onnxruntime_BUILD_SHARED_LIB) install(TARGETS onnxruntime_mlas @@ -586,3 +881,61 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() + source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs}) +endblock() + +# +#if (NOT onnxruntime_ORT_MINIMAL_BUILD) +# +# # +# # Command line tool for quantization and de-quantization of 2-D fp32 tensors +# # based on block-wise quantization of int4 +# # +# +# onnxruntime_add_executable(onnxruntime_mlas_q4dq +# ${MLAS_SRC_DIR}/q4_dq_cli.cpp +# ) +# target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) +# set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") +# +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) +# if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) +# endif() +# if(NOT WIN32) +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) +# endif() +# if (CMAKE_SYSTEM_NAME STREQUAL "Android") +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) +# endif() +# +# if(WIN32) +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) +# endif() +# if (onnxruntime_LINK_LIBATOMIC) +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) +# endif() +# target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) +# +# if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") +# if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) +# set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") +# else() +# set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") +# endif() +# endif() +# +#endif() diff --git a/third_party/mlas/inc/mlas.h b/third_party/mlas/inc/mlas.h index fd6b3df934..cdfd283899 100644 --- a/third_party/mlas/inc/mlas.h +++ b/third_party/mlas/inc/mlas.h @@ -69,6 +69,9 @@ Module Name: #endif #endif +#if defined(__loongarch64) +#define MLAS_TARGET_LARCH64 +#endif // // Define the support levels for the target architecture. // @@ -87,7 +90,7 @@ Module Name: #define MLAS_F16VEC_INTRINSICS_SUPPORTED -#endif // +#endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -1219,6 +1222,26 @@ MlasQuantizeLinear( OutputType ZeroPoint ); +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias @@ -1611,6 +1634,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input @@ -1619,7 +1755,7 @@ MlasHalfGemmConvertPackB( * @param Channels # of input channels * @param OutputCount # of output pixels * @param KernelSize # kernel size - * @return + * @return */ void MLASCALL @@ -1657,7 +1793,7 @@ MlasTranspose( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize Size of the kernel - * @return + * @return */ void MLASCALL @@ -1676,7 +1812,7 @@ MlasNhwcMaxPool( * @param Channels C in NHWC * @param OutputCount Number of output pixels * @param KernelSize size of the kernel - * @return + * @return */ void MLASCALL diff --git a/third_party/mlas/inc/mlas_gemm_postprocessor.h b/third_party/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000..7ea29eb091 --- /dev/null +++ b/third_party/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,33 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_gemm_postprocessor.h + +Abstract: + + This module contains a base class for custom postprocessing following a + GEMM. + +--*/ + +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/third_party/mlas/inc/mlas_q4.h b/third_party/mlas/inc/mlas_q4.h new file mode 100644 index 0000000000..898fb23cf3 --- /dev/null +++ b/third_party/mlas/inc/mlas_q4.h @@ -0,0 +1,427 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_q4.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked int4 quantization and dequantization. + + Int4 block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +#include +#include + +/** + * @brief Define types of block quantization + */ +typedef enum { + BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ + BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ + BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ + BlkQ4Sym128 = 4 /*!< int4 Symmetric Block Quantization, 128 values per block*/ +} MLAS_BLK_QUANT_TYPE; + +/** + * @brief Computes the number of bytes required to pack and int4-quantize + * a weight matrix + * @param QType type of block quantization + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @return size of the packing buffer, 0 if the operation is not yet supported. +*/ +size_t +MLASCALL +MlasQ4GemmPackBSize( + MLAS_BLK_QUANT_TYPE QType, + size_t N, + size_t K + ); + +/** + * @brief Prepack and Quantize fp32 weight tensor to int4 blocks + * + * @param QType type of block quantization + * @param PackedBuf destination buffer + * @param FpData the pointer to fp32 matrix + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B +*/ +void +MLASCALL +MlasQ4GemmPackB( + MLAS_BLK_QUANT_TYPE QType, + void* PackedBuf, + const float* FpData, + size_t N, + size_t K, + size_t ldb + ); + + +/** + * @brief Unpack and dequantize from int4 to fp32, reverse operation of + * MlasQ4GemmPackB + * @param QType type of block quantization + * @param FpData destination buffer, the fp32 matrix + * @param PackedBuf int4 quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + */ +void +MLASCALL +MlasQ4GemmUnPackB( + MLAS_BLK_QUANT_TYPE QType, + float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb + ); + + +/** + * @brief Data parameters for Q4 GEMM routine + * C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed int4 blob + * All except C are [in] parameters + */ +struct MLAS_Q4_GEMM_DATA_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed int4 blob + * + * @param[in] QType type of block quantization used in B + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasQ4GemmBatch( + MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr + ); + + +/** + * @brief Calculate the buffer size needed for int8 block quantize + * @param[in] QType Type of block quantization used + * @param[in] M Number of rows of the input matrix + * @param[in] K Number of columns of the input matrix + * @return buffer size (in bytes) needed, 0 if not yet supported on current hardware +*/ +size_t +MLASCALL +MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K); + +/** + * @brief Given an input float 2-D matrix, perform blocked int8 quantize + * + * @param QType Type of block quantization used + * @param Qblob Pointer to the output buffer + * @param A Pointer to the float matrix + * @param M Number of rows of the input matrix + * @param K Number of columns of the input matrix + * @param lda leading dimension of the input matrix + * @param ThreadPool +*/ +void +MLASCALL +MlasQ80BlkQuant( + MLAS_BLK_QUANT_TYPE QType, + void* Qblob, + const float* A, + size_t M, + size_t K, + size_t lda, + MLAS_THREADPOOL* ThreadPool + ); + + +/** + * @brief Data parameters for Q8Q4 GEMM routine + * C = A * B + Bias + * A must be a block quantized int8 matrix + * B must be a block quantized and packed int4 blob + * All except C are [in] parameters + */ +struct MLAS_Q8Q4_GEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A (quantized int8 blob)*/ + const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a quantized int8 blob + * B must be a quantized and packed int4 blob + * + * @param[in] QType type of block quantization used in B + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasQ8Q4GemmBatch( + MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into separate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantized matrix [q_rows, q_cols]. The quantized matrix + * is in column major layout, with bits packed on the column. + * + * @tparam T + * @tparam qbits + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols +*/ +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +/** + * @brief Compute the sizes of the quantized data and quantization parameter buffers. + * + * @param qbits The bit width of each quantized value. + * @param block_size The number of quantized values in a block. + * @param columnwise Whether a block contains values from a matrix column (true) or row (false). + * @param rows Number of matrix rows. + * @param columns Number of matrix columns. + * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. + * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. + * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. + * + * If the qbits or block_size values are unsupported the output sizes will be zero. + */ +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into separate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from separate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @tparam qbits number of bits used for quantization, 4 for int4 + * + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); + +/** + * @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points + * are packed row-wise. In terms of the qbits type, dst and src have the same shape, and + * scales and zero_points have the same shape. + * columns must be multiple of 8 / qbits. + * @tparam Tin + * @tparam qbits number of bits used for quantization, 2 or 4 + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param scales points to the scales matrix, row major + * @param zero_points points to the zero_points matrix, row major + * @param dst points to the quantized matrix, shape [rows, columns] row major in qbits type. + * In uint8_t type, shape is [rows, columns * qbits / 8]. + * @param columnwise true when quantize elements in a column, false when quantize elements in a row. + * @param rows + * @param columns + * @param quant_block_size number of elements in a quantize block + * @param thread_pool + */ +template +void +MlasQDQQuantizeBlockwise( + const Tin* src, + Tin* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero + * points are packed row-wise. The dst tensors are column major. dst weights and zero points + * are packed column-wise. + * @tparam Tin + * @tparam qbits number of bits used for quantization, 2 or 4 + * @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type. + * In uint8_t type, shape is [rows, columns * qbits / 8]. + * @param src_scales points to the scales matrix, row major + * @param src_zero_points points to the zero_points matrix, row major. Packed row-wise. + * @param dst_weights points to the quantized matrix, column major. Packed column-wise. + * @param dst_scales points to the scales matrix, column major + * @param dst_zero_points points to the zero_points matrix, column major. Packed column-wise. + * @param columnwise true when quantize elements in a column, false when quantize elements in a row. + * @param rows + * @param columns + * @param quant_block_size number of elements in a quantize block + * @param thread_pool + */ +template +void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const Tin* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + Tin* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); diff --git a/third_party/mlas/inc/mlas_qnbit.h b/third_party/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000..32e9cc9810 --- /dev/null +++ b/third_party/mlas/inc/mlas_qnbit.h @@ -0,0 +1,181 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Define compute types of block quantization, in order of decreasing accuracy. + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32, /*!< input fp32, accumulator fp32 */ + CompFp16, /*!< input fp16, accumulator fp16 */ + CompBf16, /*!< input bf16, accumulator fp32 */ + CompInt8, /*!< input int8, accumulator int32 */ + + // special values that should be the first and last actual values + + CompMostAccurate = CompUndef, + CompLeastAccurate = CompInt8, +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * + * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasSQNBitGemmPackQuantBData(). + * + * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * point to an intermediate workspace buffer. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] Workspace Address of intermediate workspace buffer. + If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + buffer with at least that many bytes. Otherwise, it may be nullptr. + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool = nullptr +); + +/** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM + * implementation. If zero, no intermediate workspace is required. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ +size_t MLASCALL +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Gets the size in bytes of the packed quantized B data. + * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). + * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasSQNBitGemmBatch(). + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +/** + * @brief Packs the quantized B data in a format that the kernel expects. + * + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[out] PackedQuantBData packed quantized B data + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool = nullptr +); diff --git a/third_party/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/third_party/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000..e18846c890 --- /dev/null +++ b/third_party/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/third_party/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/third_party/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S new file mode 100644 index 0000000000..baf6e21e6f --- /dev/null +++ b/third_party/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmU8X8KernelUmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmU8X8KernelUmmlaFunction Zero + QgemmU8X8KernelUmmlaFunction Add + + .end diff --git a/third_party/mlas/lib/aarch64/SbgemmKernelNeon.S b/third_party/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000..e424c30515 --- /dev/null +++ b/third_party/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/third_party/mlas/lib/activate.cpp b/third_party/mlas/lib/activate.cpp index 6c4ab8ae11..df3b884a7e 100644 --- a/third_party/mlas/lib/activate.cpp +++ b/third_party/mlas/lib/activate.cpp @@ -143,6 +143,8 @@ struct MLAS_ACTIVATION_FUNCTION return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); #elif defined(MLAS_VSX_INTRINSICS) return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); #else return MlasBlendFloat32x4(ValueTimesAlpha, Value, ZeroFloat32x4 < Value); #endif diff --git a/third_party/mlas/lib/amd64/SoftmaxKernelAvx512F.asm b/third_party/mlas/lib/amd64/SoftmaxKernelAvx512F.asm new file mode 100644 index 0000000000..3e83bc852f --- /dev/null +++ b/third_party/mlas/lib/amd64/SoftmaxKernelAvx512F.asm @@ -0,0 +1,103 @@ +;++ +; +;Copyright (c) Microsoft Corporation. All rights reserved. +; +;Licensed under the MIT License. +; +;Module Name: +; +; SoftmaxKernelAvx512F.asm +; +;Abstract: +; +; This module implements the kernels for the single precision softmax +; operation. +; +; This implementation uses AVX512F instructions. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + EXTERN MlasMinimumF32Value:NEAR + +;++ +; +;Routine Description: +; +; This routine implements a vectorized kernel to find the maximum value of +; the supplied buffer. +; +;Arguments: +; +; Input (rcx) - Supplies the input buffer. +; +; N (rdx) - Supplies the number of elements to process. +; +;Return Value: +; +; Returns the maximum value of the supplied buffer. +; +;-- + + LEAF_ENTRY MlasReduceMaximumF32KernelAvx512F, _TEXT + + vbroadcastss zmm0,DWORD PTR [MlasMinimumF32Value] + test rdx,rdx + jz ExitKernel + cmp rdx,16 + jb ProcessRemainingCountBy1 + cmp rdx,64 + jb ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + vmaxps zmm1,zmm1,ZMMWORD PTR [rcx+16*4] + sub rdx,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rcx+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rcx+48*4] + add rcx,64*4 ; advance input by 64 elements + cmp rdx,64 + jae ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 ; reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +ProcessRemainingCountBy16: + cmp rdx,16 + jb ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] + sub rdx,16 + add rcx,16*4 ; advance input by 16 elements + jmp ProcessRemainingCountBy16 + +ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 ; reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0EEh + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,055h + vmaxss xmm0,xmm0,xmm1 + test rdx,rdx + jz ExitKernel + +ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rcx] + add rcx,4 ; advance input by 1 element + dec edx + jnz ProcessRemainingCountBy1 + +ExitKernel: + vzeroupper + ret + + LEAF_END MlasReduceMaximumF32KernelAvx512F, _TEXT + + END diff --git a/third_party/mlas/lib/amx_common.h b/third_party/mlas/lib/amx_common.h new file mode 100644 index 0000000000..caf94af023 --- /dev/null +++ b/third_party/mlas/lib/amx_common.h @@ -0,0 +1,80 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + amx_common.h + +Abstract: + + Intrinsic and inline functions for amx processing. + +--*/ + +#pragma once + +#include "mlasi.h" + +#ifdef _WIN32 +#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) + +#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) + +#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2) + +#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2) + +#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride) + +#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride) + +#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride) + +#define tile_loadconfig(config) \ + _tile_loadconfig(config) + +#define tile_storeconfig(config) _tile_storeconfig(config) + +#else + +#define tile_dpbusd_internal(dst,src1,src2) \ +__asm__ volatile (".set Payload1, 0x01\n\t" \ + ".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \ + ".set ModRMByte, 0xC0\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".set ModRMByte, ModRMByte + ("#src1")\n\t" \ + ".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t") + +#define tile_dpbusd(dst,src1,src2) \ +tile_dpbusd_internal(dst,src1,src2) + +#define tile_loadd_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_loadd(dst,base,stride) \ + tile_loadd_internal1(dst, base, stride) + + +#define tile_stored_internal1(dst,base,stride) \ + __asm__ volatile (".set ModRMByte, 0x04\n\t" \ + ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ + ".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \ + :: "a" ((const void*) (base)), "b" ((long) (stride))) + +#define tile_stored(dst,base,stride) \ +tile_stored_internal1(dst, base, stride) + + +#define tile_loadconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#define tile_storeconfig(config) \ +__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \ + +#endif diff --git a/third_party/mlas/lib/compute.cpp b/third_party/mlas/lib/compute.cpp index 1183510551..f4c1e3da69 100644 --- a/third_party/mlas/lib/compute.cpp +++ b/third_party/mlas/lib/compute.cpp @@ -148,6 +148,9 @@ Return Value: // instead. normal = _mm_min_epi16(normal, MaximumExponent); normal = _mm_max_epi16(normal, MinimumExponent); +#elif defined(MLAS_LSX_INTRINSICS) + normal = __lsx_vmin_h(normal, MaximumExponent); + normal = __lsx_vmax_h(normal, MinimumExponent); #else normal = MlasMinimumInt32x4(normal, MaximumExponent); normal = MlasMaximumInt32x4(normal, MinimumExponent); @@ -215,6 +218,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle // and use zeroes for the upper elements. Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else Vector = MlasBroadcastFloat32x4(Input); #endif @@ -467,6 +472,8 @@ Return Value: // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and // use zeroes for the upper elements. MLAS_FLOAT32X4 Vector = _mm_load_ss(Input); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); #else MLAS_FLOAT32X4 Vector = MlasBroadcastFloat32x4(Input); #endif @@ -843,13 +850,29 @@ Return Value: const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; +#if defined(MLAS_SSE2_INTRINSICS) + // TODO: Use std::hardware_constructive_interference_size + constexpr size_t CacheLineSize = 64; + constexpr size_t ElementsPerCacheLine = CacheLineSize / sizeof(float); +#endif + while (CountN > 0) { +#if defined(MLAS_SSE2_INTRINSICS) + // + // Prefetch the next row of the input buffer. + // + + for (size_t i = 0; i * ElementsPerCacheLine < D; i++) { + _mm_prefetch((char*)(Input + D) + i * CacheLineSize, _MM_HINT_T0); + } +#endif + // // Find the maximum value for the row. // -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); @@ -874,7 +897,7 @@ Return Value: float Parameters[] = { NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #else MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); @@ -899,7 +922,7 @@ Return Value: float Parameters[] = { 1.0f / Accumulation }; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); diff --git a/third_party/mlas/lib/dgemm.cpp b/third_party/mlas/lib/dgemm.cpp index 1ef63d03c8..50c62744f1 100644 --- a/third_party/mlas/lib/dgemm.cpp +++ b/third_party/mlas/lib/dgemm.cpp @@ -530,7 +530,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined (MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/third_party/mlas/lib/loongarch64/DgemmKernelCommon.h b/third_party/mlas/lib/loongarch64/DgemmKernelCommon.h new file mode 100644 index 0000000000..8d812baabd --- /dev/null +++ b/third_party/mlas/lib/loongarch64/DgemmKernelCommon.h @@ -0,0 +1,27 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the double + precision matrix/matrix multiply operation (DGEMM). + +--*/ + +#define LFgemmElementShift 3 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.d) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.d) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.d) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.d) diff --git a/third_party/mlas/lib/loongarch64/DgemmKernelLasx.S b/third_party/mlas/lib/loongarch64/DgemmKernelLasx.S new file mode 100644 index 0000000000..2f197d6891 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/DgemmKernelLasx.S @@ -0,0 +1,32 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "DgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmDoubleKernelLasx + + .end diff --git a/third_party/mlas/lib/loongarch64/DgemmKernelLsx.S b/third_party/mlas/lib/loongarch64/DgemmKernelLsx.S new file mode 100644 index 0000000000..63395631a9 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/DgemmKernelLsx.S @@ -0,0 +1,217 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.d) +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 8xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a1 (rsi) - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to two elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy8 RowCount + + vld $vr4, $a1, 0 + vld $vr5, $a1, 16 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr8, $vr4, $vr0, $vr8 + vfmadd.d $vr9, $vr5, $vr0, $vr9 +.if \RowCount\() == 2 + vfmadd.d $vr12, $vr6, $vr1, $vr12 + vfmadd.d $vr13, $vr7, $vr1, $vr13 +.endif + vld $vr4, $a1, 32 + vld $vr5, $a1, 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.d $vr10, $vr4, $vr0, $vr10 + vfmadd.d $vr11, $vr5, $vr0, $vr11 +.if \RowCount\() == 2 + vfmadd.d $vr14, $vr6, $vr1, $vr14 + vfmadd.d $vr15, $vr7, $vr1, $vr15 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop8xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8,$vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9,$vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10,$vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11,$vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12,$vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13,$vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14,$vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15,$vr15,$vr15" + move $t7,$a3 # reload CountK +.LCompute8xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vreplgr2vr.d $vr0, $s0" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vreplgr2vr.d $vr1, $s0" + ComputeBlockSseBy8 \RowCount\() + addi.d $a1, $a1, 8*8 # advance matrix B by 8 columns + addi.d $a0, $a0, 8 # advance matrix A by 1 column + addi.d $t7, $t7, -1 + bnez $t7, .LCompute8xNBlockBy1Loop\@ + +.LOutput8xNBlock\@: + movfr2gr.d $s0, $f24 + vreplgr2vr.d $vr2, $s0 + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr8, $vr8, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr9, $vr9, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr10,$vr10, $vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr11,$vr11, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr12,$vr12, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr13,$vr13, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr14,$vr14, $vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr15,$vr15, $vr2" + li.d $s0, 8 + blt $a5, $s0, .LOutputPartial8xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 8*8 # advance matrix C by 8 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop8xN\@ + b .LExitKernel + +// +// Output a partial 8xN block to the matrix. +// + +.LOutputPartial8xNBlock\@: + li.d $s0, 2 + blt $a5, $s0, .LOutputPartial1xNBlock\@ + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 6 + blt $a5, $s0, .LOutputPartialLessThan6xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $s0, $a5, 1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr15" + addi.d $a2, $a2, 6*8 # advance matrix C by 6 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan6xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr14" + addi.d $a2, $a2, 4*8 # advance matrix C by 4 columns + b .LOutputPartial1xNBlock\@ + +.LOutputPartialLessThan4xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $s0, $a5,1 # check if remaining count is small + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr13" + addi.d $a2, $a2, 2*8 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + bnez $t5, .LSkipAccumulateOutput1xN\@ # ZeroMode? + + EmitIfCountGE \RowCount\(), 1, "fld.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.d $f15, $f15, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.d $f16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.d $f16, $f16, $f12" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.d $f15, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.d $f16, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmDoubleKernelLSX + + .end diff --git a/third_party/mlas/lib/loongarch64/FgemmKernelCommon.h b/third_party/mlas/lib/loongarch64/FgemmKernelCommon.h new file mode 100644 index 0000000000..777a592590 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/FgemmKernelCommon.h @@ -0,0 +1,100 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the floating + point matrix/matrix multiply operation (SGEMM and DGEMM). + +--*/ + +// +// Define the typed instruction template. +// + +#define FGEMM_TYPED_INSTRUCTION(Untyped, Typed) \ + .macro Untyped Operand:vararg; Typed \Operand\(); .endm; + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + + AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer + in rbx should also be advanced as part of the loop. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 3 rows. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLoop ComputeBlock, RowCount, AdvanceMatrixAPlusRows + + move $t8, $a3 # reload CountK + li.d $s0, 4 + blt $t8, $s0, .LProcessRemainingBlocks\@ + +.LComputeBlockBy4Loop\@: + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*0, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*1, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*2, 64*4 + \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*3, 64*4 + addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes + addi.d $a0, $a0, 4*LFgemmElementSize # advance matrix A by 4 elements +.if \RowCount\() > 3 + addi.d $t7, $t7, 4*LFgemmElementSize # advance matrix A plus rows by 4 elements +.if \RowCount\() == 12 + addi.d $t3, $t3, 4*LFgemmElementSize + addi.d $t4,, $t4, 4*LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -4 + li.d $s0, 4 + bge $t8, $s0, .LComputeBlockBy4Loop\@ + +.LProcessRemainingBlocks\@: + beqz $t8, .LOutputBlock\@ + +.LComputeBlockBy1Loop\@: + \ComputeBlock\() \RowCount\(), 0, 0 + addi.d $a1, $a1, 2*32 # advance matrix B by 64 bytes + addi.d $a0, $a0, LFgemmElementSize # advance matrix A by 1 element +.if \RowCount\() > 3 + addi.d $t7, $t7, LFgemmElementSize # advance matrix A plus rows by 1 element +.if \RowCount\() == 12 + addi.d $t3, $t3, LFgemmElementSize + addi.d $t4, $t4, LFgemmElementSize +.endif +.endif + addi.d $t8, $t8, -1 + bnez $t8, .LComputeBlockBy1Loop\@ + +.LOutputBlock\@: + + .endm diff --git a/third_party/mlas/lib/loongarch64/FgemmKernelLasxCommon.h b/third_party/mlas/lib/loongarch64/FgemmKernelLasxCommon.h new file mode 100644 index 0000000000..b96db84861 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/FgemmKernelLasxCommon.h @@ -0,0 +1,546 @@ + +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLasxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses LASX instructions. + +--*/ + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy16 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr4, $a1, \VectorOffset\() + xvfmadd $xr8, $xr4, $xr3, $xr8 + xvld $xr5, $a1, \VectorOffset\()+32 + xvfmadd $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + xvld $xr1, $a1, \VectorOffset\()+32 + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3,$a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr8, $xr3, $xr0, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr1, $xr9" + EmitIfCountGE \RowCount\(), 2, "add.d $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr10, $xr3, $xr0, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr1, $xr11" + + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3,$t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr12, $xr3, $xr0, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr1, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0,$t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3,$s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr14, $xr3, $xr0, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr1, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for 1 YMMWORD by N rows of the output + matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. + + PrefetchOffset - Optionally supplies the byte offset from matrix B to + prefetch elements. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + t7 - Supplies the address into the matrix A data plus 2 rows. + + a1 - Supplies the address into the matrix B data. + + t0 - Supplies the length in bytes of a row from matrix A. + + xr8-xr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxBy8 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset + +.if \RowCount\() == 1 + xvldrepl.w $xr3, $a0, \BroadcastOffset\() + xvld $xr5, $a1, \VectorOffset\() + xvfmadd.s $xr9, $xr5, $xr3, $xr9 +.else + xvld $xr0, $a1, \VectorOffset\() + EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3, $a0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr0, $xr9" + + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a0, $t0" + EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr0, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3, $t7, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr0, $xr13" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t0" + EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3, $s0, \BroadcastOffset\()" + EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr0, $xr15" +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to execute the block compute macro multiple + times and advancing the matrix A and matrix B data pointers. + +Arguments: + + ComputeBlock - Supplies the macro to compute a single block. + + RowCount - Supplies the number of rows to process. + +Implicit Arguments: + + a0 - Supplies the address into the matrix A data. + + a1 - Supplies the address into the matrix B data. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + vr4-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockLasxLoop ComputeBlock, RowCount + +.if \RowCount\() > 2 + # compute matrix A plus 2 rows + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 2 +.if \RowCount\() > 2 + # compute matrix C plus 2 rows + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + .endm + + .macro store_n src, num, dst + move $s2, \num\() + beqz $s2, .Lstore_exit\@ + xvstelm.w \src\(), \dst\(), 0, 0 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 4, 1 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 8, 2 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 12, 3 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 16, 4 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 20, 5 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + + xvstelm.w \src\(), \dst\(), 24, 6 + addi.d $s2, $s2, -1 + beqz $s2, .Lstore_exit\@ + +.Lstore_exit\@: + .endm +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t1 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t0 - Supplies the length in bytes of a row from matrix A. + + t6 - Supplies the length in bytes of a row from matrix C. + + t5 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough + + ori $s1, $r0, LFgemmYmmElementCount + bgeu $s1, $a5, .LProcessRemainingCountN\@ + +.LProcessNextColumnLoop2xN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr12, $xr12, $xr12" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr14, $xr14, $xr14" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + ComputeBlockLasxLoop ComputeBlockLasxBy16, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr8, $xr8, $xr2" + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr10, $xr10, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr12, $xr12, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr14, $xr14, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + + sub.d $a5, $a5, $s1 + sub.d $a5, $a5, $s1 + blt $a5, $zero, .LOutputMasked2xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0x20" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvld $xr16, $s0, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0x20" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvst $xr11, $s0, 0x20" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0x20" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvst $xr15, $s0, 0x20" + + addi.d $a2, $a2, 0x40 # advance matrix C by 2 XRWORDs + move $a0, $t1 # reload matrix A + bltu $s1, $a5, .LProcessNextColumnLoop2xN\@ + beqz $a5, .LExitKernel + +.LProcessRemainingCountN\@: + EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" + EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" + EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" + + + ComputeBlockLasxLoop ComputeBlockLasxBy8, \RowCount\() + EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" + EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" + EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" + EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" + bltu $a5, $s1, .LOutputMasked1xNBlock\@ + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStore1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" + +.LStore1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr11, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr15, $t7, $t6" + b .LExitKernel + +.LOutputMasked2xNBlock\@: + andi $s0, $t5, 0xff # ZeroMode? + bnez $s0, .LStoreMasked2xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" + +.LStoreMasked2xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" + EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" + EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" + addi.d $a2, $a2, 0x20 # advance matrix C by YMMWORD +.if \RowCount\() > 2 + addi.d $t7, $t7, 0x20 # advance matrix C plus 2 rows by YMMWORD + +.endif + addi.d $a5, $a5, LFgemmYmmElementCount # correct for over-subtract above + + +.LOutputMasked1xNBlock\@: + +.if \RowCount\() > 2 + slli.d $s0, $t0, 1 + add.d $t7, $a0, $s0 +.endif + +.if \RowCount\() == 1 +.else +.endif + +.if \RowCount\() > 2 + slli.d $s0, $t6, 1 + add.d $t7, $a2, $s0 +.endif + + sub.d $a5, $zero, $a5 + la.global $a0, MlasMaskMoveTableLasx + ori $s0, $r0, LFgemmElementSize + mul.d $s0, $a5, $s0 + addi.d $s0, $s0, 8*4 + xvldx $xr0, $a0, $s0 + andi $s0, $t5, 0xff + + sub.d $a5, $zero, $a5 + + bnez $s0, .LStoreMasked1xNBlock\@ + EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "xvand.v $xr8, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "xvand.v $xr10, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" + EmitIfCountGE \RowCount\(), 3, "xvand.v $xr12, $xr16, $xr0" + EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" + EmitIfCountGE \RowCount\(), 4, "xvand.v $xr14, $xr16, $xr0" + + EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr8" + EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr10" + EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr12" + EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr14" +.LStoreMasked1xNBlock\@: + EmitIfCountGE \RowCount\(), 1, "store_n $xr9, $a5, $a2" + + add.d $s3, $a2, $t6 + EmitIfCountGE \RowCount\(), 2, "store_n $xr11, $a5, $s3" + + EmitIfCountGE \RowCount\(), 3, "store_n $xr13, $a5, $t7" + + add.d $s3, $t7, $t6 + EmitIfCountGE \RowCount\(), 4, "store_n $xr15, $a5, $s3" + sub.d $a5, $zero, $a5 +.ifb \Fallthrough\() + b .LExitKernel +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLasxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A a0 - Supplies the address of matrix A. + + B a1 - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C a2 - Supplies the address of matrix C. + + CountK a3 - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM a4 - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN a5 - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda a6 - Supplies the first dimension of matrix A. + + ldc a7 - Supplies the first dimension of matrix C. + + Alpha f0 - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp + 0)- Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + + FUNCTION_ENTRY \FunctionName\() + + addi.d $sp, $sp, -64 + st.d $ra, $sp, 56 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + fst.s $f0, $sp, 2*8 + fst.d $f16, $sp,3*8 + st.d $s2, $sp, 4*8 + st.d $s3, $sp, 5*8 + + move $t1, $a0 + slli.d $t0, $a6, 2 # convert lda to bytes + slli.d $t6, $a7, 2 # convert ldc to bytes + ld.d $t5, $sp, 64 # get zeromode + fst.s $f0, $sp, 2*8 + xvldrepl.w $xr2, $sp, 0x10 + +// +// Process 4 rows of the matrices. +// + + ori $s0, $zero, 4 + bltu $a4, $s0, .LProcessCountMLessThan4 + li.d $a4, 4 # return 4 rows handled + ProcessCountM 4, Fallthrough + +// +// Restore non-volatile registers and return. +// + +.LExitKernel: + bstrpick.d $a0, $a4, 31, 0 + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + fld.d $f16, $sp,3*8 + ld.d $s2, $sp, 4*8 + ld.d $s3, $sp, 5*8 + ld.d $ra, $sp, 7*8 + addi.d $sp, $sp, 64 + jr $ra + +// +// Process 2 rows of the matrices. +// + +.LProcessCountMLessThan4: + ori $s0, $r0, 2 + bltu $a4, $s0, .LProcessCountMLessThan2 + li.d $a4, 2 # return 2 rows handled + ProcessCountM 2 + +// +// Process 1 row of the matrices. +// + +.LProcessCountMLessThan2: + ProcessCountM 1 + + .endm diff --git a/third_party/mlas/lib/loongarch64/FgemmKernelLsxCommon.h b/third_party/mlas/lib/loongarch64/FgemmKernelLsxCommon.h new file mode 100644 index 0000000000..0333af792b --- /dev/null +++ b/third_party/mlas/lib/loongarch64/FgemmKernelLsxCommon.h @@ -0,0 +1,170 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelLsxCommon.h + +Abstract: + + This module implements the kernels for the floating point matrix/matrix + multiply operation (SGEMM and DGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "FgemmKernelCommon.h" +/*++ + +Macro Description: + + This stores the block accumulators to the output matrix with an optional + accumulation of the existing contents of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorCount - Supplies the number of vector columns to process. + +Implicit Arguments: + + t5 - Supplies the length in bytes of a row from matrix C. + + a2 - Supplies the address of matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro AccumulateAndStoreBlock RowCount, VectorCount + + and $s0, $t5,$t5 # ZeroMode? + bnez $s0 , .LSkipAccumulateOutput\@ + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vld $vr0, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vld $vr1, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vld $vr2, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vld $vr3, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vldx $vr4, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vldx $vr5, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vldx $vr6, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vldx $vr7, $a2, $s0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vfadd $vr8, $vr8, $vr0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vfadd $vr9, $vr9, $vr1" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vfadd $vr10,$vr10,$vr2" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vfadd $vr11,$vr11,$vr3" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vfadd $vr12,$vr12,$vr4" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vfadd $vr13,$vr13,$vr5" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vfadd $vr14,$vr14,$vr6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vfadd $vr15,$vr15,$vr7" + +.LSkipAccumulateOutput\@: + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vst $vr8, $a2, 0" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vst $vr9, $a2, 16" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vst $vr10, $a2, 32" + EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vst $vr11, $a2, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vstx $vr12, $a2, $t6" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vstx $vr13, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vstx $vr14, $a2, $s0" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" + EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vstx $vr15, $a2, $s0" + + .endm +/*++ + +Macro Description: + + This macro generates the inner kernel to compute matrix multiplication. + +Arguments: + + FunctionName - Supplies the name for the generated function. + +--*/ + + .macro FgemmKernelLsxFunction FunctionName + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (a0) - Supplies the address of matrix A. + + B (a1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C (a2) - Supplies the address of matrix C. + + CountK (a3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (a4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (a5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (a6) Supplies the first dimension of matrix A. + + ldc (a7) Supplies the first dimension of matrix C. + + Alpha (f0) - Supplies the scalar alpha multiplier (see GEMM definition). + + ZeroMode (sp 0) - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ + +FUNCTION_ENTRY \FunctionName\() + addi.d $sp, $sp, -64 + st.d $t5, $sp, 0 + st.d $s0, $sp, 1*8 + st.d $s1, $sp, 2*8 + st.d $s2, $sp, 3*8 + st.d $s3, $sp, 4*8 + move $t1, $a0 + slli.d $t0, $a6, 2 //convert lda to bytes + slli.d $t6, $a7, 2 //convert ldc to bytes + ld.d $t5, $sp, 64 + fmov.s $f24, $f0 //f0 destroyed by lsx + + li.d $s0, 2 + blt $a4, $s0, .LProcessCountM1 + + li.d $a4, 2 + ProcessCountM 2, Fallthrough + +.LExitKernel: + ld.d $t5, $sp, 0 + ld.d $s0, $sp, 1*8 + ld.d $s1, $sp, 2*8 + ld.d $s2, $sp, 3*8 + ld.d $s3, $sp, 4*8 + addi.d $sp, $sp, 64 + move $a0, $a4 + jr $ra + +.LProcessCountM1: + ProcessCountM 1 + .endm diff --git a/third_party/mlas/lib/loongarch64/SconvKernelLasx.S b/third_party/mlas/lib/loongarch64/SconvKernelLasx.S new file mode 100644 index 0000000000..e035035219 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SconvKernelLasx.S @@ -0,0 +1,412 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t7 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + xr0-xr7 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + xvld $xr12, $a2, 0 + EmitIfCountGE \OutputCount\(), 1, "xvld $xr8, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr12, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr9, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvfmadd.s $xr4, $xr9, $xr12, $xr4" + +.else + EmitIfCountGE \OutputCount\(), 1, "xvldrepl.w $xr13, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 2, "add.d $s0, $a3, $a5" + EmitIfCountGE \OutputCount\(), 2, "xvldrepl.w $xr14, $s0, \BroadcastOffset\()" +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr13, $xr0" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr9, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "xvfmadd.s $xr1, $xr9, $xr13, $xr1" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr10, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "xvfmadd.s $xr2, $xr10, $xr13, $xr2" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr11, $s0, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "xvfmadd.s $xr3, $xr11, $xr13, $xr3" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a2, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmadd.s $xr0, $xr12, $xr13, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmadd.s $xr4, $xr12, $xr14, $xr4" + EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmadd.s $xr1, $xr13, $xr12, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmadd.s $xr5, $xr14, $xr12, $xr5" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr12, $t7, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmadd.s $xr2, $xr13, $xr12, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmadd.s $xr6, $xr14, $xr12, $xr6" + EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr12, $s0, \VectorOffset\()" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmadd.s $xr3, $xr13, $xr12, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmadd.s $xr7, $xr14, $xr12, $xr7" +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + t7 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + +// +// Process the output blocks that include left padding. +// + + ld.d $t0, $sp, OutputCountLeftPad_arg + beqz $t0, .L\KernelType\().\FilterCount\().ProcessOutputCount + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + +// +// Process the output blocks that do not include any padding. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 2 + bltu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount + +.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2: + ProcessOutputCountN Lasx, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2 + +.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: + +// +// Process the output blocks that include right padding plus any remaining output +// blocks from above. +// + +.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\KernelType\().ExitKernel + bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t0 - Supplies the OutputCount parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount + li.d $s0, 2 + bltu $t0, $s0, .LPointwise.\FilterCount\().ProcessRemainingOutputCount + +.LPointwise.\FilterCount\().ProcessNextOutputCountBy2: + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 2 + slli.d $s0, $a5, 1 # advance input by 2 elements + add.d $a0, $a0, $s0 + addi.d $t0, $t0, -2 + li.d $s0, 2 + bgeu $t0, $s0, .LPointwise.\FilterCount\().ProcessNextOutputCountBy2 + +.LPointwise.\FilterCount\().ProcessRemainingOutputCount: + beqz $t0, .LPointwise.ExitKernel + ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 1 + + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, Lasx + SconvKernelFunction Nchwc, 8, Lasx, BiasFilter + SconvKernelDepthwiseFunction 8, Lasx + SconvKernelPointwiseFunction Lasx, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\(): + + .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() + .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() +MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + slli.d $s0, $t6, 1 # compute output plus 2 rows + add.d $t7, $a4, $s0 +.endif + +// +// Test if the existing contents of the output buffer should be accumulated +// with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvld $xr16, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvld $xr16, $a4, 32" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvld $xr16, $a4, 0x40" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvldx $xr16, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvld $xr16, $s0, 0x20" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr16" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvld $xr16, $s0, 0x40" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr16" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvld $xr16,$t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvld $xr16,$t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvld $xr16,$t7, 0x40" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvldx $xr16,$t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvld $xr16,$s0, 0x20" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr16" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvld $xr16,$s0, 0x40" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr16" + + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: + +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition +.if \OutputCount\() == 1 + EmitIfCountGE \FilterCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \FilterCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr16, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr16, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr16, $a3, 0x60" + EmitIfCountGE \FilterCount\(), 4, "xvfadd.s $xr3, $xr3, $xr16" +.else + EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a3, 0" + EmitIfCountGE \FilterCount\(), 2, "xvld $xr13, $a3, 0x20" + EmitIfCountGE \FilterCount\(), 3, "xvld $xr14, $a3, 0x40" + EmitIfCountGE \FilterCount\(), 4, "xvld $xr15, $a3, 0x60" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr12" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr12" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr13" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr13" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr14" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr15" + +.endif + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + xvxor.v $xr15, $xr15, $xr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmax.s $xr0, $xr15, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmax.s $xr4, $xr15, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfmax.s $xr8, $xr15, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmax.s $xr1, $xr15, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmax.s $xr5, $xr15, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfmax.s $xr9, $xr15, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmax.s $xr2, $xr15, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmax.s $xr6, $xr15, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfmax.s $xr10, $xr15, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmax.s $xr3, $xr15, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmax.s $xr7, $xr15, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfmax.s $xr11, $xr15, $xr11" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvst $xr0, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvst $xr4, $a4, 0x20" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvst $xr8, $a4, 0x40" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvstx $xr1, $a4, $t6" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvst $xr5, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvst $xr9, $s0, 0x40" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvst $xr2, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvst $xr6, $t7, 0x20" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvst $xr10, $t7, 0x40" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvstx $xr3, $t7, $t6" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvst $xr7, $s0, 0x20" + + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvst $xr11, $s0, 0x40" + + add_immed $a4,\OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1, 2, 3 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/third_party/mlas/lib/loongarch64/SconvKernelLasxCommon.h b/third_party/mlas/lib/loongarch64/SconvKernelLasxCommon.h new file mode 100644 index 0000000000..bd2db816ed --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SconvKernelLasxCommon.h @@ -0,0 +1,868 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lasx kernels. + +--*/ + + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 +#define Filter_save_offset 18*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t5 - Supplies the InputStride parameter (see function description). +--*/ + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg + ld.d $t2, $sp, KernelWidth_arg +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $a3, $s0 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 # compute filter plus 2 rows + add.d $t7, $a2, $s0 +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif + +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + # advance input by dilation width + add.d $a3, $a3, $t8 +.ifeqs "\KernelType\()","Nchwc" + # advance filter by 8i8o/16i16o block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// + +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6)- Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp + 8)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp + 0x10)- Supplies the width of the kernel to apply. + + InputBase (sp + 0x18)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x20)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x28)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x30)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x38)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x40)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x48)- Supplies the address of the bias buffer. + + Flags (sp + 0x50)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, KernelHeight_arg + st.d $t2, $sp, KernelWidth_arg + st.d $t3, $sp, InputBase_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, InputWidth_arg + st.d $t1, $sp, DilatedInputWidth_arg + st.d $t2, $sp, OutputCountLeftPad_arg + st.d $t3, $sp, OutputCount_arg + ld.d $t0, $sp, SP_SIZE+8*8 + ld.d $t1, $sp, SP_SIZE+9*8 + ld.d $t2, $sp, SP_SIZE+10*8 + st.d $t0, $sp, OutputCountRightPad_arg + st.d $t1, $sp, Bias_arg + st.d $t2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1, 4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jirl $zero, $ra, 0 + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + + .irp FilterCount, 1, 2, 3, 4 + +MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + st.d $ra, $sp, 19*8 +loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): + ProcessOutputCountN \Isa\(), LSconvKernelSingleFrame, \KernelType\(), \BlockSize\(), \FilterCount\(), 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\() + ld.d $ra, $sp, 19*8 + jr $ra + + .endr + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a5) - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight (a6)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0 )- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 8 )- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp + 0x30)- Supplies the address of the bias buffer. + + Flags (sp + 0x38)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, InputBase_arg + st.d $t1, $sp, InputWidth_arg + st.d $t2, $sp, DilatedInputWidth_arg + st.d $t3, $sp, OutputCountLeftPad_arg + ld.d $t0, $sp, SP_SIZE+4*8 + ld.d $t1, $sp, SP_SIZE+5*8 + ld.d $t2, $sp, SP_SIZE+6*8 + ld.d $t3, $sp, SP_SIZE+7*8 + st.d $t0, $sp, OutputCount_arg + st.d $t1, $sp, OutputCountRightPad_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + + move $t8, $a4 + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ProcessFilterCountN LSconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. +// + +.LDepthwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + +.ifnes "\Isa\()","LSX" + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasConvDepthwiseFloatSingle\Isa\()Filter1: + st.d $ra, $sp, 20*8 +MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop: + ProcessOutputCountN \Isa\(), LSconvKernelDepthwiseSingleFrame, Depthwise, \BlockSize\(), 1, 1 + add.d $a0, $a0, $a5 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + + bnez $t0, MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop + ld.d $ra, $sp, 20*8 + jr $ra + +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t2 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + slli.d $s0, $a5, 1 + add.d $s0, $s0, $a5 + add.d $t4, $s0, $a3 +.endif +.if \FilterCount\() > 2 + slli.d $s0, $a1, 1 + add.d $t7, $a2, $s0 +.endif +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 # decrement input blocks remaining + + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + + ld.w $a2, $sp, Flags_arg +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() + + .endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case where the kernel dimensions are 1. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp + 8)- Supplies the number of output elements. + + Bias (sp + 0x10)- Supplies the address of the bias buffer. + + Flags (sp + 0x18)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $ra, $sp, 5*8 + + ld.d $t0, $sp, SP_SIZE+0*8 + ld.d $t1, $sp, SP_SIZE+1*8 + ld.d $t2, $sp, SP_SIZE+2*8 + ld.d $t3, $sp, SP_SIZE+3*8 + st.d $t0, $sp, OutputStride_arg + st.d $t1, $sp, OutputCount_arg + st.d $t2, $sp, Bias_arg + st.d $t3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + ld.d $t0, $sp, OutputCount_arg + move $a1, $a7 + move $t8, $a6 + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + + ori $s0, $zero, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + bltu $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + ori $s0, $zero, 2 + bltu $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// + +.LPointwise.ExitKernel: +.ifnes "\Isa\()","LSX" + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 +.endif + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + xr0-xr11 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvxor.v $xr4, $xr4, $xr4" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvxor.v $xr8, $xr8, $xr8" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvxor.v $xr5, $xr5, $xr5" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvxor.v $xr9, $xr9, $xr9" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvxor.v $xr2, $xr2, $xr2" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvxor.v $xr6, $xr6, $xr6" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvxor.v $xr10, $xr10, $xr10" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvxor.v $xr3, $xr3, $xr3" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvxor.v $xr7, $xr7, $xr7" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvxor.v $xr11, $xr11, $xr11" + + .endm diff --git a/third_party/mlas/lib/loongarch64/SconvKernelLsx.S b/third_party/mlas/lib/loongarch64/SconvKernelLsx.S new file mode 100644 index 0000000000..04b8dc14d0 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SconvKernelLsx.S @@ -0,0 +1,339 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsx.S + +Abstract: + + This module implements the kernels for the single precision convolution + operation. + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "SconvKernelLsxCommon.h" + +/*++ + +Macro Description: + + This macro generates code to clear the block accumulators. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + vr0-vr7 - Supplies the block accumulators. + +--*/ + + .macro ClearBlock FilterCount, OutputCount + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr0,$vr0,$vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr1,$vr1,$vr1" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr2,$vr2,$vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr3,$vr3,$vr3" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr4,$vr4,$vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr5,$vr5,$vr5" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr6,$vr6,$vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr7,$vr7,$vr7" + + .endm + +/*++ + +Macro Description: + + This macro multiplies and accumulates for FilterCount by OutputCount block + of the output buffer. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + + VectorOffset - Supplies the byte offset from the filter buffer to fetch + elements. + + BroadcastOffset - Supplies the byte offset from the input buffer to fetch + elements. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a2 - Supplies the address of the filter buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + t6 - Supplies the address of the filter buffer plus 2 * FilterStride. + + a5 - Supplies the StrideWidth parameter (see function description). + + vr0-vr7 - Supplies the block accumulators. + +--*/ + .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset + +.ifeqs "\KernelType\()","Depthwise" + vld $vr8, $a2, 0 + vld $vr9, $a2, 16 + vld $vr10, $a3, 0 + vld $vr11, $a3, 16 + vfmadd.s $vr0, $vr8, $vr10, $vr0 + vfmadd.s $vr1, $vr9, $vr11, $vr1 +.else + EmitIfCountGE \OutputCount\(), 1, "ld.w $s0, $a3, \BroadcastOffset\()" + EmitIfCountGE \OutputCount\(), 1, "vreplgr2vr.w $vr12, $s0" + EmitIfCountGE \FilterCount\(), 1, "vld $vr8, $a2, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 1, "vld $vr9, $a2, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr0, $vr8, $vr12, $vr0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr1, $vr9, $vr12, $vr1" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr8, $a2, $s0" + EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 2, "vldx $vr9, $a2, $s0" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr2, $vr8, $vr12, $vr2" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr3, $vr9, $vr12, $vr3" + EmitIfCountGE \FilterCount\(), 3, "vld $vr8, $t7, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 3, "vld $vr9, $t7, \VectorOffset\()+16" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr4, $vr8, $vr12, $vr4" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr5, $vr9, $vr12, $vr5" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr8, $t7, $s0" + EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()+16" + EmitIfCountGE \FilterCount\(), 4, "vldx $vr9, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr6, $vr8, $vr12, $vr6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr7, $vr9, $vr12, $vr7" +.endif + .endm +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount + ld.d $s0, $sp, OutputCountLeftPad_arg //OutputCountLeftPad + ld.d $s1, $sp, OutputCount_arg //OutputCount + add.d $s0, $s0, $s1 + ld.d $s1, $sp, OutputCountRightPad_arg //OutputCountRightPad + add.d $t0, $s0, $s1 +.L\KernelType\().\FilterCount\().ProcessNextOutputCount: + ProcessOutputCountN Sse, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .L\KernelType\().\FilterCount\().ProcessNextOutputCount + .endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a specified number + of filter rows for a pointwise convolution. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + t7 - Supplies the OutputCount parameter (see function description). + + s5 - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseFilterCountN FilterCount +.LPointwise.\FilterCount\().ProcessNextOutputCount: + ProcessPointwiseOutputCountN Sse, 8, \FilterCount\(), 1 + add.d $a0, $a0, $a5 + addi.d $t0, $t0, -1 + bnez $t0, .LPointwise.\FilterCount\().ProcessNextOutputCount + .endm + +// +// Generate the convolution kernels. +// + + SconvKernelFunction Nchw, 8, LSX + SconvKernelFunction Nchwc, 8, LSX, BiasFilter + SconvKernelDepthwiseFunction 8, LSX + SconvKernelPointwiseFunction LSX, BiasFilter + +/*++ + +Macro Description: + + This macro generates code to process an output block after the inner + convolution kernel has executed and then stores the output block to the + output buffer. + +Arguments: + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. +--*/ + + .macro PostProcessBlock FilterCount, OutputCount + + .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#if !defined(__APPLE__) + .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() +#endif +MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): + +.if \FilterCount\() > 2 + li.d $s0, 2 + mul.d $s0, $s0, $t6 + add.d $t7, $a4, $s0 +.endif + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a4, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr10, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr11, $a4, $s0" + + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr14, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr15, $t7, $s0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: +// +// Test if the bias buffer should be accumulated with the output block. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a3, 0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a3, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr10, $a3, 32" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr11, $a3, 48" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $a3, 64" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $a3, 80" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr14, $a3, 96" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr15, $a3, 112" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: + +// +// Test for fused ReLU activation. +// + + andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION + andi $s0, $s0, 0xff + beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation + vxor.v $vr15,$vr15, $vr15 + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr0, $vr0, $vr15" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr1, $vr1, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr2, $vr2, $vr15" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr3, $vr3, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr4, $vr4, $vr15" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr5, $vr5, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr6, $vr6, $vr15" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr7, $vr7, $vr15" + +.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: + +// +// Store the output block in the output buffer. +// + + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr0, $a4,0" + EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr1, $a4, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr2, $a4, $t6" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr3, $a4, $s0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr4, $t7, 0" + EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr5, $t7, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr6, $t7, $t6" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" + EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr7, $t7, $s0" + add_immed $a4, \OutputCount\()*8*4 # advance output by N nchw8c blocks + jr $ra + + .endm + + .irp FilterCount, 1, 2, 3, 4 + .irp OutputCount, 1 + PostProcessBlock \FilterCount\(), \OutputCount\() + .endr + .endr + + .end diff --git a/third_party/mlas/lib/loongarch64/SconvKernelLsxCommon.h b/third_party/mlas/lib/loongarch64/SconvKernelLsxCommon.h new file mode 100644 index 0000000000..d03714f654 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SconvKernelLsxCommon.h @@ -0,0 +1,669 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SconvKernelLsxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision convolution operation for the Lsx kernels. + +--*/ + +#define SP_SIZE 32*8 + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#define Filter_save_offset 18*8 + +#define OutputStride_arg 6*8 +#define KernelHeight_arg 7*8 +#define KernelWidth_arg 8*8 +#define InputBase_arg 9*8 +#define InputWidth_arg 10*8 +#define DilatedInputWidth_arg 11*8 +#define OutputCountLeftPad_arg 12*8 +#define OutputCount_arg 13*8 +#define OutputCountRightPad_arg 14*8 +#define Bias_arg 15*8 +#define Flags_arg 16*8 +#define InputChannels_arg 17*8 + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a1 - Supplies the FilterStride parameter (see function description) when + KernelType!=Depthwise. Supplies the address of the filter buffer when + KernelType=Depthwise. + + s8 - Supplies the DilationWidth parameter (see function description). + + a4 - Supplies the address of the output buffer. + + a5 - Supplies the StrideWidth parameter (see function description). + + s3 - Supplies the InputStride parameter (see function description). +--*/ + + .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount + move $a3, $a0 +.ifeqs "\KernelType\()","Depthwise" + move $a2, $a1 +.else + ld.d $a2, $sp, Filter_save_offset +.endif + ld.d $t1, $sp, KernelHeight_arg //KernelHeight + ld.d $t2, $sp, KernelWidth_arg //KernelWidth +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg //InputBase + ld.d $t4, $sp, InputWidth_arg //InputWidth + sub.d $t3, $zero, $t3 # keep negative for lea usage below +.endif + ClearBlock \FilterCount\(), \OutputCount\() + beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing + +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: + move $t6, $t2 # reload kernel width remaining +.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding +.endif +.if \OutputCount\() > 3 + li.d $s2, 2 + mul.d $s2, $a5, $s2 + add.d $t4, $a5, $s2 + + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s2, 2 + mul.d $s2, $s2, $a1 + add.d $t7, $a2, $s2 //t6 is rbx used by ComputeBlock +.endif +.ifeqs "\KernelType\()","Nchwc" +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif +.else + ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 +.endif +.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $t8 # advance input by dilation width +.ifeqs "\KernelType\()","Nchwc" + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block +.else + addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block +.endif + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t5 +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg #DilatedInputWidth + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 # decrement rows remaining + bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow + +// +// Handle post processing of the output block. +// +.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: + ld.w $a2, $sp, Flags_arg + +.if \FilterCount\() > 1 + ld.d $t6, $sp, OutputStride_arg +.endif + ld.d $a3, $sp, Bias_arg + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel. + +Arguments: + + KernelType - Supplies the type of kernel to be generated. + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + + BiasFilter - Supplies a non-blank value if the address of the filter buffer + should be biased to point to the middle of a OIhw8i8o block in order to + reduce the code size from relative byte offsets. + +--*/ + + .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a4) - Supplies the length in bytes of the blocked dilation + width. + + FilterCount (a5) - Supplies the number of filters to process in this + iteration. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input row. + + FilterStride (a7)- Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp,8*0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + KernelHeight (sp,8*1)- Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (sp, 8*2)- Supplies the width of the kernel to apply. + + InputBase (sp, 8*3)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 8*4)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 8*5)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 8*6)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 8*7)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 8*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 8*9)- Supplies the address of the bias buffer. + + Flags (sp, 8*10)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, KernelHeight_arg + st.d $s2, $sp, KernelWidth_arg + st.d $s3, $sp, InputBase_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, InputWidth_arg + st.d $s1, $sp, DilatedInputWidth_arg + st.d $s2, $sp, OutputCountLeftPad_arg + st.d $s3, $sp, OutputCount_arg + ld.d $s0, $sp, SP_SIZE+8*8 + ld.d $s1, $sp, SP_SIZE+9*8 + ld.d $s2, $sp, SP_SIZE+10*8 + st.d $s0, $sp, OutputCountRightPad_arg + st.d $s1, $sp, Bias_arg + st.d $s2, $sp, Flags_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $a1, $a1,4*8*4 +.endif + st.d $a1, $sp, Filter_save_offset //store Filter + move $a1, $a7 + move $t5, $a6 + move $t8, $a4 # shuffle to Win64 register usage + move $t1, $a5 + move $a4, $a2 + move $a5, $a3 + + li.d $s0, 3 + beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 + blt $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 4 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount3: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 3 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCountLessThan3: + li.d $s0,2 + blt $t1, $s0, .L\KernelType\().ProcessFilterCount1 + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 2 + b .L\KernelType\().ExitKernel + +.L\KernelType\().ProcessFilterCount1: + ProcessFilterCountN SconvKernelFrame, \KernelType\(), 1 + +// +// Restore non-volatile registers and return. +// + +.L\KernelType\().ExitKernel: + ld.d $a1, $sp, Filter_save_offset //restore Filter + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code for the inner convolution kernel for the special + case of a depthwise separable convolution. + +Arguments: + + BlockSize - Supplies the number of elements per block. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SconvKernelDepthwiseFunction BlockSize, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Depthwise separable convolutions are a form of grouped convolution where + the number of input and output channels per group are one. + +Arguments: + + Input a0 - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Filter a1 - Supplies the address of the filter buffer. + + Output a2 - Supplies the address of the output buffer. + + StrideWidth a3 - Supplies the length in bytes of the blocked stride width. + + DilationWidth a4 - Supplies the length in bytes of the blocked dilation + width. + + InputStride a5 - Supplies the length in bytes to advance the input buffer + to the next input row. + + KernelHeight a6 - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth a7- Supplies the width of the kernel to apply. + + InputBase (sp, 0*8)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp, 1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp, 2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp, 3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp, 4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp, 5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + + Bias (sp, 6*8)- Supplies the address of the bias buffer. + + Flags (sp, 7*8)- Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + st.d $a6, $sp, KernelHeight_arg + st.d $a7, $sp, KernelWidth_arg + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, InputBase_arg + st.d $s1, $sp, InputWidth_arg + st.d $s2, $sp, DilatedInputWidth_arg + st.d $s3, $sp, OutputCountLeftPad_arg + ld.d $s0, $sp, SP_SIZE+4*8 + ld.d $s1, $sp, SP_SIZE+5*8 + ld.d $s2, $sp, SP_SIZE+6*8 + ld.d $s3, $sp, SP_SIZE+7*8 + st.d $s0, $sp, OutputCount_arg + st.d $s1, $sp, OutputCountRightPad_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg +// +// Process the specified number of filter rows. +// + move $t8, $a4 // shuffle to Win64 register usage + move $t5, $a5 + move $a4, $a2 + move $a5, $a3 + ProcessFilterCountN SconvKernelDepthwiseFrame, Depthwise, 1 + +// +// Restore non-volatile registers and return. + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE +// + jr $ra +.endm + +/*++ + +Macro Description: + + This macro generates code to compute the convolution for a vector of input + blocks and a vector of filter blocks to produce a matrix of output blocks + for a pointwise convolution. + +Arguments: + + Isa - Supplies the instruction set architecture string for function tags. + + BlockSize - Supplies the number of elements per block. + + FilterCount - Supplies the number of rows from the filter to process. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + (a0) - Supplies the address of the input buffer. + + (a1) - Supplies the FilterStride parameter (see function description). + + (s8) - Supplies the InputStride parameter (see function description). + + (a4) - Supplies the address of the output buffer. + + (a5) - Supplies the StrideWidth parameter (see function description). + + (s5) - Supplies the address of the filter buffer. + +--*/ + + .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount + + move $a3, $a0 + move $a2, $t2 + ld.d $t1, $sp, InputChannels_arg + ClearBlock \FilterCount\(), \OutputCount\() + +.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: +.if \OutputCount\() > 3 + li.d $s0, 2 + mul $s0, $s0, $a5 + add.d $t4, $a5, $s0 + add.d $t4, $t4, $a3 # compute input plus 3 blocks +.endif +.if \FilterCount\() > 2 + li.d $s0, 2 # compute filter plus 2 rows + mul.d $s0, $s0, $a1 + add.d $t7, $a2, $s0 +.endif + +.if \BlockSize\() == 16 + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 + .endr +.else + .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 + ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 + .endr +.endif + add.d $a3, $a3, $t8 # advance input to next channel block + addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 + # advance filter by 8i8o/16i16o block + addi.d $t1, $t1, -1 //InputChannels decrement input blocks remaining + bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock + +// +// Handle post processing of the output block. +// + ld.w $a2, $sp, Flags_arg #load flag +.if \FilterCount\() > 1 + ld.d $t6 ,$sp, OutputStride_arg #load .LSconvKernelPointwiseFrame_OutputStride +.endif + ld.d $a3, $sp, Bias_arg # load .LSconvKernelPointwiseFrame_Bias + bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() +.endm + + .macro SconvKernelPointwiseFunction Isa, BiasFilter + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + + Pointwise convolutions have a kernel size of one. To simplify this + implementation, no input padding is allowed, which matches typical usage in + models. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + Filter (a1) - Supplies the address of the filter buffer. + + Output (a2) - Supplies the address of the output buffer. + + StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. + + InputChannels (a4) - Supplies the number of input channels to process. + + FilterCount (a5) - Supplies the number of rows from the filter to process. + + InputStride (a6) - Supplies the length in bytes to advance the input buffer to + the next input channel of the same input row. + + FilterStride (a7) - Supplies the length in bytes to advance the filter buffer + to the next set of filters. + + OutputStride (sp+0) - Supplies the length in bytes to advance the output buffer + to the next output address associated with the next set of filters. + + OutputCount (sp+8) - Supplies the number of output elements. + + Bias (sp+16) - Supplies the address of the bias buffer. + + Flags (sp+24) - Supplies additional flags controlling the convolution operation, + especially post calculation options. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + + ld.d $s0, $sp, SP_SIZE+0*8 + ld.d $s1, $sp, SP_SIZE+1*8 + ld.d $s2, $sp, SP_SIZE+2*8 + ld.d $s3, $sp, SP_SIZE+3*8 + st.d $s0, $sp, OutputStride_arg + st.d $s1, $sp, OutputCount_arg + st.d $s2, $sp, Bias_arg + st.d $s3, $sp, Flags_arg + st.d $a4, $sp, InputChannels_arg + +.ifeqs "\BiasFilter\()","BiasFilter" + addi.d $t2, $a1, 4*8*4 +.else + move $t2, $a1 +.endif + + ld.d $t0, $sp, OutputCount_arg //OutputCount + move $a1, $a7 // FilterStride + move $t8, $a6 // InputStride + move $t1, $a5 // shuffle to Win64 register usage + move $a4, $a2 + move $a5, $a3 + +// +// Process the specified number of filter rows. +// + li.d $s0, 3 + beq $t1, $s0, .LPointwise.ProcessFilterCount3 + blt $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 + ProcessPointwiseFilterCountN 4 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount3: + ProcessPointwiseFilterCountN 3 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCountLessThan3: + li.d $s0, 2 + blt $t1, $s0, .LPointwise.ProcessFilterCount1 + ProcessPointwiseFilterCountN 2 + b .LPointwise.ExitKernel + +.LPointwise.ProcessFilterCount1: + ProcessPointwiseFilterCountN 1 + +// +// Restore non-volatile registers and return. +// +.LPointwise.ExitKernel: + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra +.endm diff --git a/third_party/mlas/lib/loongarch64/SgemmKernelCommon.h b/third_party/mlas/lib/loongarch64/SgemmKernelCommon.h new file mode 100644 index 0000000000..93b109c90a --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SgemmKernelCommon.h @@ -0,0 +1,35 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision matrix/matrix multiply operation (SGEMM). + +--*/ + +// +// Define the single precision parameters. +// + +#define LFgemmElementShift 2 +#define LFgemmElementSize (1 << LFgemmElementShift) +#define LFgemmYmmElementCount (32/LFgemmElementSize) + +#include "FgemmKernelCommon.h" + +// +// Define the typed instructions for single precision. +// + +FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.s) +FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.s) +FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.w) +FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.s) diff --git a/third_party/mlas/lib/loongarch64/SgemmKernelLasx.S b/third_party/mlas/lib/loongarch64/SgemmKernelLasx.S new file mode 100644 index 0000000000..d537742016 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SgemmKernelLasx.S @@ -0,0 +1,33 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses LASX instructions. + +--*/ + +#include "asmmacro.h" +#include "SgemmKernelCommon.h" +#include "FgemmKernelLasxCommon.h" + + + .text + +// +// Generate the GEMM kernel. +// + +FgemmKernelLasxFunction MlasGemmFloatKernelLasx + + .end diff --git a/third_party/mlas/lib/loongarch64/SgemmKernelLsx.S b/third_party/mlas/lib/loongarch64/SgemmKernelLsx.S new file mode 100644 index 0000000000..86b5ef8b51 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SgemmKernelLsx.S @@ -0,0 +1,267 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelLsx.s + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" +#include "FgemmKernelLsxCommon.h" + +FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.s) + +/*++ + +Macro Description: + + This macro multiplies and accumulates for a 16xN block of the output matrix. + +Arguments: + + RowCount - Supplies the number of rows to process. + + VectorOffset - Supplies the byte offset from matrix B to fetch elements. + + Shuffle - Supplies the shuffle mask to extract the element from matrix A. + +Implicit Arguments: + + a1 - Supplies the address into the matrix B data. + + vr0-vr1 - Supplies up to four elements loaded from matrix A and matrix A + plus one row. + + vr8-vr15 - Supplies the block accumulators. + +--*/ + + .macro ComputeBlockSseBy16 RowCount, VectorOffset, Shuffle + vld $vr4, $a1, \VectorOffset + vld $vr5, $a1, \VectorOffset + 16 + vreplvei.w $vr2, $vr0, \Shuffle +.if \RowCount\() == 2 + vreplvei.w $vr3, $vr1, \Shuffle + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr8, $vr4, $vr2, $vr8 + vfmadd.s $vr9, $vr5, $vr2, $vr9 +.if \RowCount\() == 2 + vfmadd.s $vr12, $vr6, $vr3, $vr12 + vfmadd.s $vr13, $vr7, $vr3, $vr13 +.endif + vld $vr4, $a1, \VectorOffset + 32 + vld $vr5, $a1, \VectorOffset + 48 +.if \RowCount\() == 2 + vmove $vr6, $vr4 + vmove $vr7, $vr5 +.endif + vfmadd.s $vr10, $vr4, $vr2, $vr10 + vfmadd.s $vr11, $vr5, $vr2, $vr11 +.if \RowCount\() == 2 + vfmadd.s $vr14, $vr6, $vr3, $vr14 + vfmadd.s $vr15, $vr7, $vr3, $vr15 +.endif + .endm + + +/*++ + +Macro Description: + + This macro generates code to compute matrix multiplication for a fixed set + of rows. + +Arguments: + + RowCount - Supplies the number of rows to process. + + Fallthrough - Supplies a non-blank value if the macro may fall through to + the ExitKernel label. + +Implicit Arguments: + + a0 - Supplies the address of matrix A. + + a1 - Supplies the address of matrix B. + + t8 - Supplies the address of matrix A. + + a5 - Supplies the number of columns from matrix B and matrix C to iterate + over. + + a2 - Supplies the address of matrix C. + + a3 - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + t7 - Supplies the length in bytes of a row from matrix A. + + t5 - Supplies the length in bytes of a row from matrix C. + + s3 - Stores the ZeroMode argument from the stack frame. + +--*/ + + .macro ProcessCountM RowCount, Fallthrough +.LProcessNextColumnLoop16xN\@: + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8, $vr8,$vr8" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9, $vr9,$vr9" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10, $vr10,$vr10" + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11, $vr11,$vr11" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12, $vr12,$vr12" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13, $vr13,$vr13" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14, $vr14,$vr14" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15, $vr15,$vr15" + move $t8, $a3 + li.d $s0, 4 + blt $t8, $s0, .LProcessRemaining16xNBlocks\@ +.LCompute16xNBlockBy4Loop\@: + EmitIfCountGE \RowCount\(), 1, "vld $vr0, $a0, 0" + EmitIfCountGE \RowCount\(), 2, "vldx $vr1, $a0, $t0" #second line of A + ComputeBlockSseBy16 2, 0, 0x0 + ComputeBlockSseBy16 2, 16*4, 0x1 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + ComputeBlockSseBy16 2, 0, 0x2 + ComputeBlockSseBy16 2, 16*4, 0x3 + addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns + addi.d $a0, $a0, 4*4 # advance matrix A by 4 columns + addi.d $t8, $t8, -4 + li.d $s0, 4 #check matrix A remaining less than 4 + bge $t8, $s0, .LCompute16xNBlockBy4Loop\@ + +.LProcessRemaining16xNBlocks\@: + beqz $t8, .LOutput16xNBlock\@ + +.LCompute16xNBlockBy1Loop\@: + EmitIfCountGE \RowCount\(), 1, "ld.w $s0, $a0, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.w $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "ldx.w $s0,$a0, $t0" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.w $vr1,$s0, 0" + ComputeBlockSseBy16 2, 0, 0x00 + addi.d $a1, $a1, 16*4 #advance matrix B by 16 columns + addi.d $a0, $a0, 1*4 #advance matrix A by 1 column + addi.d $t8, $t8, -1 + bnez $t8, .LCompute16xNBlockBy1Loop\@ + +.LOutput16xNBlock\@: + movfr2gr.s $s0, $f24 + vreplgr2vr.w $vr2, $s0 + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr8,$vr8,$vr2" + # multiply by alpha + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr9,$vr9,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr10,$vr10,$vr2" + EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr11,$vr11,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr12,$vr12,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr13,$vr13,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr14,$vr14,$vr2" + EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr15,$vr15,$vr2" + li.d $s0, 16 + blt $a5, $s0, .LOutputPartial16xNBlock\@ + sub.d $a5, $a5, $s0 + AccumulateAndStoreBlock \RowCount\(), 4 + addi.d $a2, $a2, 16*4 # advance matrix C by 16 columns + move $a0, $t1 # reload matrix A + bnez $a5, .LProcessNextColumnLoop16xN\@ + b .LExitKernel + +// +// Output a partial 16xN block to the matrix. +// + +.LOutputPartial16xNBlock\@: + li.d $s0, 4 + blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ + li.d $s0, 8 + blt $a5, $s0, .LOutputPartialLessThan8xNBlock\@ + li.d $s0, 12 + blt $a5, $s0, .LOutputPartialLessThan12xNBlock\@ + AccumulateAndStoreBlock \RowCount\(), 3 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr11" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr15" + addi.d $a2, $a2,12*4 # advance matrix C by 12 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan12xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 2 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr10" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr14" + addi.d $a2, $a2,8*4 # advance matrix C by 8 columns + b .LOutputPartialLessThan4xNBlock\@ + +.LOutputPartialLessThan8xNBlock\@: + AccumulateAndStoreBlock \RowCount\(), 1 + andi $a5, $a5, 3 + beqz $a5, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr9" + # shift remaining elements down + EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr13" + addi.d $a2, $a2, 4*4 # advance matrix C by 4 columns + +.LOutputPartialLessThan4xNBlock\@: + andi $s0, $a5, 2 + beqz $s0, .LOutputPartial1xNBlock\@ + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput2xN\@ + EmitIfCountGE \RowCount\(), 1, "vxor.v $vr0, $vr0, $vr0" + EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.d $vr0, $s0, 0" + EmitIfCountGE \RowCount\(), 2, "vxor.v $vr1, $vr1, $vr1" + EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.d $vr1, $s0, 0" + EmitIfCountGE \RowCount\(), 1, "vfadd.s $vr8, $vr8, $vr0" + EmitIfCountGE \RowCount\(), 2, "vfadd.s $vr12, $vr12, $vr1" + +.LSkipAccumulateOutput2xN\@: + EmitIfCountGE \RowCount\(), 1, "vstelm.d $vr8, $a2, 0, 0" + EmitIfCountGE \RowCount\(), 2, "vpickve2gr.d $s0, $vr12, 0" + EmitIfCountGE \RowCount\(), 2, "stx.d $s0, $a2, $t6" + andi $s0, $a5, 1 + beqz $s0, .LExitKernel + EmitIfCountGE \RowCount\(), 1, "vpermi.w $vr8, $vr8, 0xee" + # shift third element down + EmitIfCountGE \RowCount\(), 2, "vpermi.w $vr12, $vr12, 0xee" + addi.d $a2, $a2, 2*4 # advance matrix C by 2 columns + +.LOutputPartial1xNBlock\@: + and $s0, $t5, $t5 # ZeroMode? + bnez $s0, .LSkipAccumulateOutput1xN\@ + + EmitIfCountGE \RowCount\(), 1, "fld.s $f16, $a2, 0" + EmitIfCountGE \RowCount\(), 1, "fadd.s $f8, $f16, $f8" + EmitIfCountGE \RowCount\(), 2, "fldx.s $f17, $a2, $t6" + EmitIfCountGE \RowCount\(), 2, "fadd.s $f12, $f12, $f17" + +.LSkipAccumulateOutput1xN\@: + EmitIfCountGE \RowCount\(), 1, "fst.s $f8, $a2, 0" + EmitIfCountGE \RowCount\(), 2, "fstx.s $f12, $a2, $t6" +.ifb \Fallthrough\() + b .LExitKernel +.endif + .endm + +// +// Generate the GEMM kernel. +// + +FgemmKernelLsxFunction MlasGemmFloatKernelLSX + + .end diff --git a/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S b/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S new file mode 100644 index 0000000000..cd1747745d --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S @@ -0,0 +1,89 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4LSX.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lsx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4LSX + addi.d $sp, $sp, -64 + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + slli.d $a2, $a2, 2 # convert ldb to bytes + ori $a3, $zero, 4 # transpose four 4x4 blocks + vxor.v $vr7, $vr7, $vr7 +.LTransposeBlockLoop: + slli.d $s0, $a2, 1 + add.d $s1, $a1, $s0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + vld $vr2, $s1, 0 + vldx $vr3, $s1, $a2 + + vor.v $vr4, $vr0, $vr7 + vilvl.w $vr4, $vr1, $vr4 + vilvh.w $vr0, $vr1, $vr0 + vor.v $vr5, $vr2, $vr7 + vilvl.w $vr5, $vr3, $vr5 + vilvh.w $vr2, $vr3, $vr2 + vor.v $vr1, $vr4, $vr7 + vilvl.d $vr1, $vr5, $vr1 + vilvh.d $vr4, $vr5, $vr4 + vor.v $vr3, $vr0, $vr7 + vilvl.d $vr3, $vr2, $vr3 + vilvh.d $vr0, $vr2, $vr0 + vst $vr1, $a0, 0 + vst $vr4, $a0, 0x40 + vst $vr3, $a0, 0x80 + vst $vr0, $a0, 0xc0 + addi.d $a0, $a0, 0x10 + slli.d $s0, $a2, 1 + add.d $a1, $s0, $s1 + addi.d $a3, $a3, -1 + bnez $a3, .LTransposeBlockLoop + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + addi.d $sp, $sp, 64 + jr $ra + + .end diff --git a/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S b/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S new file mode 100644 index 0000000000..e617419989 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S @@ -0,0 +1,126 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmTransposePackB16x4Lasx.s + +Abstract: + + This module implements routines for packing buffers for the single precision + matrix/matrix multiply operation (SGEMM). + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Macro Description: + + 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 + rows in the destination packed buffer. + +Arguments: + + StoreOffset - Supplies the relative byte offset into the destination packed + buffer. + +Implicit Arguments: + + a0 - Supplies the address of the destination packed buffer. + + a1 - Supplies the address of the source matrix. + + a2 - Supplies the number of elements per row of the source matrix. + +--*/ + + .macro TransposePackB8x4BlockLasx StoreOffset + +// +// Load 4 columns from 8 rows of the source matrix into the lower and upper +// halves of 4 XR registers. +// + + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + vld $vr0, $a1, 0 + vldx $vr1, $a1, $a2 + add.d $t0, $a2, $a2 + add.d $a1, $t6, $t0 + vld $vr2, $t6, 0 + vldx $vr3, $t6, $a2 + add.d $t0, $a2, $a2 + add.d $t6, $a1, $t0 + + vld $vr4, $a1, 0 + xvpermi.q $xr0, $xr4, 0x2 + vldx $vr5, $a1, $a2 + xvpermi.q $xr1, $xr5, 0x2 + vld $vr4, $t6, 0 + xvpermi.q $xr2, $xr4, 0x2 + vldx $vr5, $t6, $a2 + xvpermi.q $xr3, $xr5, 0x2 + +// +// Transpose the lower and upper halves of the 4 XR registers as two 4x4 +// matrices and store the output to the destination packed buffer. +// + + xvilvl.w $xr4, $xr1, $xr0 + xvilvh.w $xr5, $xr1, $xr0 + xvilvl.w $xr0, $xr3, $xr2 + xvilvh.w $xr1, $xr3, $xr2 + xvilvl.d $xr2, $xr0, $xr4 + xvilvh.d $xr3, $xr0, $xr4 + xvst $xr2, $a0, \StoreOffset\() + xvst $xr3, $a0, 0x40+\StoreOffset\() + xvilvl.d $xr0, $xr1, $xr5 + xvilvh.d $xr4, $xr1, $xr5 + xvst $xr0, $a0, 0x80+\StoreOffset\() + xvst $xr4, $a0, 0xc0+\StoreOffset\() + + .endm + +/*++ + +Routine Description: + + This routine transposes elements from the source matrix to the destination + packed buffer. + + 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 + rows in the destination packed buffer. + +Arguments: + + D (a0) - Supplies the address of the destination packed buffer. + + B (a1) - Supplies the address of the source matrix. + + ldb (a2) - Supplies the number of elements per row of the source matrix. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasSgemmTransposePackB16x4Lasx + + slli.d $a2, $a2, 2 # convert ldb to bytes + TransposePackB8x4BlockLasx 0*4 + add.d $t0, $a2, $a2 + add.d $a1, $t0, $t6 + TransposePackB8x4BlockLasx 8*4 + jr $ra + + .end diff --git a/third_party/mlas/lib/loongarch64/SoftmaxKernelLasx.S b/third_party/mlas/lib/loongarch64/SoftmaxKernelLasx.S new file mode 100644 index 0000000000..aaaa3cbf91 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SoftmaxKernelLasx.S @@ -0,0 +1,357 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (a0) - Supplies the input buffer. + + N (a1) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelLasx + addi.d $sp, $sp, -32 + + la.global $t0, MlasMinimumF32Value + ld.w $t0, $t0, 0 + xvreplgr2vr.w $xr0, $t0 + beqz $a1, .LReduceMaximum.ExitKernel + ori $t0, $zero, 8 + bltu $a1, $t0, .LReduceMaximum.ProcessRemainingCountBy1 + ori $t1, $zero, 32 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy8 + xvreplgr2vr.w $xr16, $zero + xvor.v $xr1, $xr0, $xr16 + xvor.v $xr2, $xr0, $xr16 + xvor.v $xr3, $xr0, $xr16 + +.LReduceMaximum.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + xvld $xr16, $a0, 8*4 + xvfmax.s $xr1, $xr1, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmax.s $xr2, $xr2, $xr16 + xvld $xr16, $a0, 24*4 + xvfmax.s $xr3, $xr3, $xr16 + addi.d $a0, $a0, 32*4 # advance input by 32 elements + ori $t1, $zero, 32 + bgeu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy32 + xvfmax.s $xr0, $xr0, $xr1 + xvfmax.s $xr2, $xr2, $xr3 + xvfmax.s $xr0, $xr0, $xr2 + +.LReduceMaximum.ProcessRemainingCountBy8: + ori $t1, $zero, 8 + bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmax.s $xr0, $xr0, $xr16 + addi.d $a1, $a1, -8 + addi.d $a0, $a0, 8*4 + b .LReduceMaximum.ProcessRemainingCountBy8 + +.LReduceMaximum.ProcessRemainingCountLessThan8: + xvst $xr0, $sp, 0 + vld $vr1, $sp, 0x10 + vld $vr0, $sp, 0 + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0xee + vfmax.s $vr0, $vr0, $vr1 + vshuf4i.w $vr1, $vr0, 0x55 + vfmax.s $vr0, $vr0, $vr1 + beqz $a1, .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vld $vr16, $a0, 0 + vfmax.s $vr0, $vr0, $vr16 + addi.d $a0, $a0, 4 # advance input by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + addi.d $sp, $sp, 32 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the softmax operation. + +Arguments: + + Output (a0) - Supplies the output buffer. + + N (a1) - Supplies the number of elements to process. + + Parameters (a2) - Supplies an array containing the scale value. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeSoftmaxOutputF32KernelLasx + + ld.w $t0, $a2, 0 + xvreplgr2vr.w $xr4, $t0 + ori $t1, $zero, 0x20 + bltu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 8*4 + xvfmul.s $xr1, $xr4, $xr16 + addi.d $a1, $a1, -0x20 + xvld $xr16, $a0, 16*4 + xvfmul.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 24*4 + xvfmul.s $xr3, $xr4, $xr16 + xvst $xr0, $a0, 0 + xvst $xr1, $a0, 8*4 + xvst $xr2, $a0, 16*4 + xvst $xr3, $a0, 24*4 + addi.d $a0, $a0, 0x80 # advance output by 32 elements + bgeu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy8: + ori $t2, $zero, 8 + bltu $a1, $t2, .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfmul.s $xr0, $xr4, $xr16 + addi.d $a1, $a1, -8 + xvst $xr0, $a0, 0 + addi.d $a0, $a0, 8*4 # advance output by 8 elements + b .LComputeSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a1, .LComputeSoftmaxOutput.ExitKernel + +.LComputeSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fmul.s $f0, $f4, $f16 + fst.s $f0, $a0, 0 + addi.d $a0, $a0, 4 # advance output by 1 element + addi.d $a1, $a1, -1 + bnez $a1, .LComputeSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to produce the final output for + the log softmax operation. + +Arguments: + + Input (a0) - Supplies the output buffer. + + Output (a1) - Supplies the output buffer. + + N (a2) - Supplies the number of elements to process. + + Parameters (a3) - Supplies an array containing the negative maximum and + logarithm values. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasComputeLogSoftmaxOutputF32KernelLasx + + ld.w $t0, $a3, 0 + ld.w $t1, $a3, 4 + ori $t2, $zero, 0x20 + xvreplgr2vr.w $xr4, $t0 # broadcast negative minimum value + xvreplgr2vr.w $xr5, $t1 # broadcast log(SumExp) + bltu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy32: + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + xvld $xr16, $a0, 0x20 + xvfadd.s $xr1, $xr4, $xr16 + addi.d $a2, $a2, -0x20 + xvld $xr16, $a0, 0x40 + xvfadd.s $xr2, $xr4, $xr16 + xvld $xr16, $a0, 0x60 + xvfadd.s $xr3, $xr4, $xr16 + addi.d $a0, $a0, 0x80 # advance input by 32 elements + xvfsub.s $xr0, $xr0, $xr5 # do as two steps for numeric stability + xvfsub.s $xr1, $xr1, $xr5 # do as two steps for numeric stability + xvfsub.s $xr2, $xr2, $xr5 # do as two steps for numeric stability + xvfsub.s $xr3, $xr3, $xr5 # do as two steps for numeric stability + xvst $xr0, $a1, 0 + xvst $xr1, $a1, 0x20 + xvst $xr2, $a1, 0x40 + xvst $xr3, $a1, 0x60 + addi.d $a1, $a1, 0x80 # advance output by 32 elements + bgeu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy32 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy8: + ori $t3, $zero, 8 + bltu $a2, $t3, .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8 + xvld $xr16, $a0, 0 + xvfadd.s $xr0, $xr4, $xr16 + addi.d $a0, $a0, 0x20 + xvfsub.s $xr0, $xr0, $xr5 + addi.d $a2, $a2, -8 + xvst $xr0, $a1, 0 + addi.d $a1, $a1, 0x20 # advance output by 8 elements + b .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: + beqz $a2, .LComputeLogSoftmaxOutput.ExitKernel + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: + fld.s $f16, $a0, 0 + fadd.s $f0, $f4, $f16 + + addi.d $a0, $a0, 4 + fsub.s $f0, $f0, $f5 + fst.s $f0, $a1, 0 + + addi.d $a1, $a1, 4 + addi.d $a2, $a2, -1 + bnez $a2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 + +.LComputeLogSoftmaxOutput.ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + jr $ra + + .end diff --git a/third_party/mlas/lib/loongarch64/SpoolKernelLSX.S b/third_party/mlas/lib/loongarch64/SpoolKernelLSX.S new file mode 100644 index 0000000000..96bda3bb12 --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SpoolKernelLSX.S @@ -0,0 +1,460 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLSX.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses LSX instructions. + +--*/ + +#define SP_SIZE 32*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 + + .macro FUNCTION_ENTRY FunctionName + + .p2align 4 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + vreplgr2vr.w $vr5, $s0 +.endif + +.ifeqs "\PoolingType\()","AverageIncludePad" + vreplgr2vr.w $vr5, $a5 + vffint.s.w $vr5, $vr5 +.endif + + .endm +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0*8 + st.d $s1, $sp, 1*8 + st.d $s2, $sp, 2*8 + st.d $s3, $sp, 3*8 + st.d $s4, $sp, 4*8 + st.d $ra, $sp, 5*8 + fst.d $f24,$sp, 6*8 + + InitializeKernel \PoolingType\() + # move InputStride to s8 + or $t8, $a4, $r0 + # move StrideWidth to a4 + or $a4, $a2, $r0 + # move DilationWidth to a5 + or $a5, $a3, $r0 + # move Output to a2 + or $a2, $a1, $r0 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0*8 + ld.d $s1, $sp, 1*8 + ld.d $s2, $sp, 2*8 + ld.d $s3, $sp, 3*8 + ld.d $s4, $sp, 4*8 + ld.d $ra, $sp, 5*8 + fld.d $f24,$sp, 6*8 + + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr2 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vor.v $vr0, $vr5, $vr5 + vor.v $vr1, $vr5, $vr5 +.else + vxor.v $vr0, $vr0, $vr0 + vxor.v $vr1, $vr1, $vr1 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + xor $a1, $a1, $a1 # reset valid block counter +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + vr0-vr1 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + vld $vr24, $a3, 0 + vfmax.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfmax.s $vr1, $vr1, $vr24 +.else + vld $vr24, $a3, 0 + vfadd.s $vr0, $vr0, $vr24 + vld $vr24, $a3, 16 + vfadd.s $vr1, $vr1, $vr24 +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" + # increment valid block counter + addi.d $a1, $a1, 1 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + vr0-vr1 - Supplies the pooling intermediates. + + vr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" + # convert valid block counter + vreplgr2vr.w $vr4, $a1 + vffint.s.w $vr4, $vr4 + vfdiv.s $vr0, $vr0, $vr4 + vfdiv.s $vr1, $vr1, $vr4 +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + vfdiv.s $vr0, $vr0, $vr5 + vfdiv.s $vr1, $vr1, $vr5 +.endif + +// +// Store the output block in the output buffer. +// + + vst $vr0, $a2, 0 + vst $vr1, $a2, 16 + # advance output by 1 nchw8c block + addi.d $a2, $a2, 8*4 + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + s8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $r0, $t3 # keep negative for lea usage below +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + or $t6, $t2, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + # (Input - InputBase) >= InputWidth? + add.d $t7, $a3, $t3 + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + # decrement columns remaining + addi.d $t6, $t6, -1 + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + # advance input base to next row + sub.d $t3, $t3, $s0 +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7) - Supplies the width of the kernel to apply. + + InputBase (0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (1*8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (2*8)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (3*8)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (4*8)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (5*8)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + SpoolKernelEntry \PoolingType\() + + ld.d $s0, $sp, OutputCountLeftPad_arg + ld.d $s1, $sp, OutputCount_arg + add.d $t0, $s0, $s1 + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + +.L\PoolingType\().ProcessNextOutputCount: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 + addi.d $t0, $t0, -1 + bnez $t0, .L\PoolingType\().ProcessNextOutputCount + +.L\PoolingType\().ExitKernel: + SpoolKernelExit + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, LSX + SpoolKernelFunction AverageExcludePad, LSX + SpoolKernelFunction AverageIncludePad, LSX + + .end diff --git a/third_party/mlas/lib/loongarch64/SpoolKernelLasx.S b/third_party/mlas/lib/loongarch64/SpoolKernelLasx.S new file mode 100644 index 0000000000..6e5f0136cd --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SpoolKernelLasx.S @@ -0,0 +1,238 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelLasx.s + +Abstract: + + This module implements the kernels for the single precision pooling + operation. + + This implementation uses Lasx instructions. + +--*/ + +#include "asmmacro.h" +#include "SpoolKernelLasxCommon.h" + + .text + +/*++ + +Macro Description: + + This macro generates code to initialize registers used across the kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + +Implicit Arguments: + + a5 - Supplies the ActualKernelSize parameter (see function description). + +--*/ + + .macro InitializeKernel PoolingType + +.ifeqs "\PoolingType\()","Maximum" + li.w $s0, 0xFF7FFFFF + xvreplgr2vr.w $xr5, $s0 +.else + xvxor.v $xr5, $xr5, $xr5 +.ifeqs "\PoolingType\()","AverageExcludePad" + move $t6, $a6 + mul.d $t6, $t6, $a7 + xvreplgr2vr.w $xr5, $t6 +.else + xvreplgr2vr.w $xr5, $a5 +.endif + xvffint.s.w $xr5, $xr5 +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to clear the pooling intermediates. + + For PoolingType==Maximum, the pooling intermediates are set to the minimum + float value. Otherwise, the pooling intermediates are cleared to zero. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies a vector containing the minimum float value broadcasted, + if PoolingType==Maximum. + +--*/ + + .macro ClearBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvor.v $xr0, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvor.v $xr1, $xr5, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvor.v $xr2, $xr5, $xr5" +.else + EmitIfCountGE \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" + EmitIfCountGE \OutputCount\(), 2, "xvxor.v $xr1, $xr1, $xr1" + EmitIfCountGE \OutputCount\(), 3, "xvxor.v $xr2, $xr2, $xr2" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xor $a1, $a1, $a1 # reset valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to sample the input buffer and update the pooling + intermediates as appropriate. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a3 - Supplies the address of the input buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + a4 - Supplies the StrideWidth parameter (see function description). + + xr0-xr2 - Supplies the pooling intermediates. + +--*/ + + .macro ComputeBlock PoolingType, OutputCount + +.ifeqs "\PoolingType\()","Maximum" + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfmax.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfmax.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfmax.s $xr2, $xr2, $xr16" +.else + EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" + EmitIfCountGE \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" + EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" + EmitIfCountGE \OutputCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" + EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" + EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" + EmitIfCountGE \OutputCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" +.endif + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + addi.d $a1, $a1, 1 # increment valid block counter +.endif +.endif + + .endm + +/*++ + +Macro Description: + + This macro generates code to process and store the pooling intermediates. + +Arguments: + + PoolingType - Supplies the pooling type string. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a2 - Supplies the address of the output buffer. + + a1 - Supplies the number of blocks accessed by ComputeBlock, if + PoolingType=AverageExcludePad and OutputCount=1. + + xr0-xr2 - Supplies the pooling intermediates. + + xr5 - Supplies the kernel size computed by InitializeKernel, if + PoolingType=AverageExcludePad, else the actual kernel size, if + PoolingType=AverageIncludePad. + +--*/ + + .macro PostProcessBlock PoolingType, OutputCount + +// +// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding +// blocks. OutputCount=1 generates code to count the number of blocks accessed by +// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. +// + +.ifeqs "\PoolingType\()","AverageExcludePad" +.if \OutputCount\() == 1 + xvxor.v $xr4, $xr4, $xr4 + xvreplgr2vr.w $xr4, $a1 + xvffint.s.w $xr4, $xr4 + xvfdiv.s $xr0, $xr0, $xr4 +.else + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif +.endif + +// +// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. +// + +.ifeqs "\PoolingType\()","AverageIncludePad" + EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" + EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" + EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" +.endif + +// +// Store the output block in the output buffer. +// + + EmitIfCountGE \OutputCount\(), 1, "xvst $xr0, $a2, 0" + EmitIfCountGE \OutputCount\(), 2, "xvst $xr1, $a2, 0x20" + EmitIfCountGE \OutputCount\(), 3, "xvst $xr2, $a2, 0x40" + add_immed $a2,\OutputCount\()*8*4 # advance output by N nchw8c blocks + + .endm + +// +// Generate the pooling kernels. +// + + SpoolKernelFunction Maximum, Lasx + SpoolKernelFunction AverageExcludePad, Lasx + SpoolKernelFunction AverageIncludePad, Lasx + + .end diff --git a/third_party/mlas/lib/loongarch64/SpoolKernelLasxCommon.h b/third_party/mlas/lib/loongarch64/SpoolKernelLasxCommon.h new file mode 100644 index 0000000000..066c75d34f --- /dev/null +++ b/third_party/mlas/lib/loongarch64/SpoolKernelLasxCommon.h @@ -0,0 +1,311 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SpoolKernelasxCommon.h + +Abstract: + + This module contains common kernel macros and structures for the single + precision pooling operation for the Lasx kernels. + +--*/ + +// +// Stack frame layout for the pooling kernels. +// + +#define SP_SIZE 8*8 +#define InputBase_arg SP_SIZE+0*8 +#define InputWidth_arg SP_SIZE+1*8 +#define DilatedInputWidth_arg SP_SIZE+2*8 +#define OutputCountLeftPad_arg SP_SIZE+3*8 +#define OutputCount_arg SP_SIZE+4*8 +#define OutputCountRightPad_arg SP_SIZE+5*8 +/*++ + +Macro Description: + + This macro generates the common prologue code for the pooling kernels. + +Arguments: + + PoolingType - Supplies the pooling type string. + +--*/ + + .macro SpoolKernelEntry PoolingType + + addi.d $sp, $sp, -SP_SIZE + st.d $s0, $sp, 0 + st.d $s1, $sp, 1*8 + fst.d $f16, $sp, 2*8 + st.d $ra, $sp, 5*8 + + InitializeKernel \PoolingType\() + move $t8, $a4 + move $a4, $a2 + move $a5, $a3 + move $a2, $a1 + + .endm + +/*++ + +Macro Description: + + This macro generates the common epilogue code for the pooling kernels. + +Arguments: + + None. + +--*/ + + .macro SpoolKernelExit + + ld.d $s0, $sp, 0 + ld.d $s1, $sp, 1*8 + fld.d $f16, $sp, 2*8 + ld.d $ra, $sp, 5*8 + addi.d $sp, $sp, SP_SIZE + jr $ra + + .endm + +/*++ + +Macro Description: + + This macro generates code to compute pooling for a vector of input blocks + to produce a matrix of output blocks. + + OutputCount=1 generates special case code to handle padding blocks. All + other output counts assume no padding. + +Arguments: + + KernelFrame - Supplies the symbol name to access the convolution kernel + stack. + + OutputCount - Supplies the number of output blocks to produce. + +Implicit Arguments: + + a0 - Supplies the address of the input buffer. + + a2 - Supplies the address of the output buffer. + + a4 - Supplies the StrideWidth parameter (see function description). + + a5 - Supplies the DilationWidth parameter (see function description). + + t8 - Supplies the InputStride parameter (see function description). + +--*/ + + .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount + + move $a3, $a0 + move $t1, $a6 + move $t2, $a7 +.if \OutputCount\() == 1 + ld.d $t3, $sp, InputBase_arg + ld.d $t4, $sp, InputWidth_arg + sub.d $t3, $zero, $t3 +.endif + ClearBlock \PoolingType\(), \OutputCount\() + beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing + +.L\PoolingType\().\OutputCount\().ProcessNextRow: + move $t6, $t2 + +.L\PoolingType\().\OutputCount\().ProcessNextColumn: +.if \OutputCount\() == 1 + add.d $t7, $a3, $t3 # compute (Input - InputBase) + # (Input - InputBase) >= InputWidth? + bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding +.endif + ComputeBlock \PoolingType\(), \OutputCount\() + +.L\PoolingType\().\OutputCount\().SkipOverPadding: + add.d $a3, $a3, $a5 # advance input by dilation width + addi.d $t6, $t6, -1 # decrement columns remaining + bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn + add.d $a3, $a3, $t8 # advance input to next row +.if \OutputCount\() == 1 + ld.d $s0, $sp, DilatedInputWidth_arg + sub.d $t3, $t3, $s0 + # advance input base to next row +.endif + addi.d $t1, $t1, -1 + bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow + +.L\PoolingType\().\OutputCount\().HandlePostProcessing: + PostProcessBlock \PoolingType\(), \OutputCount\() + + .endm +/*++ + +Macro Description: + + This macro generates code for the inner pooling kernel. + +Arguments: + + PoolingType - Supplies the pooling type string. + + Isa - Supplies the instruction set architecture string for function tags. + +--*/ + + .macro SpoolKernelFunction PoolingType, Isa + +/*++ + +Routine Description: + + This routine is the inner kernel to compute pooling for the elements of an + output row for a set of filter rows. + +Arguments: + + Input (a0) - Supplies the address of the input buffer. + + The address is biased to include padding blocks for the left width + dimension. The address is not biased to include padding rows for the + left height dimension these are accounted for in the outer kernel. + + Output (a1) - Supplies the address of the output buffer. + + StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. + + DilationWidth (a3) - Supplies the length in bytes of the blocked dilation + width. + + InputStride (a4) - Supplies the length in bytes to advance the input buffer to + the next input row. + + ActualKernelSize (a5) - Supplies the size of the kernel based on the original + kernel dimensions, used for PoolingType=AverageIncludePad. + + KernelHeight (a6) - Supplies the height of the kernel to apply. This height may + be less than the original kernel height after removing any padding + rows. + + KernelWidth (a7)- Supplies the width of the kernel to apply. + + InputBase (sp + 0)- Supplies the address of the valid input buffer. + + This parameter is similar to the Input parameter, but does not include + the padding blocks for the left width dimension. This parameter is used + with the following InputWidth parameter in order to validate that the + current input buffer address in bounds and not in the left or right + width padding region. + + InputWidth (sp + 0x8)- Supplies the length in bytes of the blocked input width. + + DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base + buffer to the next input row including dilation. + + OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include + one or more padding elements from the left edge. + + OutputCount (sp + 0x20)- Supplies the number of output elements that do not include + any padding elements. + + OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include + one or more padding elements from the right edge. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() + + SpoolKernelEntry \PoolingType\() + +.L\PoolingType\().ProcessOutputCountLeftPad: + ld.d $t0, $sp, OutputCountLeftPad_arg + + beqz $t0, .L\PoolingType\().ProcessOutputCount + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ProcessOutputCount: + ld.d $t0, $sp, OutputCount_arg + li.d $s0, 3 + bltu $t0, $s0, .L\PoolingType\().ProcessRemainingOutputCount + +.L\PoolingType\().ProcessNextOutputCountBy3: + ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 3 + slli.d $s0, $a4, 1 + add.d $t6, $s0, $a4 + add.d $a0, $a0, $t6 # advance input by 3 elements + addi.d $t0, $t0, -3 + li.d $s0, 3 + bgeu $t0, $s0, .L\PoolingType\().ProcessNextOutputCountBy3 + +.L\PoolingType\().ProcessRemainingOutputCount: + +.L\PoolingType\().ProcessOutputCountRightPad: + ld.d $s0, $sp, OutputCountRightPad_arg + add.d $t0, $t0, $s0 + beqz $t0, .L\PoolingType\().ExitKernel + bl MlasPool\PoolingType\()FloatSingle\Isa\() + +.L\PoolingType\().ExitKernel: + xvinsgr2vr.d $xr0, $zero, 2 + xvinsgr2vr.d $xr0, $zero, 3 + xvinsgr2vr.d $xr1, $zero, 2 + xvinsgr2vr.d $xr1, $zero, 3 + xvinsgr2vr.d $xr2, $zero, 2 + xvinsgr2vr.d $xr2, $zero, 3 + xvinsgr2vr.d $xr3, $zero, 2 + xvinsgr2vr.d $xr3, $zero, 3 + xvinsgr2vr.d $xr4, $zero, 2 + xvinsgr2vr.d $xr4, $zero, 3 + xvinsgr2vr.d $xr5, $zero, 2 + xvinsgr2vr.d $xr5, $zero, 3 + xvinsgr2vr.d $xr6, $zero, 2 + xvinsgr2vr.d $xr6, $zero, 3 + xvinsgr2vr.d $xr7, $zero, 2 + xvinsgr2vr.d $xr7, $zero, 3 + xvinsgr2vr.d $xr8, $zero, 2 + xvinsgr2vr.d $xr8, $zero, 3 + xvinsgr2vr.d $xr9, $zero, 2 + xvinsgr2vr.d $xr9, $zero, 3 + xvinsgr2vr.d $xr10, $zero, 2 + xvinsgr2vr.d $xr10, $zero, 3 + xvinsgr2vr.d $xr11, $zero, 2 + xvinsgr2vr.d $xr11, $zero, 3 + xvinsgr2vr.d $xr12, $zero, 2 + xvinsgr2vr.d $xr12, $zero, 3 + xvinsgr2vr.d $xr13, $zero, 2 + xvinsgr2vr.d $xr13, $zero, 3 + xvinsgr2vr.d $xr14, $zero, 2 + xvinsgr2vr.d $xr14, $zero, 3 + xvinsgr2vr.d $xr15, $zero, 2 + xvinsgr2vr.d $xr15, $zero, 3 + SpoolKernelExit + +// +// Generate out-of-band helpers for handling output blocks involving padding. +// + +MlasPool\PoolingType\()FloatSingle\Isa\(): + st.d $ra, $sp, 6*8 +loopMlasPool\PoolingType\()FloatSingle\Isa\(): + ProcessOutputCountN .LSpoolKernelSingleFrame, \PoolingType\(), 1 + add.d $a0, $a0, $a4 # advance input by 1 element + addi.d $t0, $t0, -1 # decrement output count remaining + bnez $t0, loopMlasPool\PoolingType\()FloatSingle\Isa\() + ld.d $ra, $sp, 6*8 + jr $ra + + .endm diff --git a/third_party/mlas/lib/loongarch64/asmmacro.h b/third_party/mlas/lib/loongarch64/asmmacro.h new file mode 100644 index 0000000000..837aca77dd --- /dev/null +++ b/third_party/mlas/lib/loongarch64/asmmacro.h @@ -0,0 +1,144 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + asmmacro.h + +Abstract: + + This module implements common macros for the assembly modules. + +--*/ + +#define C_UNDERSCORE(symbol) symbol + +.macro vmove dst src + vand.v \dst, \src, \src +.endm + +/*++ + +Macro Description: + + This macro emits the assembler directives to annotate a new function. + +Arguments: + + FunctionName - Supplies the name of the function. + +--*/ + + .macro FUNCTION_ENTRY FunctionName + .align 2 + .globl \FunctionName\() + .type \FunctionName\(),@function +\FunctionName\(): + + .endm + +/*++ + +Macro Description: + + This macro generates an optimization for "add reg,128" which can instead + be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit + value. + +Arguments: + + Register - Supplies the register to be added to. + + Immediate - Supplies the immediate to add to the register. + +--*/ + + .macro add_immed Register, Immediate + +.if (\Immediate\() != 128) + addi.d \Register\(),\Register\(),\Immediate\() +.else + addi.d \Register\(),\Register\(),\Immediate\() # smaller encoding +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count is greater than or + equal to Value. + +Arguments: + + Count - Supplies the variable used in the comparison. + + Value - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCountGE Count1, Value1, Statement + +.if (\Count1\() >= \Value1\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro conditionally emits the statement if Count1 is greater than or + equal to Value1 and Count2 is greater than or equal to Value2. + +Arguments: + + Count1 - Supplies the variable used in the comparison. + + Value1 - Supplies the static used in the comparison. + + Count2 - Supplies the variable used in the comparison. + + Value2 - Supplies the static used in the comparison. + + Statement - Supplies the statement to conditionally emit. + +--*/ + + .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement + +.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) + \Statement\() +.endif + + .endm + +/*++ + +Macro Description: + + This macro emits the statement for each register listed in the register + list. The statement can use RegItem to access the current register. + +Arguments: + + RegList - Supplies the list of registers. + + Statement - Supplies the statement to emit. + +--*/ + + .macro EmitForEachRegister RegList, Statement + + .irp RegItem, \RegList\() + \Statement\() + .endr + + .endm diff --git a/third_party/mlas/lib/mlasi.h b/third_party/mlas/lib/mlasi.h index cd06fc63e9..fe4e74e932 100644 --- a/third_party/mlas/lib/mlasi.h +++ b/third_party/mlas/lib/mlasi.h @@ -24,7 +24,6 @@ Module Name: #include #include #include -#include #include #ifdef MLAS_NO_EXCEPTION @@ -51,9 +50,18 @@ Module Name: #include #endif #if defined(__x86_64__) || defined(__i386__) +#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards #include +#endif +#if defined(__GNUC__) && __GNUC__ >= 12 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. +#include +#pragma GCC diagnostic pop +#else #include #endif +#endif #if defined(__VSX__) #include // Undefine unwanted aliases from altivec.h. @@ -61,6 +69,9 @@ Module Name: #undef pixel #undef bool #endif +#if defined(__loongarch64) +#include +#endif #if defined(MLAS_TARGET_WASM_SIMD) #include #endif @@ -178,11 +189,20 @@ class MLASCPUIDInfo bool IsCurrentCoreArmv8NarrowLd() const { return false; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -305,7 +325,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the prototypes of the platform optimized routines. // -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_TARGET_LARCH64) typedef size_t @@ -339,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -627,6 +662,42 @@ void int8_t ZeroPoint ); +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_U4_KERNEL)( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -664,9 +735,37 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; +#elif defined(MLAS_TARGET_LARCH64) + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLSX; + MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLSX; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLSX; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLSX; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLasx; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLasx; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLasx; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLSX; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLasx; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLasx; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4LSX; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Lasx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelLasx; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelLasx; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelLasx; #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -743,6 +842,10 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; @@ -763,6 +866,7 @@ extern "C" { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; #if defined(MLAS_TARGET_AMD64) MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx512F; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; #endif @@ -783,7 +887,7 @@ extern "C" { // value. // -#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 32 +#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 64 // // Define the target number of per-thread multiplies before using another @@ -794,6 +898,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // @@ -822,16 +930,17 @@ MlasSgemmOperation( struct MLAS_GEMM_QUANT_DISPATCH; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; -#ifdef MLAS_AMX_SUPPORTED extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx; -#endif extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; @@ -858,6 +967,36 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_Q8Q4GEMM_DISPATCH; + +extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; + +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_FPQ4GEMM_DISPATCH; + +extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; + +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; + // // Quantized depthwise convolution kernels. // @@ -923,12 +1062,29 @@ struct MLAS_PLATFORM { #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif - +#if defined(MLAS_TARGET_LARCH64) + const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; + MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; + uint32_t NchwcBlockSize; +#endif #if defined(MLAS_TARGET_AMD64_IX86) const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; #elif defined(MLAS_TARGET_ARM64) - const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; + const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; @@ -945,6 +1101,10 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; @@ -972,6 +1132,10 @@ struct MLAS_PLATFORM { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; + MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; + MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; + MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; + MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; @@ -981,6 +1145,10 @@ struct MLAS_PLATFORM { static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif + const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; + const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline @@ -1029,6 +1197,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( @@ -1173,6 +1358,8 @@ MlasConvDepthwiseFloat_CHW( #endif #elif defined(MLAS_TARGET_WASM_SIMD) #define MLAS_WASM_SIMD_INTRINSICS +#elif defined(MLAS_TARGET_LARCH64) +#define MLAS_LSX_INTRINSICS #endif #if defined(MLAS_NEON_INTRINSICS) @@ -1188,6 +1375,9 @@ typedef __vector unsigned MLAS_UINT32X4; #elif defined(MLAS_WASM_SIMD_INTRINSICS) typedef v128_t MLAS_FLOAT32X4; typedef v128_t MLAS_INT32X4; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128 MLAS_FLOAT32X4; +typedef __m128i MLAS_INT32X4; #else typedef float MLAS_FLOAT32X4 __attribute__ ((vector_size(16))); typedef int32_t MLAS_INT32X4 __attribute__ ((vector_size(16))); @@ -1201,6 +1391,8 @@ MlasReinterpretAsInt32x4(MLAS_FLOAT32X4 Vector) return vreinterpretq_s32_f32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castps_si128(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_INT32X4)Vector; #else return MLAS_INT32X4(Vector); #endif @@ -1216,6 +1408,8 @@ MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) return _mm_cvttps_epi32(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_cts(Vector, 0); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vftint_w_s(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return (MLAS_INT32X4)__builtin_convertvector((__f32x4)Vector, __i32x4); #else @@ -1235,6 +1429,8 @@ MlasCastToFloat32x4(MLAS_INT32X4 Vector) return vec_ctf(Vector, 0); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_convert_i32x4(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vffint_s_w(Vector); #else return MLAS_FLOAT32X4{float(Vector[0]), float(Vector[1]), float(Vector[2]), float(Vector[3])}; #endif @@ -1252,6 +1448,8 @@ MlasBroadcastInt32x4(int32_t Value) return wasm_i32x4_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vreplgr2vr_w(Value); #else return MLAS_INT32X4{Value, Value, Value, Value}; #endif @@ -1269,6 +1467,8 @@ MlasLoadInt32x4(const int32_t* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vld((const MLAS_INT32X4*)Buffer, 0); #else return *((MLAS_INT32X4*)Buffer); #endif @@ -1286,6 +1486,8 @@ MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(Vector, (MLAS_INT32X4 *)Buffer, 0); #else *((MLAS_INT32X4*)Buffer) = Vector; #endif @@ -1303,6 +1505,8 @@ MlasAddInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_i32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vadd_w(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1318,6 +1522,8 @@ MlasSubtractInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_sub_epi32(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vsub_w(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1333,6 +1539,8 @@ MlasAndInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_and_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vand_v(Vector1, Vector2); #else return Vector1 & Vector2; #endif @@ -1348,6 +1556,8 @@ MlasOrInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_or_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vor_v(Vector1, Vector2); #else return Vector1 | Vector2; #endif @@ -1363,6 +1573,8 @@ MlasAndNotInt32x4(MLAS_INT32X4 VectorNot, MLAS_INT32X4 Vector) return _mm_andnot_si128(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vandn_v(VectorNot, Vector); #else return (~VectorNot) & Vector; #endif @@ -1380,6 +1592,8 @@ MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return wasm_v128_xor(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vxor_v(Vector1, Vector2); #else return Vector1 ^ Vector2; #endif @@ -1403,6 +1617,8 @@ MlasShiftLeftInt32x4(MLAS_INT32X4 Vector) return _mm_slli_epi32(Vector, ShiftCount); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_shl(Vector, ShiftCount); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vslli_w(Vector, ShiftCount); #else return Vector << ShiftCount; #endif @@ -1422,6 +1638,8 @@ MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vmaxsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmax_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -1441,6 +1659,8 @@ MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return vec_vminsw(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vmin_w(Vector1, Vector2); #else return MlasBlendInt32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -1454,6 +1674,8 @@ MlasReinterpretAsFloat32x4(MLAS_INT32X4 Vector) return vreinterpretq_f32_s32(Vector); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_castsi128_ps(Vector); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4(Vector); #else return MLAS_FLOAT32X4(Vector); #endif @@ -1473,6 +1695,8 @@ MlasBroadcastFloat32x4(float Value) // Suppress wrong GCC warnings MLAS_UNREFERENCED_PARAMETER(Value); return vec_splats(Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{Value, Value, Value, Value}; #else return MLAS_FLOAT32X4{Value, Value, Value, Value}; #endif @@ -1490,6 +1714,8 @@ MlasBroadcastFloat32x4(const float* Value) return wasm_v128_load32_splat(Value); #elif defined(MLAS_VSX_INTRINSICS) return vec_splats(*Value); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #else return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; #endif @@ -1505,6 +1731,8 @@ MlasZeroFloat32x4(void) return _mm_setzero_ps(); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_const(0.0f, 0.0f, 0.0f, 0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat32x4(0.0f); #else return MlasBroadcastFloat32x4(0.0f); #endif @@ -1522,6 +1750,9 @@ MlasLoadFloat32x4(const float* Buffer) return vec_vsx_ld(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + // return MlasReinterpretAsFloat32x4(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); + return (MLAS_FLOAT32X4)__lsx_vld((const MLAS_INT32X4 *)Buffer, 0); #else return *((MLAS_FLOAT32X4*)Buffer); #endif @@ -1539,6 +1770,8 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_vsx_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + __lsx_vst(MlasReinterpretAsInt32x4(Vector), Buffer, 0); #else *((MLAS_FLOAT32X4*)Buffer) = Vector; #endif @@ -1559,6 +1792,8 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) vec_st(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreFloat32x4(Buffer, Vector); #else MlasStoreFloat32x4(Buffer, Vector); #endif @@ -1577,6 +1812,8 @@ MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) *Buffer = ((__f32x4)(Vector))[Lane]; +#elif defined(MLAS_LSX_INTRINSICS) + *Buffer = Vector[Lane]; #else *Buffer = Vector[Lane]; #endif @@ -1592,6 +1829,9 @@ MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_storel_pi((__m64*)Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((long long*)Buffer) = ((__vector long long)Vector)[0]; +#elif defined(MLAS_LSX_INTRINSICS) + MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); + MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); #else MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); @@ -1609,6 +1849,8 @@ MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_extract_lane(Vector, Lane); +#elif defined(MLAS_LSX_INTRINSICS) + return Vector[Lane]; #else return Vector[Lane]; #endif @@ -1653,6 +1895,9 @@ MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_i32x4_shuffle(Vector1, Vector2, Index0, Index1, Index2, Index3); #elif defined(__clang__) return __builtin_shufflevector(Vector1, Vector2, Index0, Index1, Index2, Index3); +#elif defined(MLAS_LSX_INTRINSICS) + typedef int32_t GEN_INT32X4 __attribute__ ((vector_size(16))); + return __builtin_shuffle(Vector1, Vector2, GEN_INT32X4{Index0, Index1, Index2, Index3}); #else return __builtin_shuffle(Vector1, Vector2, MLAS_INT32X4{Index0, Index1, Index2, Index3}); #endif @@ -1681,6 +1926,8 @@ MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpacklo_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergeh(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<0, 4, 1, 5>(Vector1, Vector2); #endif @@ -1699,6 +1946,8 @@ MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_unpackhi_ps(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_mergel(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); #else return MlasShuffleFloat32x4<2, 6, 3, 7>(Vector1, Vector2); #endif @@ -1716,6 +1965,8 @@ MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_add(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_add(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfadd_s(Vector1, Vector2); #else return Vector1 + Vector2; #endif @@ -1733,6 +1984,8 @@ MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_sub(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return vec_sub(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfsub_s(Vector1, Vector2); #else return Vector1 - Vector2; #endif @@ -1753,6 +2006,8 @@ MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) MLAS_UNREFERENCED_PARAMETER(Vector1); MLAS_UNREFERENCED_PARAMETER(Vector2); return vec_mul(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_s(Vector1, Vector2); #else return Vector1 * Vector2; #endif @@ -1772,6 +2027,8 @@ MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmadd_s(Vector1, Vector2, Vector3); #else return Vector1 * Vector2 + Vector3; #endif @@ -1807,6 +2064,8 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_div_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_div(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfdiv_s(Vector1, Vector2); #else return Vector1 / Vector2; #endif @@ -1824,6 +2083,8 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return wasm_f32x4_gt(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); #else return Vector1 > Vector2; #endif @@ -1837,6 +2098,8 @@ MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_and_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_and(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1850,6 +2113,8 @@ MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_or_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_or(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1863,6 +2128,8 @@ MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) return _mm_andnot_ps(VectorNot, Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_andnot(Vector, VectorNot); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #else return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); #endif @@ -1876,6 +2143,8 @@ MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_xor_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_xor(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #else return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); #endif @@ -1901,6 +2170,8 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_max(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmax_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector1 > Vector2); #endif @@ -1919,6 +2190,8 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_min(Vector1, Vector2); +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmin_s(Vector1, Vector2); #else return MlasBlendFloat32x4(Vector2, Vector1, Vector2 > Vector1); #endif @@ -2025,6 +2298,8 @@ MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) typedef __m128d MLAS_FLOAT64X2; #elif defined(MLAS_VSX_INTRINSICS) typedef __vector double MLAS_FLOAT64X2; +#elif defined(MLAS_LSX_INTRINSICS) +typedef __m128d MLAS_FLOAT64X2; #else #define MLAS_FLOAT64X2_UNSUPPORTED #endif @@ -2046,6 +2321,27 @@ MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FL return vec_madd(Vector1, Vector2, Vector3); } +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasBroadcastFloat64x2(const double *Value) +{ + return MLAS_FLOAT64X2{*Value, *Value}; +} +#elif defined(MLAS_LSX_INTRINSICS) +template +MLAS_FORCEINLINE +double +MlasExtractLaneFloat64x2(MLAS_FLOAT64X2 Vector) +{ + return Vector[Lane]; +} +MLAS_FORCEINLINE +MLAS_FLOAT64X2 +MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FLOAT64X2 Vector3) +{ + return __lsx_vfmadd_d(Vector1, Vector2, Vector3); +} + MLAS_FORCEINLINE MLAS_FLOAT64X2 MlasBroadcastFloat64x2(const double *Value) @@ -2061,6 +2357,8 @@ MlasBroadcastFloat64x2(double Value) return _mm_set1_pd(Value); #elif defined(MLAS_VSX_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2{Value, Value}; #endif } @@ -2072,6 +2370,8 @@ MlasZeroFloat64x2(void) return _mm_setzero_pd(); #elif defined(MLAS_VSX_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); +#elif defined(MLAS_LSX_INTRINSICS) + return MlasBroadcastFloat64x2(0.0f); #endif } @@ -2083,6 +2383,8 @@ MlasLoadFloat64x2(const double* Buffer) return _mm_loadu_pd(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); #endif } @@ -2094,6 +2396,8 @@ MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_storeu_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2105,6 +2409,8 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_store_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) *((MLAS_FLOAT64X2*)Buffer) = Vector; +#elif defined(MLAS_LSX_INTRINSICS) + (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif } @@ -2116,6 +2422,8 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) return _mm_mul_pd(Vector1, Vector2); #elif defined(MLAS_VSX_INTRINSICS) return Vector1 * Vector2; +#elif defined(MLAS_LSX_INTRINSICS) + return __lsx_vfmul_d(Vector1, Vector2); #endif } @@ -2150,6 +2458,17 @@ MlasReadTimeStampCounter(void) ); return ((uint64_t)edx << 32) | eax; +#elif defined(MLAS_TARGET_LARCH64) + uint64_t time_cnt, id; + + __asm__ __volatile__ + ( + "rdtime.d %0, %1\n\t" + : "=r" (time_cnt), "=r" (id) + :: + ); + + return time_cnt; #else return 0; #endif @@ -2162,7 +2481,7 @@ MlasReadTimeStampCounter(void) constexpr size_t ThreadedBufAlignment = 64; -//extern thread_local size_t ThreadedBufSize; +extern thread_local size_t ThreadedBufSize; #ifdef _MSC_VER extern thread_local std::unique_ptr ThreadedBufHolder; #else @@ -2173,8 +2492,8 @@ MLAS_FORCEINLINE constexpr size_t UpAlignSize(size_t size) { - //size = (size + ThreadedBufAlignment - 1) / ThreadedBufAlignment; - return ((size + ThreadedBufAlignment - 1) / ThreadedBufAlignment) * ThreadedBufAlignment; + size = (size + ThreadedBufAlignment - 1) / ThreadedBufAlignment; + return size * ThreadedBufAlignment; } @@ -2182,7 +2501,6 @@ MLAS_FORCEINLINE void MlasThreadedBufAlloc(size_t size) { -/* if (size > ThreadedBufSize) { #ifdef _MSC_VER ThreadedBufHolder.reset( @@ -2202,5 +2520,52 @@ MlasThreadedBufAlloc(size_t size) ThreadedBufSize = size; } -*/ } + +// +// Utilities for INT4 quantization. +// + +template +struct Int4Traits; + +template<> +struct Int4Traits { + using UnpackedType = int8_t; + static constexpr int8_t Min = -8; + static constexpr int8_t Max = 7; +}; + +template<> +struct Int4Traits { + using UnpackedType = uint8_t; + static constexpr int8_t Min = 0; + static constexpr int8_t Max = 15; +}; + +template +MLAS_FORCEINLINE +void +MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value) +{ + static_assert(std::is_same_v || std::is_same_v); + + const size_t OutputIndex = ElemIndex >> 1; // which byte + const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte + const uint8_t Shift = static_cast(NibbleIndex << 2); // Either 0 or 4 + const uint8_t Mask = static_cast(0xF0 >> Shift); + uint8_t* Dst = &Output[OutputIndex]; + + *Dst &= Mask; // Clear 4-bit lane + *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane +} + +template +MLAS_FORCEINLINE +void +MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueHigh) +{ + static_assert(std::is_same_v || std::is_same_v); + *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); +} + diff --git a/third_party/mlas/lib/platform.cpp b/third_party/mlas/lib/platform.cpp index 25020316e7..d690345bc2 100644 --- a/third_party/mlas/lib/platform.cpp +++ b/third_party/mlas/lib/platform.cpp @@ -52,13 +52,30 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); // raw hack! Need CPUIDInfo implementation for more precise detection has_fp16_ = has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif @@ -112,6 +129,14 @@ MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[ #define _XCR_XFEATURE_ENABLED_MASK 0 #endif +#if !defined(XFEATURE_MASK_XTILE) +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#endif + inline uint64_t MlasReadExtendedControlRegister( @@ -143,11 +168,6 @@ bool MlasInitAMX() { #if defined(__linux__) -#define XFEATURE_XTILECFG 17 -#define XFEATURE_XTILEDATA 18 -#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) -#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) -#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) #define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_REQ_XCOMP_PERM 0x1023 @@ -172,6 +192,28 @@ MlasInitAMX() #endif // MLAS_TARGET_AMD64_IX86 +#ifdef MLAS_TARGET_LARCH64 + +#if defined(__linux__) +#include +#include +#endif +// +// Stores a vector to build a conditional load/store mask for vmaskmovps. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveLasx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; + +// +// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. +// + +MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableLasx[16], 32) = { + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, +}; + +#endif MLAS_PLATFORM::MLAS_PLATFORM( void ) @@ -231,6 +273,10 @@ Return Value: this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -330,6 +376,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; // // Check if the processor supports Hybrid core architecture. @@ -378,6 +425,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx512F; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelAvx512F; this->NchwcBlockSize = 16; @@ -394,6 +442,8 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; + this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -405,22 +455,26 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; + this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } -#ifdef MLAS_AMX_SUPPORTED +#ifndef __APPLE__ // // Check if the processor supports AMX-TILE and AMX-INT8 // features. // - if ((Cpuid7[3] & 0b1 << 24) != 0 && (Cpuid7[3] & 0b1 << 25) != 0) { + if ((Cpuid7[3] & 0b1 << 24) != 0 && + (Cpuid7[3] & 0b1 << 25) != 0 && + (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } -#endif // MLAS_AMX_SUPPORTED +#endif // __APPLE__ #endif // ORT_MINIMAL_BUILD @@ -435,7 +489,9 @@ Return Value: #if defined(MLAS_TARGET_ARM64) - this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchNeon; + this->GemmU8S8Dispatch = &MlasGemmX8S8DispatchNeon; + this->GemmS8S8Dispatch = &MlasGemmX8S8DispatchNeon; this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; @@ -448,18 +504,41 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif defined(__linux__) - HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #else - HasDotProductInstructions = false; + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; + HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif if (HasDotProductInstructions) { - this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchUdot; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUdot; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSdot; this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + + // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + } + +#if defined(__linux__) + // + // Check if the processor supports ASIMD I8MM instructions. + // + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; } +#endif #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) @@ -467,6 +546,10 @@ Return Value: this->GemmDoubleKernel = MlasDgemmKernel; this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #if defined(__linux__) unsigned long hwcap2 = getauxval(AT_HWCAP2); @@ -492,6 +575,63 @@ Return Value: #endif // __linux__ #endif // MLAS_TARGET_POWER +#if defined(MLAS_TARGET_LARCH64) + + // + // Default to the baseline LSX support. + // + + int hwcap = getauxval(AT_HWCAP); + bool cap_lasx = hwcap & HWCAP_LOONGARCH_LASX; + bool cap_lsx = hwcap & HWCAP_LOONGARCH_LSX; + + if( cap_lasx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLasx; + this->GemmDoubleKernel = MlasGemmDoubleKernelLasx; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLasx; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLasx; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLasx; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLasx; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLasx; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLasx; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelLasx; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelLasx; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + }else if( cap_lsx ){ + this->GemmFloatKernel = MlasGemmFloatKernelLSX; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; + this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX; + this->GemmDoubleKernel = MlasGemmDoubleKernelLSX; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLSX; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLSX; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLSX; + + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLSX; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLSX; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + }else{ + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + } + + this->NchwcBlockSize = 8; + // this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; + + // this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; + +#endif // MLAS_TARGET_LARCH64 + } size_t diff --git a/third_party/mlas/lib/pooling.cpp b/third_party/mlas/lib/pooling.cpp index 12128f6c70..50dcf19224 100644 --- a/third_party/mlas/lib/pooling.cpp +++ b/third_party/mlas/lib/pooling.cpp @@ -1569,6 +1569,96 @@ Return Value: c -= 16; } +#elif defined(MLAS_LSX_INTRINSICS) + uint32_t val = 0x80808080; + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(val); + if constexpr (std::is_unsigned::value) { + MLAS_UNREFERENCED_PARAMETER(BitFlipVector); + } + + while (c >= 32) { + + __m128i MaximumVector0 = __lsx_vldi(0); + __m128i MaximumVector1 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __m128i InputVector1 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset + 16], 0); + + if constexpr (std::is_signed::value) { + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + InputVector1 = __lsx_vxor_v(InputVector1, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + MaximumVector1 = __lsx_vmax_bu(MaximumVector1, InputVector1); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + MaximumVector1 = __lsx_vxor_v(MaximumVector1, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + __lsx_vst(MaximumVector1, (__m128i*)&Output[16], 0); + Output += 32; + + ChannelOffset += 32; + c -= 32; + } + + while (c >= 16) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); + Output += 16; + + ChannelOffset += 16; + c -= 16; + } + + if (c >= 8) { + + __m128i MaximumVector0 = __lsx_vldi(0); + + for (size_t k = 0; k < KernelSize; k++) { + + __m128i InputVector0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0), 0, 1); + + if constexpr (std::is_signed::value){ + InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); + } + + MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); + } + + if constexpr (std::is_signed::value) { + MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); + } + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i*)&Output[0] , 0), __lsx_vpickve2gr_d(MaximumVector0, 0), 0), (__m128i*)&Output[0], 0); + Output += 8; + + ChannelOffset += 8; + c -= 8; + } #endif while (c > 0) { diff --git a/third_party/mlas/lib/power/QuantizePower.cpp b/third_party/mlas/lib/power/QuantizePower.cpp index 0d38288c6d..2d4d791c3a 100644 --- a/third_party/mlas/lib/power/QuantizePower.cpp +++ b/third_party/mlas/lib/power/QuantizePower.cpp @@ -1,6 +1,10 @@ +#include #include "mlasi.h" #include +// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. +// ONNX Runtime CI pipelines do not build with all compiler versions. + template void MLASCALL @@ -82,8 +86,15 @@ Return Value: auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *) Output); + + if constexpr (std::is_same_v || std::is_same_v) { + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)Output); + } else { + static_assert(std::is_same_v || std::is_same_v); + vec_xst(ShortVector0, 0, (int16_t *)Output); + vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); + } Output += 16; Input += 16; @@ -99,6 +110,119 @@ Return Value: } } +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer as int4 using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. Contains packed 4-bit elements. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. + UnpackedType TmpOutput[16] = {}; + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = vec_div(FloatVector0, ScaleVector); + FloatVector1 = vec_div(FloatVector1, ScaleVector); + FloatVector2 = vec_div(FloatVector2, ScaleVector); + FloatVector3 = vec_div(FloatVector3, ScaleVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = vec_add(FloatVector0, ZeroPointVector); + FloatVector1 = vec_add(FloatVector1, ZeroPointVector); + FloatVector2 = vec_add(FloatVector2, ZeroPointVector); + FloatVector3 = vec_add(FloatVector3, ZeroPointVector); + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); + + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); + MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); + MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); + MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); + MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); + MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); + + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + void MLASCALL MlasQuantizeLinearU8Kernel( @@ -124,3 +248,56 @@ MlasQuantizeLinearS8Kernel( { MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } + +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + diff --git a/third_party/mlas/lib/power/qgemm_kernel_power10.cpp b/third_party/mlas/lib/power/qgemm_kernel_power10.cpp index 633349e800..a67be1dbfa 100644 --- a/third_party/mlas/lib/power/qgemm_kernel_power10.cpp +++ b/third_party/mlas/lib/power/qgemm_kernel_power10.cpp @@ -67,7 +67,7 @@ MlasGemmQuantFixupZeroPointB( } -template +template void MlasGemmQuantCopyPackA8x8( MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* D, @@ -75,11 +75,10 @@ MlasGemmQuantCopyPackA8x8( size_t lda, size_t CountM, size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned + int32_t* RowSumBuffer ) { - const uint8_t Flip = (AIsSigned ? 0 : 0x80); + constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80); Vtype vmask = reinterpret_cast(vec_splats(Flip)); typedef __vector signed char vec_t; @@ -106,66 +105,74 @@ MlasGemmQuantCopyPackA8x8( Vtype a3 = *reinterpret_cast(&a[lda * 2]); Vtype a4 = *reinterpret_cast(&a[lda * 3]); Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); a1 = *reinterpret_cast(&a[lda*4]); a2 = *reinterpret_cast(&a[lda*5]); a3 = *reinterpret_cast(&a[lda*6]); a4 = *reinterpret_cast(&a[lda*7]); vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx8 = vec_xxpermdi (vx, vx1, 0); - Vtype vx9 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx10 = vec_xxpermdi (vx, vx1, 3); - Vtype vx11 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx8 = vec_xxpermdi(vx, vx1, 0); + Vtype vx9 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx10 = vec_xxpermdi(vx, vx1, 3); + Vtype vx11 = vec_xxpermdi(vx2, vx3, 3); vec_t vxx = - reinterpret_cast(vec_sub (vx4, vmask)); - vsum = vec_sum4s (vxx, vsum); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[0]) = vxx; - vxx = reinterpret_cast(vec_sub (vx5, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[16]) = vxx; - vxx = reinterpret_cast(vec_sub (vx6, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[32]) = vxx; - vxx = reinterpret_cast(vec_sub (vx7, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[48]) = vxx; - vxx = reinterpret_cast(vec_sub (vx8, vmask)); + vxx = AIsSigned ? reinterpret_cast(vx8) : + reinterpret_cast(vec_sub(vx8, vmask)); *reinterpret_cast(&D[64]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx9, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx9) : + reinterpret_cast(vec_sub(vx9, vmask)); *reinterpret_cast(&D[80]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx10, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx10) : + reinterpret_cast(vec_sub(vx10, vmask)); *reinterpret_cast(&D[96]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx11, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx11) : + reinterpret_cast(vec_sub(vx11, vmask)); *reinterpret_cast(&D[112]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); + vsum2 = vec_sum4s(vxx, vsum2); D += 16 * 8; a += 16; y -= 16; @@ -179,16 +186,18 @@ MlasGemmQuantCopyPackA8x8( int a4 = *reinterpret_cast(&a[lda*3]); __vector int vx1 = { a1, a2, a3, a4}; vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); - vsum = vec_sum4s (vx, vsum); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); + vsum = vec_sum4s(vx, vsum); *reinterpret_cast(&D[0]) = vx; a1 = *reinterpret_cast(&a[lda*4]); a2 = *reinterpret_cast(&a[lda*5]); a3 = *reinterpret_cast(&a[lda*6]); a4 = *reinterpret_cast(&a[lda*7]); __vector int vx2 = { a1, a2, a3, a4}; - vx = reinterpret_cast(vec_sub (reinterpret_cast(vx2), vmask)); - vsum2 = vec_sum4s (vx, vsum2); + vx = AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(reinterpret_cast(vx2), vmask)); + vsum2 = vec_sum4s(vx, vsum2); if (CountK & 3) { if (yval >= 12) { *reinterpret_cast(&D[64]) = vx; @@ -225,10 +234,10 @@ MlasGemmQuantCopyPackA8x8( } if (y >= 1) { - Vtype a1 = reinterpret_cast(vec_splats(Flip)); - Vtype a2 = reinterpret_cast(vec_splats(Flip)); - Vtype a3 = reinterpret_cast(vec_splats(Flip)); - Vtype a4 = reinterpret_cast(vec_splats(Flip)); + Vtype a1 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a4 = vmask; a1[0] = a[0]; a2[0] = a[lda]; a3[0] = a[lda * 2]; @@ -246,20 +255,21 @@ MlasGemmQuantCopyPackA8x8( a4[2] = a[lda * 3 + 2]; } Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx2 = vec_xxpermdi (vx, vx1, 0); + Vtype vx2 = vec_xxpermdi(vx, vx1, 0); vec_t vx3 = - reinterpret_cast(vec_sub (vx2, vmask)); - vsum = vec_sum4s (vx3, vsum); + AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(vx2, vmask)); + vsum = vec_sum4s(vx3, vsum); *reinterpret_cast(&D[0]) = vx3; - a1 = reinterpret_cast(vec_splats(Flip)); - a2 = reinterpret_cast(vec_splats(Flip)); - a3 = reinterpret_cast(vec_splats(Flip)); - a4 = reinterpret_cast(vec_splats(Flip)); + a1 = vmask; + a2 = vmask; + a3 = vmask; + a4 = vmask; a1[0] = a[lda * 4]; a2[0] = a[lda * 5]; a3[0] = a[lda * 6]; @@ -277,14 +287,15 @@ MlasGemmQuantCopyPackA8x8( a4[2] = a[lda * 7 + 2]; } vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - vx2 = vec_xxpermdi (vx, vx1, 0); - vx3 = reinterpret_cast(vec_sub (vx2, vmask)); - vsum2 = vec_sum4s (vx3, vsum2); + vx2 = vec_xxpermdi(vx, vx1, 0); + vx3 = AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(vx2, vmask)); + vsum2 = vec_sum4s(vx3, vsum2); if (CountK % 16 >= 12) { *reinterpret_cast(&D[64]) = vx3; D += 80; @@ -327,34 +338,38 @@ MlasGemmQuantCopyPackA8x8( Vtype a3 = *reinterpret_cast(&a[lda * 2]); Vtype a4 = *reinterpret_cast(&a[lda * 3]); Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); vec_t vx0 = - reinterpret_cast(vec_sub (vx4, vmask)); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx5, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx6, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx7, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s (vx0, vsum); + vsum = vec_sum4s(vx0, vsum); D += 16 * 4; a += 16; y -= 16; @@ -367,16 +382,17 @@ MlasGemmQuantCopyPackA8x8( int a4 = *reinterpret_cast(&a[lda*3]); __vector int vx1 = { a1, a2, a3, a4}; vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; a += 4; y -= 4; } if (y >= 1) { - Vtype vx = reinterpret_cast(vec_splats(Flip)); + Vtype vx = vmask; vx[0] = a[0]; vx[4] = a[lda]; vx[8] = a[lda * 2]; @@ -394,9 +410,10 @@ MlasGemmQuantCopyPackA8x8( vx[14] = a[lda * 3 + 2]; } vec_t vx1 = - reinterpret_cast(vec_sub (vx, vmask)); + AIsSigned ? reinterpret_cast(vx) : + reinterpret_cast(vec_sub(vx, vmask)); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; a += 16; } @@ -416,9 +433,9 @@ MlasGemmQuantCopyPackA8x8( __vector signed int vsum = { 0 }; while (y >= 16) { - Vtype a4 = reinterpret_cast(vec_splats(Flip)); - Vtype a2 = reinterpret_cast(vec_splats(Flip)); - Vtype a3 = reinterpret_cast(vec_splats(Flip)); + Vtype a4 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; Vtype a1 = *reinterpret_cast(&a[0]); if (CountM == 3) { a3 = *reinterpret_cast(&a[lda * 2]); @@ -427,53 +444,58 @@ MlasGemmQuantCopyPackA8x8( a2 = *reinterpret_cast(&a[lda]); } Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); vec_t vx0 = - reinterpret_cast(vec_sub (vx4, vmask)); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx5, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx6, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx7, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s (vx0, vsum); + vsum = vec_sum4s(vx0, vsum); D += 16 * 4; a += 16; y -= 16; } while (y >= 4) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; __vector int vx1 = reinterpret_cast<__vector int>(vb); vx1[0] = *reinterpret_cast(&a[0]); - if(CountM >= 2) { + if (CountM >= 2) { vx1[1] = *reinterpret_cast(&a[lda]); } - if(CountM >= 3) { + if (CountM >= 3) { vx1[2] = *reinterpret_cast(&a[lda*2]); } vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; a += 4; y -= 4; @@ -508,7 +530,7 @@ MlasGemmQuantCopyPackA8x8( } } *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; } *RowSumBuffer++ = vsum[0]; @@ -521,7 +543,7 @@ MlasGemmQuantCopyPackA8x8( } } -template +template void MlasGemmQuantCopyPackB8x8( MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType* D, @@ -529,29 +551,128 @@ MlasGemmQuantCopyPackB8x8( size_t ldb, size_t CountN, size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned + int32_t* ColumnSumBuffer ) { - const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); + [[maybe_unused]] constexpr uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); typedef __vector unsigned char vec_t; Vtype vmask = reinterpret_cast(vec_splats(BitFlipValue)); vec_t mask = {0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15}; - const int8_t Flip = (BIsSigned ? -128 : 0); - // Process 4 columns of matrix B in a loop. - // // Copy columns from matrix B to the packed buffer. Signed buffers are // converted to unsigned buffers in order to share a common kernel. // // If CountK is not aligned to a multiple of four, then the packed buffer // is padded with zero vectors. - while (CountN >= 4) { + // Process 16 columns of matrix B in a loop. + // + size_t PackedK = ((CountK + 4 - 1) / 4) * 16; + size_t k2 = PackedK; + size_t k3 = PackedK*2; + size_t k4 = PackedK*3; + + while (CountN >= 16) { const uint8_t* b = B; __vector unsigned int vsum = {0}; + __vector unsigned int vsum2 = {0}; + __vector unsigned int vsum3 = {0}; + __vector unsigned int vsum4 = {0}; size_t y = CountK; - if(y >= 4) { + if (y >= 4) { + do { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = *reinterpret_cast(&b[ldb]); + Vtype b3 = *reinterpret_cast(&b[ldb*2]); + Vtype b4 = *reinterpret_cast(&b[ldb*3]); + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum = vec_sum4s(vx1, vsum); + vsum2 = vec_sum4s(vx2, vsum2); + vsum3 = vec_sum4s(vx3, vsum3); + vsum4 = vec_sum4s(vx4, vsum4); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = (y >= 2) ? *reinterpret_cast(&b[ldb]) : vmask; + Vtype b3 = (y >= 3) ? *reinterpret_cast(&b[ldb*2]) : vmask; + Vtype b4 = vmask; + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum = vec_sum4s(vx1, vsum); + vsum2 = vec_sum4s(vx2, vsum2); + vsum3 = vec_sum4s(vx3, vsum3); + vsum4 = vec_sum4s(vx4, vsum4); + D += 16; + } + *ColumnSumBuffer++ = vsum[0]; + *ColumnSumBuffer++ = vsum[1]; + *ColumnSumBuffer++ = vsum[2]; + *ColumnSumBuffer++ = vsum[3]; + *ColumnSumBuffer++ = vsum2[0]; + *ColumnSumBuffer++ = vsum2[1]; + *ColumnSumBuffer++ = vsum2[2]; + *ColumnSumBuffer++ = vsum2[3]; + *ColumnSumBuffer++ = vsum3[0]; + *ColumnSumBuffer++ = vsum3[1]; + *ColumnSumBuffer++ = vsum3[2]; + *ColumnSumBuffer++ = vsum3[3]; + *ColumnSumBuffer++ = vsum4[0]; + *ColumnSumBuffer++ = vsum4[1]; + *ColumnSumBuffer++ = vsum4[2]; + *ColumnSumBuffer++ = vsum4[3]; + B += 16; + CountN -= 16; + D += k4; + } + + // Process four columns of matrix B in a loop. + // + while (CountN >= 4) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + size_t y = CountK; + if (y >= 4) { do { int b1 = *reinterpret_cast(&b[0]); int b2 = *reinterpret_cast(&b[ldb]); @@ -559,28 +680,30 @@ MlasGemmQuantCopyPackB8x8( int b4 = *reinterpret_cast(&b[ldb*3]); __vector int vb = {b1, b2, b3, b4}; Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; b += ldb*4; y -= 4; } while (y >= 4); } if (y >= 1) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; __vector int vb1 = reinterpret_cast<__vector int>(vb); vb1[0] = *reinterpret_cast(&b[0]); - if( y >= 2) { + if (y >= 2) { vb1[1] = *reinterpret_cast(&b[ldb]); } - if( y >= 3) { + if (y >= 3) { vb1[2] = *reinterpret_cast(&b[ldb*2]); } Vtype vx = vec_perm(reinterpret_cast(vb1), reinterpret_cast(vb1), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; } *ColumnSumBuffer++ = vsum[0]; @@ -600,7 +723,7 @@ MlasGemmQuantCopyPackB8x8( size_t y = CountK; if (y >= 4) { do { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; if (CountN == 1) { vb[0] = b[0]; vb[4] = b[ldb]; @@ -632,16 +755,17 @@ MlasGemmQuantCopyPackB8x8( vb[14] = b[ldb*3+2]; } Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; b += ldb*4; y -= 4; } while (y >= 4); } if (y >= 1) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; if (CountN == 1) { vb[0]= b[0]; if (y >= 2) { @@ -679,9 +803,10 @@ MlasGemmQuantCopyPackB8x8( } } Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; } *ColumnSumBuffer++ = vsum[0]; @@ -707,9 +832,9 @@ MlasGemmQuantCopyPackA( ) { if (AIsSigned) { - MlasGemmQuantCopyPackA8x8<__vector signed char>(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); + MlasGemmQuantCopyPackA8x8<__vector signed char, true>(D, A, lda, CountM, CountK, RowSumBuffer); } else { - MlasGemmQuantCopyPackA8x8<__vector unsigned char>(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); + MlasGemmQuantCopyPackA8x8<__vector unsigned char, false>(D, A, lda, CountM, CountK, RowSumBuffer); } } template<> @@ -725,9 +850,9 @@ MlasGemmQuantCopyPackB( ) { if (BIsSigned) { - MlasGemmQuantCopyPackB8x8<__vector signed char>(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); + MlasGemmQuantCopyPackB8x8<__vector signed char, true>(D, B, ldb, CountN, CountK, ColumnSumBuffer); } else { - MlasGemmQuantCopyPackB8x8< __vector unsigned char>(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); + MlasGemmQuantCopyPackB8x8< __vector unsigned char, false>(D, B, ldb, CountN, CountK, ColumnSumBuffer); } } @@ -747,46 +872,93 @@ MlasQgemmStoreVectorMMA int pos ) { - __vector int *rowC; - __vector signed int vsum = {0}; + size_t RowCount; + __vector signed int vsum0, vsum1, vsum2, vsum3; + __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); + C += VectorCount; if (ZeroPointB != nullptr) { + __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] * ZeroPointB[pos] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] * ZeroPointB[pos+1] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] * ZeroPointB[pos+2] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] * ZeroPointB[pos+3] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] = *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] * ZeroPointB[pos] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] * ZeroPointB[pos+1] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] * ZeroPointB[pos+2] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] * ZeroPointB[pos+3] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] += *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } } else { if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] = *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] += *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } } @@ -846,36 +1018,36 @@ MlasQgemmComputeMMA( ) { if (CountK == 16) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp (acc0, va[2], vb[2]); - __builtin_mma_xvi8ger4pp (acc0, va[3], vb[3]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); + __builtin_mma_xvi8ger4pp(acc0, va[3], vb[3]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[4], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[5], vb[1]); - __builtin_mma_xvi8ger4pp (acc1, va[6], vb[2]); - __builtin_mma_xvi8ger4pp (acc1, va[7], vb[3]); + __builtin_mma_xvi8ger4pp(acc1, va[4], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[5], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[6], vb[2]); + __builtin_mma_xvi8ger4pp(acc1, va[7], vb[3]); } } else if (CountK == 12) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp (acc0, va[2], vb[2]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[3], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[4], vb[1]); - __builtin_mma_xvi8ger4pp (acc1, va[5], vb[2]); + __builtin_mma_xvi8ger4pp(acc1, va[3], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[4], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[5], vb[2]); } } else if (CountK == 8) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[2], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[3], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[2], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[3], vb[1]); } } else { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[1], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[1], vb[0]); } } }; @@ -902,7 +1074,7 @@ MlasGemmQuantKernel( if (Mval >= 8) { Mval = 4; } - while(CountN > 0) { + while (CountN > 0) { const int8_t *a = A; typedef __vector unsigned char vec_t; const uint8_t *b = B; @@ -1057,23 +1229,23 @@ MlasGemmQuantKernel( } // Store matrix C with accumulator result. if (CountN >=16) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc3); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); MlasQgemmStoreVectorMMA<12>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc7); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc7); MlasQgemmStoreVectorMMA<12>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 12); } INC_BUFFER(16); @@ -1082,72 +1254,72 @@ MlasGemmQuantKernel( C += 16; } else { if (CountN >=12 ) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); } INC_BUFFER(12); if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc3); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc7); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc7); } } CountN -= 12; C += 12; } else if (CountN >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); } INC_BUFFER(8); if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc6); } } CountN -= 8; C += 8; } else if (CountN >= 4) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc5); } } INC_BUFFER(4); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); } CountN -= 4; C += 4; } else { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc4); } } CountN &= 3; diff --git a/third_party/mlas/lib/q4_dq.cpp b/third_party/mlas/lib/q4_dq.cpp new file mode 100644 index 0000000000..7df6852f84 --- /dev/null +++ b/third_party/mlas/lib/q4_dq.cpp @@ -0,0 +1,1684 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4_dq.cpp + +Abstract: + + This module contains the data structures and implementations + for blocked int4 quantization and dequantization. + + Int4 block quantization is used to compress weight tensors of large + language models. + +--*/ + +#include "q4common.h" +#include + +template +constexpr size_t BlkQ4BufSize(size_t N, size_t K) { + const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); + return N * KBlocks * T::BlobSize; +} + +size_t MLASCALL MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K) { + if (GetMlasPlatform().FpQ4GemmDispatch == nullptr) { + return 0; + } + + switch (QType) { + case BlkQ4Sym: + return BlkQ4BufSize(N, K); + case BlkQ4Sym64: + return BlkQ4BufSize(N, K); + case BlkQ4Sym128: + return BlkQ4BufSize(N, K); + default: + return BlkQ4BufSize(N, K); + } +} + +template +MLAS_FORCEINLINE void MlasQ4GemmPackBImpl(void *PackedBuf, const float *FpData, size_t N, size_t K, size_t ldb) { + auto *dst_ptr = reinterpret_cast(PackedBuf); + + for (size_t n = 0; n < N; n++) { + const float *src = FpData; // starting from top of the column + + for (size_t k = 0; k < K; k += T::BlkLen) { + size_t klen = std::min(size_t(T::BlkLen), K - k); + float amax = 0.0f; // abs(max) + float max = 0.0f; + + for (size_t l = 0; l < klen; l++) { + const float v = src[ldb * l]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float scale = max / (-8); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + MlasQ4BlkScale(dst_ptr) = scale; + uint8_t *data = MlasQ4BlkData(dst_ptr); + + for (size_t kk = 0; kk < klen; kk += 32) { + size_t kklen = std::min((size_t)32, klen - kk); + for (size_t l = 0; l < 16; l++) { + const float v0 = l < kklen ? src[ldb * (kk + l)] * reciprocal_scale : 0; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, v0 + 8.5f)); + + const size_t l1 = l + 16; + const float v1 = (l1 < kklen) ? src[ldb * (kk + l1)] * reciprocal_scale : 0; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, v1 + 8.5f)); + + data[l] = vi0 | (vi1 << 4); + } + data += 16; + } + + // Move to next block of values in this column + dst_ptr += T::BlobSize; + src += ldb * klen; + } + + FpData++; // move to next column + } +} + +template <> +MLAS_FORCEINLINE void MlasQ4GemmPackBImpl(void *PackedBuf, const float *FpData, size_t N, size_t K, size_t ldb) { + auto *dst_ptr = reinterpret_cast(PackedBuf); + + for (size_t n = 0; n < N; n++) { + const float *src = FpData; // starting from top of the column + + for (size_t k = 0; k < K; k += MLAS_Q4TYPE_BLK1::BlkLen) { + size_t klen = std::min(MLAS_Q4TYPE_BLK1::BlkLen, K - k); + float min = std::numeric_limits::max(); + float max = -min; + + for (size_t l = 0; l < klen; l++) { + const float v = src[ldb * l]; + if (v < min) + min = v; + if (v > max) + max = v; + } + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + const float scale = (max - min) / ((1 << 4) - 1); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + float zero_point_fp = min; + if (scale != 0.0f) { + zero_point_fp = 0.f - min / scale; + } + + // Handle any clamping + uint8_t &zp = MlasQ4BlkZeroPoint(dst_ptr); + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > 15.0f) { + zp = 15; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + MlasQ4BlkScale(dst_ptr) = scale; + uint8_t *data = MlasQ4BlkData(dst_ptr); + + for (size_t kk = 0; kk < klen; kk += 32) { + size_t kklen = std::min((size_t)32, klen - kk); + for (size_t l = 0; l < 16; l++) { + const float v0 = l < kklen ? src[ldb * (kk + l)] : 0; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); + + const size_t l1 = l + 16; + const float v1 = (l1 < kklen) ? src[ldb * (kk + l1)] : 0; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); + + data[l] = vi0 | (vi1 << 4); + } + data += 16; + } + // move to next block of values in this column + dst_ptr += MLAS_Q4TYPE_BLK1::BlobSize; + src += ldb * klen; + } + FpData++; // move to next column + } +} + +void MLASCALL MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void *PackedBuf, const float *FpData, size_t N, size_t K, size_t ldb) { + switch (QType) { + case BlkQ4Sym: + return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); + case BlkQ4Sym64: + return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); + case BlkQ4Sym128: + return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); + default: + return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); + } +} + +template +MLAS_FORCEINLINE void MlasQ4GemmUnPackBImpl(float *FpData, const void *PackedBuf, size_t N, size_t K, size_t ldb) { + const auto *src = reinterpret_cast(PackedBuf); + for (size_t n = 0; n < N; n++) { + for (size_t k = 0; k < K; k += T::BlkLen) { + size_t CountK = std::min(K - k, T::BlkLen); + + float *dest = FpData + ldb * k + n; + const float scale = MlasQ4BlkScale(src); + const uint8_t *data = MlasQ4BlkData(src); + + for (size_t kk = 0; kk < CountK; kk += 32) { + size_t kklen = std::min((size_t)32, CountK - kk); + for (size_t l = 0; l < 16; l++) { + const uint8_t vi = data[l]; + + if (l < kklen) { + const int vi0 = (vi & 0x0F) - 8; + const float v0 = vi0 * scale; + dest[ldb * (kk + l)] = v0; + } + + const size_t l1 = l + 16; + if (l1 < kklen) { + const int vi1 = (vi >> 4) - 8; + const float v1 = vi1 * scale; + dest[ldb * (kk + l1)] = v1; + } + } + data += 16; + } + src += T::BlobSize; + } + } +} + +template <> +MLAS_FORCEINLINE void MlasQ4GemmUnPackBImpl(float *FpData, const void *PackedBuf, size_t N, size_t K, size_t ldb) { + const auto *src = reinterpret_cast(PackedBuf); + for (size_t n = 0; n < N; n++) { + for (size_t k = 0; k < K; k += MLAS_Q4TYPE_BLK1::BlkLen) { + size_t CountK = std::min(K - k, MLAS_Q4TYPE_BLK1::BlkLen); + + float *dest = FpData + ldb * k + n; + const float s = MlasQ4BlkScale(src); + const uint8_t z = MlasQ4BlkZeroPoint(src); + const uint8_t *pp = MlasQ4BlkData(src); + + for (size_t kk = 0; kk < CountK; kk += 32) { + size_t kklen = std::min((size_t)32, CountK - kk); + for (size_t l = 0; l < 16; l++) { + const uint8_t vi = pp[l]; + + if (l < kklen) { + const int8_t vi0 = vi & 0x0F; + const float v0 = (vi0 - z) * s; + dest[ldb * (kk + l)] = v0; + } + + size_t l1 = l + 16; + if (l1 < kklen) { + const int8_t vi1 = vi >> 4; + const float v1 = (vi1 - z) * s; + dest[ldb * (kk + l1)] = v1; + } + } + pp += 16; + } + src += MLAS_Q4TYPE_BLK1::BlobSize; + } + } +} + +void MLASCALL MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float *FpData, const void *PackedBuf, size_t N, size_t K, size_t ldb) { + switch (QType) { + case BlkQ4Sym: + return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); + case BlkQ4Sym64: + return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); + case BlkQ4Sym128: + return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); + default: + return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); + } +} + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + +template +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE void range2scalezp(float min, float max, ScaleT &scale, uint8_t &zp) { + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = ScaleT(scale_f); +} + +template +MLAS_FORCEINLINE void range2scale(float min, float max, ScaleT &scale) { + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = ScaleT(max / mid_fp); +}; + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static MLAS_FORCEINLINE void quantizeMetaShape(int rows, int columns, int &meta_rows, int &meta_cols) { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + static MLAS_FORCEINLINE void quantizedShape(int rows, int columns, int &q_rows, int &q_cols) { + int meta_rows; + int meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + + // quantized matrix is stored in column major, packed by column + q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + q_cols = meta_cols * QuantBlk::kColumn; + } + + static MLAS_FORCEINLINE void + quantizedBufferSizes(int rows, int columns, size_t &data_bytes, size_t &scale_num_elements, size_t *zero_point_bytes) { + int meta_rows, meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + data_bytes = q_rows * q_cols; + scale_num_elements = meta_rows * meta_cols; + + if (zero_point_bytes) { + // this works for qbits == 4 but may need to be updated for other qbits values + *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; + } + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose(uint8_t *dst, + ElementT *scales, + uint8_t *zero_points, + const ElementT *src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL *thread_pool) { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel(thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) + min = v; + if (v > max) + max = v; + } + } + + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, BitsTraits::kMaxFp); + } + + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize(ElementT *dst, + const uint8_t *weights, + const ElementT *scales, + const uint8_t *zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL *thread_pool) { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel(thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = (zero_points == nullptr) ? 0x88 : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + +/** + * @brief Blockwise quantization methods for QDQ format. Input tensor is quantized along column + * or row. Scales and zeros are calculated. Based on qbits, consecutive quantized elements + * in memory are packed together, which means the packing is along the row. Quantized data + * are stored in row major, so the output tensor reserves same shape, in terms of qbits type, + * as the input tensor. + * @tparam Tin source data type, e.g. fp32/fp16 + * @tparam qbits number of bits in each quantized element + */ +template +struct BlockwiseQDQQuantizer; + +template +struct BlockwiseQDQQuantizer { + static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { return (val >> (idx << 2)) & 0xF; } + + static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1) { return (v0 & 0xF) | ((v1 & 0xF) << 4); } + + // If src is row major, then dst is column major. Transpose: + // | src0: low 4 bit | src0: high 4 bit | + // | src1: low 4 bit | src1: high 4 bit | + // --> + // | dst0: low 4 bit | dst1: low 4 bit | + // | dst0: high 4 bit| dst1: high 4 bit | + // If src is column major, then dst is row major. Transpose: + // | src0: low 4 bit | src1: low 4 bit | + // | src0: high 4 bit| src1: high 4 bit | + // --> + // | dst0: low 4 bit | dst0: high 4 bit | + // | dst1: low 4 bit | dst1: high 4 bit | + static MLAS_FORCEINLINE void Transpose(uint8_t src0, uint8_t src1, uint8_t &dst0, uint8_t &dst1) { + dst0 = (src0 & 0xF) | ((src1 & 0xF) << 4); + dst1 = ((src0 & 0xF0) >> 4) | (src1 & 0xF0); + } + + static MLAS_FORCEINLINE uint8_t QuantizeV(Tin src, float reciprocal_scale, uint8_t zero_point) { + return static_cast( + std::clamp(static_cast(std::roundf(static_cast(src) * reciprocal_scale)) + static_cast(zero_point), + 0, + BitsTraits<4>::kMax)); + } + + /** + * @brief Quantize a matrix shape [rows, columns] row-wise. Scales and zero points are calculated. + * Quantized data are packed row-wise based on qbits. Quantized data are stored in row + * major, so the output tensor reserves the shape, in terms output type. + * Thread block is [1, quant_block_size * 2]. + * @param src the source matrix, row major: [rows * columns] + * @param scales the scales of quantized blocks, row major layout with shape: + * [rows * ceil(columns / quant_block_size)] + * @param zero_points the zero points of quantized blocks, packed. Same shape as scales + * in terms of output type. In terms of uint8_t, the shape is: + * [ceil(rows * ceil(columns / quant_block_size) * qbits / 8)] + * @param dst the quantized weights, row major: [rows * columns] in terms of + * output type. In terms of uint8_t, the shape is: [ceil(rows * columns * qbits / 8] + * @param rows number of rows in the source matrix + * @param columns number of columns in the source matrix, must satisfy + * ceil(columns / quant_block_size) % 2 == 0, so in each thread block, + * zero points are packed into one byte. + * @param quant_block_size number of elements quantized together. + * @param thread_pool thread pool for parallel processing + */ + static void QuantizeRowWise(const Tin *src, + Tin *scales, + uint8_t *zero_points, + uint8_t *dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + MLAS_UNREFERENCED_PARAMETER(src); + MLAS_UNREFERENCED_PARAMETER(scales); + MLAS_UNREFERENCED_PARAMETER(zero_points); + MLAS_UNREFERENCED_PARAMETER(dst); + MLAS_UNREFERENCED_PARAMETER(rows); + MLAS_UNREFERENCED_PARAMETER(columns); + MLAS_UNREFERENCED_PARAMETER(quant_block_size); + MLAS_UNREFERENCED_PARAMETER(thread_pool); + throw std::runtime_error("BlockwiseQDQQuantizer::BlockwiseQDQQuantizer is not implemented"); + } + + /** + * @brief Quantize a matrix shape [rows, columns] column-wise. Scales and zero points are calculated. + * Quantized data are packed row-wise based on qbits. Quantized data are stored in row major + * so the output tensor reserves the shape, in terms output type. + * @param src the source matrix, row major: [rows * columns] + * @param scales the scales of quantized blocks, row major with shape: + * [ceil(rows/quant_block_size) * columns] + * @param zero_points the zero points of quantized blocks, packed. Same shape as scales in terms + * of output type. In uint8_t, the shape is: + * [ceil(columns * ceil(rows / quant_block_size) * qbits / 8)] + * @param dst the quantized weights, row major: [rows * columns] in terms of output type. + * In uint8_t, the shape is: [ceil(rows * columns * qbits / 8] + * @param rows number of rows in the source matrix + * @param columns number of columns in the source matrix. + * @param quant_block_size number of rows/columns quantized together + * @param thread_pool thread pool for parallel processing + */ + static void QuantizeColumnWise(const Tin *src, + Tin *scales, + uint8_t *zero_points, + uint8_t *dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + QuantizeColumnWisePackUnaligned(src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool); + } else { + QuantizeColumnWisePackAligned(src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool); + } + } + + /** + * @brief Transpose quantized tensors, which has been column-wise quantized, for use in MatMulNbits. + * Since both src tensor and dst tensor are packed, it's not needed to consider sign + * during the unpacking/packing in transpose. + * @param src_weights The quantized weights, row major: [rows, columns] in qbits type. + * In uint8_t, size of [ceil(rows * columns * qbits / 8)]. + * @param src_scales [ceil(rows / quant_block_size), columns] + * @param src_zero_points [ceil(rows / quant_block_size), columns] in qbits type. In uint8_t, size of + * [ceil(ceil(rows / quant_block_size) * columns * qbits / 8 )]. + * @param dst_weights the transposed quantized weights, column major. In uint8_t, the shape is + * [columns, ceil(rows / quant_block_size), ceil(quant_block_size * qbits / 8)] + * @param dst_scales [columns, ceil(rows / quant_block_size)] + * @param dst_zero_points [columns, ceil(ceil(rows / quant_block_size) * qbits / 8)] in uint8_t. + * @param rows number of src rows in qbits type. + * @param columns number of src columns in qbits type. + * @param quant_block_size number of elements quantized together + * @param thread_pool thread pool for parallel processing + */ + static void TransposeColumnWiseQuantized(const uint8_t *src_weights, + const Tin *src_scales, + const uint8_t *src_zero_points, + uint8_t *dst_weights, + Tin *dst_scales, + uint8_t *dst_zero_points, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + // Must avoid multiple thread write to a single byte, which means the starting index + // of a thread block must be even. To achieve that, we need to customize the thread + // block size based on the parity of columns. + if (columns & 1) { + TransposeColumnWiseQuantizedPackUnaligned(src_weights, + src_scales, + src_zero_points, + dst_weights, + dst_scales, + dst_zero_points, + rows, + columns, + quant_block_size, + thread_pool); + } else { + TransposeColumnWiseQuantizedPackAligned(src_weights, + src_scales, + src_zero_points, + dst_weights, + dst_scales, + dst_zero_points, + rows, + columns, + quant_block_size, + thread_pool); + } + } + +private: + static void QuantizeColumnWisePackAligned(const Tin *src, + Tin *scales, + uint8_t *zero_points, + uint8_t *dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + if (columns % 2 != 0) { + throw std::runtime_error("Columns must be multiple of 2."); + } + // Thread block is [quant_block_size, thread_blk_size]. thread_blk_size % 2 == 0. + constexpr int32_t thread_blk_size = 128; + const auto num_row_thread_blk = (rows + quant_block_size - 1) / quant_block_size; + const auto num_col_thread_blk = (columns + thread_blk_size - 1) / thread_blk_size; + const auto num_thread_blk = num_row_thread_blk * num_col_thread_blk; + constexpr auto minf = std::numeric_limits::lowest(); + constexpr auto maxf = std::numeric_limits::max(); + + MlasTryBatchParallel(thread_pool, static_cast(num_thread_blk), [&](ptrdiff_t thread_blk_idx) { + // !!warning!!: buffering the whole thread block + constexpr int32_t buffer_size = 128; + if (buffer_size != thread_blk_size) { + throw std::runtime_error("buffer size must be equal to thread block size."); + } + float reciprocal_scale_t[buffer_size]; + uint8_t zp_t[buffer_size]; + float vmin_t[buffer_size]; + float vmax_t[buffer_size]; + + const int32_t row_thread_blk_idx = static_cast(thread_blk_idx / num_col_thread_blk); + const int32_t col_thread_blk_idx = static_cast(thread_blk_idx % num_col_thread_blk); + const int32_t row_idx = row_thread_blk_idx * quant_block_size; + const int32_t col_idx = col_thread_blk_idx * buffer_size; + const int32_t row_size = std::min(quant_block_size, rows - row_idx); + const int32_t col_size = std::min(buffer_size, columns - col_idx); + // input_idx, scale_idx, col_size are aligned to 2 + auto input_idx = row_idx * columns + col_idx; + auto scale_idx = row_thread_blk_idx * columns + col_idx; + + Tin scale0_tt, scale1_tt; + uint8_t v0_tt, v1_tt; + + std::fill_n(vmin_t, buffer_size, maxf); + std::fill_n(vmax_t, buffer_size, minf); + + // calculate min/max + for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < col_size; i += 2) { + auto v0 = static_cast(src[input_idx_t + i]); + auto v1 = static_cast(src[input_idx_t + i + 1]); + vmin_t[i] = std::min(vmin_t[i], v0); + vmax_t[i] = std::max(vmax_t[i], v0); + vmin_t[i + 1] = std::min(vmin_t[i + 1], v1); + vmax_t[i + 1] = std::max(vmax_t[i + 1], v1); + } + } + + // calculate scale and zero point, and store + for (int32_t i = 0; i < col_size; i += 2) { + v0_tt = v1_tt = BitsTraits<4>::kMid; + + if (zero_points) { + range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); + range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); + zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); + } else { + range2scale(vmin_t[i], vmax_t[i], scale0_tt); + range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); + } + + scales[scale_idx + i] = scale0_tt; + scales[scale_idx + i + 1] = scale1_tt; + + float scalef0 = static_cast(scale0_tt); + reciprocal_scale_t[i] = scalef0 ? 1.0f / scalef0 : 0.0f; + zp_t[i] = v0_tt; + + float scalef1 = static_cast(scale1_tt); + reciprocal_scale_t[i + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; + zp_t[i + 1] = v1_tt; + } + + // quantize and pack + for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < col_size; i += 2) { + v0_tt = QuantizeV(src[input_idx_t + i], reciprocal_scale_t[i], zp_t[i]); + v1_tt = QuantizeV(src[input_idx_t + i + 1], reciprocal_scale_t[i + 1], zp_t[i + 1]); + dst[(input_idx_t + i) >> 1] = Pack(v0_tt, v1_tt); + } + } + }); + } + + static void QuantizeColumnWisePackUnaligned(const Tin *src, + Tin *scales, + uint8_t *zero_points, + uint8_t *dst, + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + // Thread block is [quant_block_size * 2, columns], so the packed bytes do not cross threads. + constexpr auto minf = std::numeric_limits::lowest(); + constexpr auto maxf = std::numeric_limits::max(); + auto row_thread_blk_size = quant_block_size * 2; + auto num_row_thread_blk = (rows + row_thread_blk_size - 1) / (row_thread_blk_size); + + MlasTryBatchParallel(thread_pool, static_cast(num_row_thread_blk), [&](ptrdiff_t thread_blk_idx) { + constexpr int32_t buffer_size = 128; + float reciprocal_scale_t[buffer_size]; + uint8_t zp_t[buffer_size]; + float vmin_t[buffer_size]; + float vmax_t[buffer_size]; + + auto row_thread_blk_idx = static_cast(thread_blk_idx); + int32_t row_idx = row_thread_blk_idx * row_thread_blk_size; + int32_t row_idx_end = std::min(row_thread_blk_size + row_idx, rows); + auto input_idx = row_idx * columns; + auto scale_idx = row_thread_blk_idx * 2 * columns; + Tin scale0_tt, scale1_tt; + uint8_t v0_tt, v1_tt; + + for (; row_idx < row_idx_end; row_idx += quant_block_size) { + // per quant block row + auto quant_row_size = std::min(quant_block_size, row_idx_end - row_idx); + auto input_buffer_idx = input_idx; + auto scale_buffer_idx = scale_idx; + for (int32_t buffer_idx = 0; buffer_idx < columns; buffer_idx += buffer_size) { + // per buffer column + auto buffer_col_size = std::min(buffer_size, columns - buffer_idx); + + std::fill_n(vmin_t, buffer_size, maxf); + std::fill_n(vmax_t, buffer_size, minf); + // calculate min/max of [quant block, buffer] + auto input_idx_t = input_buffer_idx; + for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { + // TODO(fajin): use SIMD + for (int32_t i = 0; i < buffer_col_size; ++i) { + auto v = static_cast(src[input_idx_t + i]); + vmin_t[i] = std::min(vmin_t[i], v); + vmax_t[i] = std::max(vmax_t[i], v); + } + } + + // calculate scale and zero point + auto scale_buffer_idx_end = scale_buffer_idx + buffer_col_size; + int32_t col_idx = 0; + // leading unailgned zero points + if (scale_buffer_idx & 1) { + v0_tt = BitsTraits<4>::kMid; + if (zero_points) { + range2scalezp(vmin_t[0], vmax_t[0], scale0_tt, v0_tt); + zero_points[scale_buffer_idx >> 1] = SetElem(v0_tt, 1, zero_points[scale_buffer_idx >> 1]); + } else { + range2scale(vmin_t[0], vmax_t[0], scale0_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + + float scalef = static_cast(scale0_tt); + reciprocal_scale_t[0] = scalef ? 1.0f / scalef : 0.0f; + zp_t[0] = v0_tt; + + ++col_idx; + ++scale_buffer_idx; + } + // aligned zero points + for (; scale_buffer_idx < scale_buffer_idx_end - 1; col_idx += 2, scale_buffer_idx += 2) { + v0_tt = v1_tt = BitsTraits<4>::kMid; + if (zero_points) { + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + range2scalezp(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt); + zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); + } else { + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + scales[scale_buffer_idx + 1] = scale1_tt; + + float scalef0 = static_cast(scale0_tt); + reciprocal_scale_t[col_idx] = scalef0 ? 1.0f / scalef0 : 0.0f; + zp_t[col_idx] = v0_tt; + + float scalef1 = static_cast(scale1_tt); + reciprocal_scale_t[col_idx + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; + zp_t[col_idx + 1] = v1_tt; + } + // tailing unaligned elements + if (scale_buffer_idx < scale_buffer_idx_end) { + v0_tt = BitsTraits<4>::kMid; + if (zero_points) { + range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); + zero_points[scale_buffer_idx >> 1] = SetElem(v0_tt, 0, zero_points[scale_buffer_idx >> 1]); + } else { + range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); + } + + scales[scale_buffer_idx] = scale0_tt; + + float scalef = static_cast(scale0_tt); + reciprocal_scale_t[col_idx] = scalef ? 1.0f / scalef : 0.0f; + zp_t[col_idx] = v0_tt; + + ++scale_buffer_idx; + } + + // quantize and pack + input_idx_t = input_buffer_idx; + for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { + auto input_idx_t_start = input_idx_t; + auto input_idx_t_end = input_idx_t + buffer_col_size; + col_idx = 0; + // leading unaligned output + if (input_idx_t_start & 1) { + v1_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + dst[input_idx_t_start >> 1] = SetElem(v1_tt, 1, dst[input_idx_t_start >> 1]); + + ++col_idx; + ++input_idx_t_start; + } + // aligned output + // TODO(fajin): use SIMD + for (; input_idx_t_start < input_idx_t_end - 1; col_idx += 2, input_idx_t_start += 2) { + v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + v1_tt = QuantizeV(src[input_idx_t_start + 1], reciprocal_scale_t[col_idx + 1], zp_t[col_idx + 1]); + + dst[input_idx_t_start >> 1] = Pack(v0_tt, v1_tt); + } + // tailing unaligned output + if (input_idx_t_start < input_idx_t_end) { + v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); + dst[input_idx_t_start >> 1] = SetElem(v0_tt, 0, dst[input_idx_t_start >> 1]); + } + } + + input_buffer_idx += buffer_size; + } + + input_idx += quant_block_size * columns; + scale_idx += columns; + } + }); + } + + static void TransposeColumnWiseQuantizedPackAligned(const uint8_t *src_weights, // [rows, columns / 2] + const Tin *src_scales, // [ceil(rows / quant_block_size), columns] + const uint8_t *src_zero_points, // [ceil(rows / quant_block_size), columns / 2] + uint8_t *dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] + Tin *dst_scales, // [columns, ceil(rows / quant_block_size)] + uint8_t *dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + if (columns % 2 != 0) { + throw std::runtime_error("Columns must be multiple of 2"); + } + + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; + // number of rows in transposed dst + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + auto packed_col_size = columns / 2; + + // weight transpose thread block is [dst_bytes_per_quant_blk, 2] on dst_Transpose. + // Map to src it is [quant_block_size, 1]. Both in uint8_t. + auto num_thread_blk = row_quant_blk_num * packed_col_size; + MlasTryBatchParallel(thread_pool, static_cast(num_thread_blk), [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + uint8_t dst0_t, dst1_t; + + auto row_thread_blk_idx = static_cast(thread_blk_idx / packed_col_size); + auto col_thread_blk_idx = static_cast(thread_blk_idx % packed_col_size); + + auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; + auto dstT_col_idx = col_thread_blk_idx * 2; + auto dst_idx = dstT_col_idx * dstT_num_row + dstT_row_idx; + + auto src_row_idx = row_thread_blk_idx * quant_block_size; + auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); + auto src_col_idx = col_thread_blk_idx; + auto src_idx = src_row_idx * packed_col_size + src_col_idx; + auto src_end_idx = src_row_end_idx * packed_col_size + src_col_idx; + + for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { + src0_t = src_weights[src_idx]; + src1_t = src_weights[src_idx + packed_col_size]; + src_idx += packed_col_size + packed_col_size; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_weights[dst_idx] = dst0_t; + dst_weights[dst_idx + dstT_num_row] = dst1_t; + } + + if (src_idx < src_end_idx) { + src0_t = src_weights[src_idx]; + src1_t = 0; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_weights[dst_idx] = dst0_t; + dst_weights[dst_idx + dstT_num_row] = dst1_t; + } + }); + + // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. + MlasTryBatchParallel(thread_pool, static_cast(columns), [&](ptrdiff_t thread_blk_idx) { + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + }); + + if (src_zero_points) { + // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 2] + // on dst_Transpose. Map to src it is [row_quant_blk_num, 1]. Both in uint8_t. + auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; + MlasTryBatchParallel(thread_pool, static_cast(packed_col_size), [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + uint8_t dst0_t, dst1_t; + + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto src_end_idx = row_quant_blk_num * packed_col_size + col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * 2 * dst_zp_row_num; + + for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { + src0_t = src_zero_points[src_idx]; + src1_t = src_zero_points[src_idx + packed_col_size]; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_zero_points[dst_idx] = dst0_t; + dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; + src_idx += packed_col_size + packed_col_size; + } + + if (src_idx < src_end_idx) { + src0_t = src_zero_points[src_idx]; + src1_t = 0; + Transpose(src0_t, src1_t, dst0_t, dst1_t); + dst_zero_points[dst_idx] = dst0_t; + dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; + } + }); + } + } + + static void + TransposeColumnWiseQuantizedPackUnaligned(const uint8_t *src_weights, // size of [ceil(rows * columns / 2)] + const Tin *src_scales, // [ceil(rows / quant_block_size), columns] + const uint8_t *src_zero_points, // size of [ceil(ceil(rows / quant_block_size) * columns / 2)] + uint8_t *dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] + Tin *dst_scales, // [columns, ceil(rows / quant_block_size)] + uint8_t *dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] + int32_t rows, + int32_t columns, + int32_t quant_block_size, + MLAS_THREADPOOL *thread_pool) { + auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; + auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; + // number of rows in transposed dst + auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; + + // weight transpose thread block is [dst_bytes_per_quant_blk, 1] on dst_Transpose in uint8_t. + // Map to src it is [quant_block_size, 1] in int4. + auto num_thread_blk = row_quant_blk_num * columns; + MlasTryBatchParallel(thread_pool, static_cast(num_thread_blk), [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + + auto row_thread_blk_idx = static_cast(thread_blk_idx / columns); + auto col_thread_blk_idx = static_cast(thread_blk_idx % columns); + + auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; + auto dst_idx = col_thread_blk_idx * dstT_num_row + dstT_row_idx; + + auto src_row_idx = row_thread_blk_idx * quant_block_size; + auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); + auto src_idx = src_row_idx * columns + col_thread_blk_idx; + auto src_end_idx = src_row_end_idx * columns + col_thread_blk_idx; + + for (; src_idx < src_end_idx - columns; ++dst_idx) { + src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); + src1_t = GetElem(src_weights[(src_idx + columns) >> 1], (src_idx + columns) & 1); + dst_weights[dst_idx] = (src0_t & 0xf) | ((src1_t & 0xf) << 4); + src_idx += columns + columns; + } + + if (src_idx < src_end_idx) { + src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); + dst_weights[dst_idx] = src0_t & 0xf; + } + }); + + // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. + MlasTryBatchParallel(thread_pool, static_cast(columns), [&](ptrdiff_t thread_blk_idx) { + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * row_quant_blk_num; + for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { + dst_scales[dst_idx] = src_scales[src_idx]; + } + }); + + if (src_zero_points) { + // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 1] on dst_Transpose in uint8_t. + // Map to src it is [row_quant_blk_num, 1] in int4. + auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; + MlasTryBatchParallel(thread_pool, static_cast(columns), [&](ptrdiff_t thread_blk_idx) { + uint8_t src0_t, src1_t; + + auto col_thread_blk_idx = static_cast(thread_blk_idx); + auto src_idx = col_thread_blk_idx; + auto src_end_idx = row_quant_blk_num * columns + col_thread_blk_idx; + auto dst_idx = col_thread_blk_idx * dst_zp_row_num; + + for (; src_idx < src_end_idx - columns; ++dst_idx) { + src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); + src1_t = GetElem(src_zero_points[(src_idx + columns) >> 1], (src_idx + columns) & 1); + dst_zero_points[dst_idx] = (src0_t & 0xf) | ((src1_t & 0xf) << 4); + src_idx += columns + columns; + } + + if (src_idx < src_end_idx) { + src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); + dst_zero_points[dst_idx] = src0_t & 0xf; + } + }); + } + } +}; + +template +void MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int columns, int &meta_rows, int &meta_cols) { + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + +template +void MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int &q_rows, int &q_cols) { + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + default: + q_rows = 0; + q_cols = 0; + break; + } +} + +template void MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int columns, int &meta_rows, int &meta_cols); + +template void MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int columns, int &meta_rows, int &meta_cols); + +template void MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int &q_rows, int &q_cols); + +template void MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int &q_rows, int &q_cols); + +void MLASCALL MlasBlockwiseQuantizedBufferSizes(int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t &q_data_size_in_bytes, + size_t &q_scale_num_elements, + size_t *q_zero_point_size_in_bytes) { + q_data_size_in_bytes = q_scale_num_elements = 0; + if (q_zero_point_size_in_bytes) { + *q_zero_point_size_in_bytes = 0; + } + + if (qbits == 4) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } else { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } else { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } else { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } else { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } else { + BlockwiseQuantizer::quantizedBufferSizes(rows, + columns, + q_data_size_in_bytes, + q_scale_num_elements, + q_zero_point_size_in_bytes); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } + } +} + +template +void MlasQuantizeBlockwise(uint8_t *dst, + T *scales, + uint8_t *zero_points, + const T *src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL *thread_pool) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose(dst, + scales, + zero_points, + src, + rows, + columns, + leading_dimension, + thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template void MlasQuantizeBlockwise(uint8_t *dst, + float *scales, + uint8_t *zero_points, + const float *src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL *thread_pool); + +template void MlasQuantizeBlockwise(uint8_t *dst, + MLAS_FP16 *scales, + uint8_t *zero_points, + const MLAS_FP16 *src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL *thread_pool); + +template +void MlasDequantizeBlockwise(T *dst, + const uint8_t *src, + const T *scales, + const uint8_t *zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL *thread_pool) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template void MlasDequantizeBlockwise(float *dst, + const uint8_t *src, + const float *scales, + const uint8_t *zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL *thread_pool); + +template +void MlasQDQQuantizeBlockwise(const Tin *src, + Tin *scales, + uint8_t *zero_points, + uint8_t *dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool) { + if (columnwise) { + BlockwiseQDQQuantizer::QuantizeColumnWise(src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool); + } else { + BlockwiseQDQQuantizer::QuantizeRowWise(src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool); + } +} + +template void MlasQDQQuantizeBlockwise(const float *src, + float *scales, + uint8_t *zero_points, + uint8_t *dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool); + +template void MlasQDQQuantizeBlockwise(const MLAS_FP16 *src, + MLAS_FP16 *scales, + uint8_t *zero_points, + uint8_t *dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool); + +template +void MlasQDQTransposeBlockwiseQuantized(const uint8_t *src_weights, + const Tin *src_scales, + const uint8_t *src_zero_points, + uint8_t *dst_weights, + Tin *dst_scales, + uint8_t *dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool) { + if (columnwise) { + BlockwiseQDQQuantizer::TransposeColumnWiseQuantized(src_weights, + src_scales, + src_zero_points, + dst_weights, + dst_scales, + dst_zero_points, + rows, + columns, + quant_block_size, + thread_pool); + } else { + throw std::runtime_error("Row-wise MlasQDQTransposeBlockwiseQuantized is not implemented"); + } +} + +template void MlasQDQTransposeBlockwiseQuantized(const uint8_t *src_weights, + const float *src_scales, + const uint8_t *src_zero_points, + uint8_t *dst_weights, + float *dst_scales, + uint8_t *dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool); + +template void MlasQDQTransposeBlockwiseQuantized(const uint8_t *src_weights, + const MLAS_FP16 *src_scales, + const uint8_t *src_zero_points, + uint8_t *dst_weights, + MLAS_FP16 *dst_scales, + uint8_t *dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL *thread_pool); diff --git a/third_party/mlas/lib/q4_dq_cli.cpp b/third_party/mlas/lib/q4_dq_cli.cpp new file mode 100644 index 0000000000..9c330b9eaf --- /dev/null +++ b/third_party/mlas/lib/q4_dq_cli.cpp @@ -0,0 +1,304 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4_dq_cli.cpp + +Abstract: + + This module implements a command line tool that quantize fp32 into int4, + or reverse this process.. + +--*/ + +#include "mlas_q4.h" + +#include +#include +#include +#include +#include +#include + +char* +getCmdOption(char** begin, char** end, const std::string& option) +{ + char** itr = std::find(begin, end, option); + if (itr != end && ++itr != end) { + return *itr; + } + return nullptr; +} + +void +usage(const char* cli) +{ + std::cout << std::endl; + std::cout << "This utility performs int4 quantize and dequantize of a matrix, usage: " << std::endl; + std::cout << " " << cli << " ACTION NUM_ROWS NUM_COLS [OPTIONS]" << std::endl; + std::cout << " ACTION: can be either q (quantize) or dq (de-quantize)." << std::endl; + std::cout << " NUM_ROWS: number of rows in the matrix." << std::endl; + std::cout << " NUM_COLS: number of columns in the matrix." << std::endl; + std::cout << "options:" << std::endl; + std::cout << " --quant_type {0, 1}." << std::endl; + std::cout << " Type of the block quantization." << std::endl; + std::cout << " 0: Symmetric block quant, with fp32 scale." << std::endl; + std::cout << " 1: (default) Block quant with fp32 scale and int8 zero-point." << std::endl; + std::cout << " --input_file {PATH}." << std::endl; + std::cout << " Path to the input file." << std::endl; + std::cout << " --input_offset {N}." << std::endl; + std::cout << " Skip the first N bytes when reading the input file." << std::endl; + std::cout << " Ignored when read from std in." << std::endl; + std::cout << " --output_file {PATH}." << std::endl; + std::cout << " Path to the output file. Write to std out when missing" << std::endl; + std::cout << " --output_format {txt,bin}" << std::endl; + std::cout << " txt: (default) text format: space separated numbers." << std::endl; + std::cout << " bin: Binary format, can not be output to std out." << std::endl; + std::cout << std::endl; +} + + +// +// Variable for commands +// +struct Cli { + bool dqmode = false; // false -> quantize, true -> dequantize + + size_t num_rows = 0; + size_t num_cols = 0; + + MLAS_BLK_QUANT_TYPE quant_type = BlkQ4Zp8; + + char* input_file = nullptr; + size_t input_offset = 0; + + char* output_file = nullptr; + bool output_bin = false; // false -> csv, true -> binary +}; + + +bool +parseArgs(int argc, char* argv[], Cli& cli) +{ + if (argc < 4) { + return false; + } + + if (strncmp(argv[1], "q", 2) == 0) { + cli.dqmode = false; + } else if (strncmp(argv[1], "dq", 3) == 0) { + cli.dqmode = true; + } else { + return false; + } + + errno = 0; + cli.num_rows = (size_t)strtoul(argv[2], nullptr, 0); + if (cli.num_rows == 0 || errno != 0) { + return false; + } + cli.num_cols = (size_t)strtoul(argv[3], nullptr, 0); + if (cli.num_cols == 0 || errno != 0) { + return false; + } + + char* quant_t = getCmdOption(argv + 4, argv + argc, "--quant_type"); + if (quant_t) { + if (strncmp(quant_t, "0", 2) == 0) { + cli.quant_type = BlkQ4Sym; + } + } + + cli.input_file = getCmdOption(argv + 4, argv + argc, "--input_file"); + char* offset_str = getCmdOption(argv + 4, argv + argc, "--input_offset"); + if (offset_str != nullptr) { + errno = 0; + cli.input_offset = (size_t)strtoul(offset_str, nullptr, 0); + if (errno != 0) { + return false; + } + } + + cli.output_file = getCmdOption(argv + 4, argv + argc, "--output_file"); + char* output_format_str = getCmdOption(argv + 4, argv + argc, "--output_format"); + if (output_format_str != nullptr) { + if (strncmp(output_format_str, "csv", 4) == 0) { + cli.output_bin = false; + } else if (strncmp(output_format_str, "bin", 4) == 0) { + cli.output_bin = true; + if (!cli.output_file) { + // can't dump binary file to std-out + return false; + } + } else { + return false; + } + } + return true; +} + + +void +readBinFile(const char* filename, size_t start, size_t expected_size, std::vector& buf) +{ + // open the file: + std::streampos fileSize; + std::ifstream file(filename, std::ios::binary); + + // get its size: + file.seekg(0, std::ios::end); + fileSize = file.tellg(); + file.seekg(0, std::ios::beg); + + file.seekg(start); + fileSize -= start; + if ((size_t)fileSize < expected_size) { + return; + } + + // read the data: + buf.resize(expected_size); + file.read((char*)buf.data(), expected_size); +} + + +void +writeUint8Txt(std::ostream& out, const uint8_t* data, size_t len) +{ + for (size_t i = 0; i < len; i++) { + out << (int)data[i] << " "; + if (((i+1) % 21 == 0)) { + out << std::endl; + } + } + out << std::endl; +} + + +int +quantize(const Cli& cli) +{ + std::vector srcbuf; + readBinFile(cli.input_file, cli.input_offset, cli.num_rows * cli.num_cols * sizeof(float), srcbuf); + if (srcbuf.size() == 0) { + std::cerr << "Failed to read expected amount of data from file " << cli.input_file + << std::endl; + return -1; + } + + size_t qsize = MlasQ4GemmPackBSize(cli.quant_type, cli.num_cols, cli.num_rows); + if (qsize == 0) { + std::cerr << "Int4 Quantization not yet supported on this platform!"; + return -1; + } + std::vector dstbuf(qsize); + MlasQ4GemmPackB(cli.quant_type, dstbuf.data(), (const float*)srcbuf.data(), cli.num_cols, + cli.num_rows, cli.num_cols); + + if (cli.output_bin) { + std::ofstream out(cli.output_file, std::ios::out | std::ios::binary); + if (!out) { + std::cerr << "Cannot open output file " << cli.output_file << std::endl; + return -1; + } + out.write((const char*)dstbuf.data(), dstbuf.size()); + } else { + std::streambuf* buf; + if (cli.output_file) { + std::ofstream out(cli.output_file, std::ios::out); + if (!out) { + std::cerr << "Cannot open output file " << cli.output_file << std::endl; + return -1; + } + buf = out.rdbuf(); + } else { + buf = std::cout.rdbuf(); + } +#if defined(__GNUC__) && __GNUC__ >= 12 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored \ + "-Wdangling-pointer" // TODO: suppress warning about dangling pointer until we have a fix + std::ostream stream(buf); +#pragma GCC diagnostic pop +#else + std::ostream stream(buf); +#endif + + writeUint8Txt(stream, dstbuf.data(), dstbuf.size()); + } + return 0; +} + +int +dequantize(const Cli& cli) +{ + size_t qsize = MlasQ4GemmPackBSize(cli.quant_type, cli.num_cols, cli.num_rows); + if (qsize == 0) { + std::cerr << "Int4 Quantization not yet supported on this platform!"; + return -1; + } + std::vector srcbuf; + readBinFile(cli.input_file, cli.input_offset, qsize, srcbuf); + if (srcbuf.size() == 0) { + std::cerr << "Failed to read expected amount of data from file " << cli.input_file + << std::endl; + return -1; + } + + std::vector dstbuf(cli.num_rows * cli.num_cols); + MlasQ4GemmUnPackB(cli.quant_type, dstbuf.data(), srcbuf.data(), cli.num_cols, cli.num_rows, + cli.num_cols); + + if (cli.output_bin) { + std::ofstream out(cli.output_file, std::ios::out | std::ios::binary); + if (!out) { + std::cerr << "Cannot open output file " << cli.output_file << std::endl; + return -1; + } + out.write((const char*)dstbuf.data(), std::streamsize(dstbuf.size()) * sizeof(float)); + } else { + std::streambuf* buf; + std::ofstream file_output_stream; + if (cli.output_file) { + file_output_stream.open(cli.output_file, std::ios::out); + if (file_output_stream.fail()) { + std::cerr << "Cannot open output file " << cli.output_file << std::endl; + return -1; + } + buf = file_output_stream.rdbuf(); + } else { + buf = std::cout.rdbuf(); + } + std::ostream stream(buf); + size_t lcount = 0; + for (float v : dstbuf) { + stream << v << " "; + if (++lcount >= 16) { + stream << std::endl; + lcount = 0; + } + } + stream << std::endl; + } + return 0; +} + + +int +main(int argc, char* argv[]) +{ + Cli cli; + if (!parseArgs(argc, argv, cli)) { + usage(argv[0]); + return -1; + } + if (cli.dqmode) { + return dequantize(cli); + } else { + return quantize(cli); + } +} diff --git a/third_party/mlas/lib/q4common.h b/third_party/mlas/lib/q4common.h new file mode 100644 index 0000000000..532437797a --- /dev/null +++ b/third_party/mlas/lib/q4common.h @@ -0,0 +1,154 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4common.h + +Abstract: + + Define int4 block quantization types. + + Int4 block quantization is used to compress weight tensors of large + language models. It takes a number (must be multiple of 32) of floating + point values, calculates their quantization parameters, and saves + the parameters and the quantized data in a blob. +--*/ + +#include "mlas_q4.h" +#include "mlasi.h" + +#include +#include + +// +// Functions for locating data from a quantized blob +// +template +MLAS_FORCEINLINE +float& +MlasQ4BlkScale(uint8_t* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +template +MLAS_FORCEINLINE +float +MlasQ4BlkScale(const uint8_t* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +template +uint8_t& +MlasQ4BlkZeroPoint(uint8_t* BlkPtr); + +template +uint8_t +MlasQ4BlkZeroPoint(const uint8_t* BlkPtr); + +template +MLAS_FORCEINLINE +uint8_t* +MlasQ4BlkData(uint8_t* BlkPtr) +{ + return BlkPtr + sizeof(float); +} + +template +MLAS_FORCEINLINE +const uint8_t* +MlasQ4BlkData(const uint8_t* BlkPtr) +{ + return BlkPtr + sizeof(float); +} + +/** + * @brief Every block quantization type, its block size (BlkLen) + * Must be multiple of 32! + */ +constexpr size_t MLAS_QUANT4_BLK_UNIT = 32; + +/** + * @brief Representing int4 quantize type, block quant type 0: + * + * Block size 32, use 32 fp32 numbers to find quantization parameter: + * scale (fp 32) and no zero point, then quantize the numbers + * into int4. The resulting blob takes 16 + 4 = 20 bytes. + */ +struct MLAS_Q4TYPE_BLK0 { + static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT; + static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); +}; + +/** + * @brief Representing int4 quantize type, block quant type 1: + * + * Block size 32, use 32 fp32 numbers to find quantization parameter: + * scale (fp 32) and zero point (int8), and then quantize the numbers + * into int4. The resulting blob takes 16 + 5 = 21 bytes. + * + * So far this is the only type that includes a zero-point value. + * Maybe we should consider store the quantization parameters seperatedly. + */ +struct MLAS_Q4TYPE_BLK1 { + static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT; + static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float) + sizeof(uint8_t); +}; + +template<> +inline uint8_t& +MlasQ4BlkZeroPoint(uint8_t* BlkPtr) +{ + return *(BlkPtr + sizeof(float)); +} + +template<> +inline uint8_t +MlasQ4BlkZeroPoint(const uint8_t* BlkPtr) +{ + return *(BlkPtr + sizeof(float)); +} + +template<> +inline uint8_t* +MlasQ4BlkData(uint8_t* BlkPtr) +{ + return BlkPtr + sizeof(float) + sizeof(uint8_t); +} + +template<> +inline const uint8_t* +MlasQ4BlkData(const uint8_t* BlkPtr) +{ + return BlkPtr + sizeof(float) + sizeof(uint8_t); +} + +/** + * @brief Representing int4 quantize type, block quant type 2: + * + * Block size 64, use 64 fp32 numbers to find quantization parameter: + * scale (fp 32) and no zero point, then quantize the numbers + * into int4. The resulting blob takes 32 + 4 = 36 bytes. + */ +struct MLAS_Q4TYPE_BLK2 { + static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT * 2; + static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); +}; + + +/** + * @brief Representing int4 quantize type, block quant type 4: + * + * Block size 128, use 128 fp32 numbers to find quantization parameter: + * scale (fp 32) and no zero point, then quantize the numbers + * into int4. The resulting blob takes 32 + 4 = 36 bytes. + */ +struct MLAS_Q4TYPE_BLK4 { + static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT * 4; + static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); +}; diff --git a/third_party/mlas/lib/q4gemm.cpp b/third_party/mlas/lib/q4gemm.cpp new file mode 100644 index 0000000000..a734f53432 --- /dev/null +++ b/third_party/mlas/lib/q4gemm.cpp @@ -0,0 +1,179 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4gemm.cpp + +Abstract: + + This module implements the fp32 matrix multiplication with compressed + weight tensor (right hand side). The assumption is the right hand side + tensor can be pre-packed and compressed using int-4 quantization to save + memory. +--*/ + +#include "q4gemm.h" + + +size_t +MLASCALL +MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K) +{ + if (GetMlasPlatform().Q8Q4GemmDispatch == nullptr) { + return 0; + } + switch (QType) { + case BlkQ4Zp8: + return MlasQ80BlkQuantSizeImpl(M, K); + case BlkQ4Sym64: + return MlasQ80BlkQuantSizeImpl(M, K); + case BlkQ4Sym128: + return MlasQ80BlkQuantSizeImpl(M, K); + default: + return MlasQ80BlkQuantSizeImpl(M, K); + } +} + + +void +MLASCALL +MlasQ80BlkQuant( + MLAS_BLK_QUANT_TYPE QType, + void* Qblob, + const float* A, + size_t M, + size_t K, + size_t lda, + MLAS_THREADPOOL* ThreadPool + ) +{ + auto* dispatch = GetMlasPlatform().Q8Q4GemmDispatch; + dispatch->Quants[QType](Qblob, A, M, K, lda, ThreadPool); +} + + +template +MLAS_FORCEINLINE +void +MlasQ4GemmBatchDriver( + MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const ParamBlockType* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + //const MLAS_Q4GEMM_DISPATCH* dispatch = MlasQ4GemmGetDispatch(); + //MLAS_Q4GEMM_OPERATION* operation = dispatch->Operation; + void (*operation)(const size_t, const ParamBlockType*, const size_t, const size_t, const size_t, + const size_t) = nullptr; + + if constexpr (std::is_same_v) + { + operation = GetMlasPlatform().FpQ4GemmDispatch->Operations[QType]; + } + else { + operation = GetMlasPlatform().Q8Q4GemmDispatch->Operations[QType]; + } + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + + +void +MLASCALL +MlasQ4GemmBatch( + MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); +} + +void +MLASCALL +MlasQ8Q4GemmBatch( + MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); +} diff --git a/third_party/mlas/lib/q4gemm.h b/third_party/mlas/lib/q4gemm.h new file mode 100644 index 0000000000..d16798eb89 --- /dev/null +++ b/third_party/mlas/lib/q4gemm.h @@ -0,0 +1,288 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4gemm.h + +Abstract: + + int4 block quantization gemm kernel template declarations. + + Int4 block quantization is used to compress weight tensors of large + language models. It takes a number (must be multiple of 32) of floating + point values, calculates their quantization parameters, and saves + the parameters and the quantized data in a blob. +--*/ + +#include "q4common.h" + + +template +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernel( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias +); + +template +MLAS_FORCEINLINE +void +MlasBlkQ4DequantB(float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); + + +template +MLAS_FORCEINLINE void +AddBiasAvx(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc); + + + +template +void MLASCALL +MlasQ4GemmOperation( + const size_t K, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, Q4TYPE::BlkLen); + const size_t ldb = k_blks * Q4TYPE::BlobSize; + const float* A = DataParams->A + RangeStartM * lda; + const uint8_t* PackedB = (const uint8_t*)DataParams->B; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* Bias = DataParams->Bias; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, (size_t)128); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; + const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; + float* c_blk = C + n; + const float* a_row = A; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + auto RowsHandled = MlasQ4GemmKernel( + a_row, b_col, c_blk, RowsRemaining, CountN, K, lda, ldb, ldc, bias); + + if (DataParams->OutputProcessor != nullptr) { + DataParams->OutputProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * Q4TYPE::BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, (size_t)StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; + const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; + float* c_blk = C + n; + const float* a_row = A; + + MlasBlkQ4DequantB(dequant_b, b_col, CountN, K, ldb); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, + CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasAvx(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->OutputProcessor != nullptr) { + DataParams->OutputProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +typedef +void +(MLAS_Q4GEMM_OPERATION)( + const size_t K, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + +struct MLAS_FPQ4GEMM_DISPATCH { + MLAS_Q4GEMM_OPERATION** Operations; +}; + +/** + * @brief Compute the size of a quantized block, one byte per value + fp32 scale + * @tparam QType + * @return + */ +template +constexpr size_t +Q8BlobUnitSize() +{ + return (QType::BlkLen + sizeof(float)); +} + +template +constexpr size_t +MlasQ80BlkQuantSizeImpl(size_t M, size_t K) +{ + const size_t KBlocks = MlasDivRoundup(K, QType::BlkLen); + + const size_t NumBlocks = M * KBlocks; + + return NumBlocks * Q8BlobUnitSize(); +} + +typedef +void +(MLAS_Q80_BLKQUANT)( + void* Qblob, + const float* A, + size_t M, + size_t K, + size_t lda, + MLAS_THREADPOOL* ThreadPool + ); + +template +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernel( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ); + + +template +void MLASCALL +MlasQ8Q4GemmOperation( + const size_t K, + const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t k_blks = MlasDivRoundup(K, Q4TYPE::BlkLen); + const size_t ldb = k_blks * Q4TYPE::BlobSize; + const size_t lda = k_blks * Q8BlobUnitSize(); + const size_t ldc = DataParams->ldc; + + const int8_t* A = reinterpret_cast(DataParams->A) + RangeStartM * lda; + const uint8_t* PackedB = (const uint8_t*)DataParams->B; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* Bias = DataParams->Bias; + + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, (size_t)128); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; + const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; + float* c_blk = C + n; + const int8_t* a_row = A; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + auto RowsHandled = MlasQ8Q4GemmKernel( + a_row, b_col, c_blk, RowsRemaining, CountN, K, lda, ldb, ldc, bias); + + if (DataParams->OutputProcessor != nullptr) { + DataParams->OutputProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, DataParams->ldc); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +typedef +void +(MLAS_Q8Q4GEMM_OPERATION)( + const size_t K, + const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + +struct MLAS_Q8Q4GEMM_DISPATCH { + MLAS_Q80_BLKQUANT** Quants; + MLAS_Q8Q4GEMM_OPERATION** Operations; +}; diff --git a/third_party/mlas/lib/q4gemm_avx512.cpp b/third_party/mlas/lib/q4gemm_avx512.cpp new file mode 100644 index 0000000000..f7af82ed12 --- /dev/null +++ b/third_party/mlas/lib/q4gemm_avx512.cpp @@ -0,0 +1,1509 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4gemm_avx512.cpp + +Abstract: + + This module implements the fp32 matrix multiplication with compressed + weight tensor (right hand side). The assumption is the right hand side + tensor can be pre-packed and compressed using int-4 quantization to save + memory. + Specificially on x64 avx512 +--*/ + +#include "q4gemm.h" + +#include +#include + +struct MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI { + static constexpr size_t StrideM = 256; +}; + +/** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 +FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) +{ + __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); + __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); + __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); + __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); + + __m512 acc_lo0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); + __m512 acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + + __m256 acc_y = + _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); + return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); +} + + +template +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernelAvx512f( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + // We process 32 quantized values in a batch. + static_assert(MLAS_QUANT4_BLK_UNIT == 32); + static_assert(Q4Type::BlkLen % MLAS_QUANT4_BLK_UNIT == 0); + + const __m256i lowMask = _mm256_set1_epi8(0xF); + + for (size_t m = 0; m < CountM; m++) { + const auto* b_col = PackedB; + auto* sum_ptr = C; + const auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN) - 4; + while (nblk >= 0) { + __m512 acc_lo0 = _mm512_setzero_ps(); + __m512 acc_lo1 = _mm512_setzero_ps(); + __m512 acc_lo2 = _mm512_setzero_ps(); + __m512 acc_lo3 = _mm512_setzero_ps(); + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + size_t ck = std::min(CountK - k, Q4Type::BlkLen); + + const float scale_v0 = MlasQ4BlkScale(b); + const float scale_v1 = MlasQ4BlkScale(b + ldb); + const float scale_v2 = MlasQ4BlkScale(b + ldb * 2); + const float scale_v3 = MlasQ4BlkScale(b + ldb * 3); + + const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); + const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); + const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); + const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); + + // Load A row vectors + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + // Load B col vectors + const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); + const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); + const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); + const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); + + // expand 4b into byte array + __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); + __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); + __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); + __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + bytes0 = _mm256_sub_epi8( + bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); + bytes1 = _mm256_sub_epi8( + bytes1, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + const __m256i vx16_hi0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + const __m256i vx16_lo1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + const __m256i vx16_hi1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + const __m256i vx16_lo2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + const __m256i vx16_hi2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + const __m256i vx16_lo3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + const __m256i vx16_hi3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + __m512 s = _mm512_set1_ps(scale_v0); + bvf_lo0 = _mm512_mul_ps(bvf_lo0, s); + bvf_hi0 = _mm512_mul_ps(bvf_hi0, s); + s = _mm512_set1_ps(scale_v1); + bvf_lo1 = _mm512_mul_ps(bvf_lo1, s); + bvf_hi1 = _mm512_mul_ps(bvf_hi1, s); + s = _mm512_set1_ps(scale_v2); + bvf_lo2 = _mm512_mul_ps(bvf_lo2, s); + bvf_hi2 = _mm512_mul_ps(bvf_hi2, s); + s = _mm512_set1_ps(scale_v3); + bvf_lo3 = _mm512_mul_ps(bvf_lo3, s); + bvf_hi3 = _mm512_mul_ps(bvf_hi3, s); + + acc_lo0 = _mm512_fmadd_ps(bvf_lo0, av_lo, acc_lo0); + acc_lo0 = _mm512_fmadd_ps(bvf_hi0, av_hi, acc_lo0); + acc_lo1 = _mm512_fmadd_ps(bvf_lo1, av_lo, acc_lo1); + acc_lo1 = _mm512_fmadd_ps(bvf_hi1, av_hi, acc_lo1); + acc_lo2 = _mm512_fmadd_ps(bvf_lo2, av_lo, acc_lo2); + acc_lo2 = _mm512_fmadd_ps(bvf_hi2, av_hi, acc_lo2); + acc_lo3 = _mm512_fmadd_ps(bvf_lo3, av_lo, acc_lo3); + acc_lo3 = _mm512_fmadd_ps(bvf_hi3, av_hi, acc_lo3); + } + + b += Q4Type::BlobSize; + } + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + b_col += 4 * ldb; + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m512 acc_lo[4]{}; + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + size_t ck = std::min(CountK - k, Q4Type::BlkLen); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + const auto* bb = b + ldb * nn; + scale_v[nn] = MlasQ4BlkScale(bb); + b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); + } + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); + + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 + ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + for (int64_t nn = 0; nn < nblk; nn++) { + const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); + __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bytes = _mm256_and_si256(lowMask, bytes); + + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + const auto* bb = b + ldb * nn; + const uint8_t zp = MlasQ4BlkZeroPoint(bb); + bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes = _mm256_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); + const __m256i vx16_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); + __m512 bvf_hi = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); + __m512 s = _mm512_set1_ps(scale_v[nn]); + bvf_lo = _mm512_mul_ps(bvf_lo, s); + bvf_hi = _mm512_mul_ps(bvf_hi, s); + + acc_lo[nn] = _mm512_fmadd_ps(bvf_lo, av_lo, acc_lo[nn]); + acc_lo[nn] = _mm512_fmadd_ps(bvf_hi, av_hi, acc_lo[nn]); + } + } + b += Q4Type::BlobSize; + } + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + A += lda; + } + return CountM; +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernel( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, + ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernel( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, + ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernel( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, + ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ4GemmKernel( + const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, + ldb, ldc, Bias); +} + + +MLAS_FORCEINLINE +void +Transpose16x16Avx512( + float* dest, + __m512i r0, + __m512i r1, + __m512i r2, + __m512i r3, + __m512i r4, + __m512i r5, + __m512i r6, + __m512i r7, + __m512i r8, + __m512i r9, + __m512i ra, + __m512i rb, + __m512i rc, + __m512i rd, + __m512i re, + __m512i rf) +{ + + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32( + r0, r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32( + r0, r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2, r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2, r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4, r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4, r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6, r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6, r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8, r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8, r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra, rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra, rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc, rd); // 196 ... + td = _mm512_unpackhi_epi32(rc, rd); // 198 ... + te = _mm512_unpacklo_epi32(re, rf); // 228 ... + tf = _mm512_unpackhi_epi32(re, rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0, t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0, t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1, t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1, t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4, t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4, t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5, t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5, t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8, ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8, ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9, tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9, tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc, te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc, te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td, tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td, tf); // 195 211 231 243 ... + + t0 = + _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + + _mm512_storeu_si512(dest, r0); + dest += 16; + _mm512_storeu_si512(dest, r1); + dest += 16; + _mm512_storeu_si512(dest, r2); + dest += 16; + _mm512_storeu_si512(dest, r3); + dest += 16; + _mm512_storeu_si512(dest, r4); + dest += 16; + _mm512_storeu_si512(dest, r5); + dest += 16; + _mm512_storeu_si512(dest, r6); + dest += 16; + _mm512_storeu_si512(dest, r7); + dest += 16; + _mm512_storeu_si512(dest, r8); + dest += 16; + _mm512_storeu_si512(dest, r9); + dest += 16; + _mm512_storeu_si512(dest, ra); + dest += 16; + _mm512_storeu_si512(dest, rb); + dest += 16; + _mm512_storeu_si512(dest, rc); + dest += 16; + _mm512_storeu_si512(dest, rd); + dest += 16; + _mm512_storeu_si512(dest, re); + dest += 16; + _mm512_storeu_si512(dest, rf); + dest += 16; +} + + +template +MLAS_FORCEINLINE +void +BlkQ4DequantBAvx512f( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + const __m256i lowMask = _mm256_set1_epi8(0xF); + + const auto* b_col = PackedB; + + int64_t nblk = (int64_t)(CountN)-16; + while (nblk >= 0) { + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + size_t ck = std::min(CountK - k, Q4Type::BlkLen); + + const float scale_v0 = MlasQ4BlkScale(b); + const float scale_v1 = MlasQ4BlkScale(b + ldb); + const float scale_v2 = MlasQ4BlkScale(b + ldb * 2); + const float scale_v3 = MlasQ4BlkScale(b + ldb * 3); + const float scale_v4 = MlasQ4BlkScale(b + ldb * 4); + const float scale_v5 = MlasQ4BlkScale(b + ldb * 5); + const float scale_v6 = MlasQ4BlkScale(b + ldb * 6); + const float scale_v7 = MlasQ4BlkScale(b + ldb * 7); + const float scale_v8 = MlasQ4BlkScale(b + ldb * 8); + const float scale_v9 = MlasQ4BlkScale(b + ldb * 9); + const float scale_va = MlasQ4BlkScale(b + ldb * 10); + const float scale_vb = MlasQ4BlkScale(b + ldb * 11); + const float scale_vc = MlasQ4BlkScale(b + ldb * 12); + const float scale_vd = MlasQ4BlkScale(b + ldb * 13); + const float scale_ve = MlasQ4BlkScale(b + ldb * 14); + const float scale_vf = MlasQ4BlkScale(b + ldb * 15); + + const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); + const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); + const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); + const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); + const __m128i* b4ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 4); + const __m128i* b5ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 5); + const __m128i* b6ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 6); + const __m128i* b7ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 7); + const __m128i* b8ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 8); + const __m128i* b9ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 9); + const __m128i* baptr = (const __m128i*)MlasQ4BlkData(b + ldb * 10); + const __m128i* bbptr = (const __m128i*)MlasQ4BlkData(b + ldb * 11); + const __m128i* bcptr = (const __m128i*)MlasQ4BlkData(b + ldb * 12); + const __m128i* bdptr = (const __m128i*)MlasQ4BlkData(b + ldb * 13); + const __m128i* beptr = (const __m128i*)MlasQ4BlkData(b + ldb * 14); + const __m128i* bfptr = (const __m128i*)MlasQ4BlkData(b + ldb * 15); + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); + + // Load B col vectors + const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); + const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); + const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); + const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); + + // expand 4b into byte array + __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); + __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); + __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); + __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + bytes0 = _mm256_sub_epi8( + bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); + bytes1 = _mm256_sub_epi8( + bytes1, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + __m256i vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + __m256i vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + __m256i vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + __m256i vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + __m256i vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + __m256i vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + __m256i vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + __m256i vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + __m512 s = _mm512_set1_ps(scale_v0); + bvf_lo0 = _mm512_mul_ps(bvf_lo0, s); + bvf_hi0 = _mm512_mul_ps(bvf_hi0, s); + s = _mm512_set1_ps(scale_v1); + bvf_lo1 = _mm512_mul_ps(bvf_lo1, s); + bvf_hi1 = _mm512_mul_ps(bvf_hi1, s); + s = _mm512_set1_ps(scale_v2); + bvf_lo2 = _mm512_mul_ps(bvf_lo2, s); + bvf_hi2 = _mm512_mul_ps(bvf_hi2, s); + s = _mm512_set1_ps(scale_v3); + bvf_lo3 = _mm512_mul_ps(bvf_lo3, s); + bvf_hi3 = _mm512_mul_ps(bvf_hi3, s); + + // Load B col vectors + const __m128i bvi4_4 = _mm_loadu_si128(b4ptr++); + const __m128i bvi4_5 = _mm_loadu_si128(b5ptr++); + const __m128i bvi4_6 = _mm_loadu_si128(b6ptr++); + const __m128i bvi4_7 = _mm_loadu_si128(b7ptr++); + + // expand 4b into byte array + bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_4, 4), bvi4_4); + bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_5, 4), bvi4_5); + bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_6, 4), bvi4_6); + bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_7, 4), bvi4_7); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + bytes0 = _mm256_sub_epi8( + bytes0, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 4))); + bytes1 = _mm256_sub_epi8( + bytes1, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 5))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 6))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 7))); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo4 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi4 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo5 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi5 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_lo6 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hi6 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lo7 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hi7 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + s = _mm512_set1_ps(scale_v4); + bvf_lo4 = _mm512_mul_ps(bvf_lo4, s); + bvf_hi4 = _mm512_mul_ps(bvf_hi4, s); + s = _mm512_set1_ps(scale_v5); + bvf_lo5 = _mm512_mul_ps(bvf_lo5, s); + bvf_hi5 = _mm512_mul_ps(bvf_hi5, s); + s = _mm512_set1_ps(scale_v6); + bvf_lo6 = _mm512_mul_ps(bvf_lo6, s); + bvf_hi6 = _mm512_mul_ps(bvf_hi6, s); + s = _mm512_set1_ps(scale_v7); + bvf_lo7 = _mm512_mul_ps(bvf_lo7, s); + bvf_hi7 = _mm512_mul_ps(bvf_hi7, s); + + // Load B col vectors + const __m128i bvi4_8 = _mm_loadu_si128(b8ptr++); + const __m128i bvi4_9 = _mm_loadu_si128(b9ptr++); + const __m128i bvi4_a = _mm_loadu_si128(baptr++); + const __m128i bvi4_b = _mm_loadu_si128(bbptr++); + + // expand 4b into byte array + bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_8, 4), bvi4_8); + bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_9, 4), bvi4_9); + bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_a, 4), bvi4_a); + bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_b, 4), bvi4_b); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + bytes0 = _mm256_sub_epi8( + bytes0, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 8))); + bytes1 = _mm256_sub_epi8( + bytes1, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 9))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 10))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 11))); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo8 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi8 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo9 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi9 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_loa = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hia = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lob = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hib = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + s = _mm512_set1_ps(scale_v8); + bvf_lo8 = _mm512_mul_ps(bvf_lo8, s); + bvf_hi8 = _mm512_mul_ps(bvf_hi8, s); + s = _mm512_set1_ps(scale_v9); + bvf_lo9 = _mm512_mul_ps(bvf_lo9, s); + bvf_hi9 = _mm512_mul_ps(bvf_hi9, s); + s = _mm512_set1_ps(scale_va); + bvf_loa = _mm512_mul_ps(bvf_loa, s); + bvf_hia = _mm512_mul_ps(bvf_hia, s); + s = _mm512_set1_ps(scale_vb); + bvf_lob = _mm512_mul_ps(bvf_lob, s); + bvf_hib = _mm512_mul_ps(bvf_hib, s); + + // Load B col vectors + const __m128i bvi4_c = _mm_loadu_si128(bcptr++); + const __m128i bvi4_d = _mm_loadu_si128(bdptr++); + const __m128i bvi4_e = _mm_loadu_si128(beptr++); + const __m128i bvi4_f = _mm_loadu_si128(bfptr++); + + // expand 4b into byte array + bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_c, 4), bvi4_c); + bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_d, 4), bvi4_d); + bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_e, 4), bvi4_e); + bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_f, 4), bvi4_f); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + bytes0 = _mm256_sub_epi8( + bytes0, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 12))); + bytes1 = _mm256_sub_epi8( + bytes1, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 13))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 14))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 15))); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_loc = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hic = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lod = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hid = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_loe = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hie = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lof = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hif = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + s = _mm512_set1_ps(scale_vc); + bvf_loc = _mm512_mul_ps(bvf_loc, s); + bvf_hic = _mm512_mul_ps(bvf_hic, s); + s = _mm512_set1_ps(scale_vd); + bvf_lod = _mm512_mul_ps(bvf_lod, s); + bvf_hid = _mm512_mul_ps(bvf_hid, s); + s = _mm512_set1_ps(scale_ve); + bvf_loe = _mm512_mul_ps(bvf_loe, s); + bvf_hie = _mm512_mul_ps(bvf_hie, s); + s = _mm512_set1_ps(scale_vf); + bvf_lof = _mm512_mul_ps(bvf_lof, s); + bvf_hif = _mm512_mul_ps(bvf_hif, s); + Transpose16x16Avx512(FpData, _mm512_castps_si512(bvf_lo0), + _mm512_castps_si512(bvf_lo1), _mm512_castps_si512(bvf_lo2), + _mm512_castps_si512(bvf_lo3), _mm512_castps_si512(bvf_lo4), + _mm512_castps_si512(bvf_lo5), _mm512_castps_si512(bvf_lo6), + _mm512_castps_si512(bvf_lo7), _mm512_castps_si512(bvf_lo8), + _mm512_castps_si512(bvf_lo9), _mm512_castps_si512(bvf_loa), + _mm512_castps_si512(bvf_lob), _mm512_castps_si512(bvf_loc), + _mm512_castps_si512(bvf_lod), _mm512_castps_si512(bvf_loe), + _mm512_castps_si512(bvf_lof)); + if (kklen > 16) { + Transpose16x16Avx512(FpData + 16 * 16, _mm512_castps_si512(bvf_hi0), + _mm512_castps_si512(bvf_hi1), _mm512_castps_si512(bvf_hi2), + _mm512_castps_si512(bvf_hi3), _mm512_castps_si512(bvf_hi4), + _mm512_castps_si512(bvf_hi5), _mm512_castps_si512(bvf_hi6), + _mm512_castps_si512(bvf_hi7), _mm512_castps_si512(bvf_hi8), + _mm512_castps_si512(bvf_hi9), _mm512_castps_si512(bvf_hia), + _mm512_castps_si512(bvf_hib), _mm512_castps_si512(bvf_hic), + _mm512_castps_si512(bvf_hid), _mm512_castps_si512(bvf_hie), + _mm512_castps_si512(bvf_hif)); + } + FpData += 16 * kklen; + } + + b += Q4Type::BlobSize; + } + + // move to next 16 columns + b_col += 16 * ldb; + nblk -= 16; + } + + // left over columns less than 16 ? + nblk += 16; + if (nblk > 0) { + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + size_t ck = std::min(CountK - k, Q4Type::BlkLen); + + float scale_v[16]; + const __m128i* b_ptr[16]; + for (int64_t nn = 0; nn < nblk; nn++) { + const auto* bb = b + ldb * nn; + scale_v[nn] = MlasQ4BlkScale(bb); + b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); + } + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); + __m512 bvf_lo[16]; + __m512 bvf_hi[16]; + for (int64_t nn = 0; nn < nblk; nn++) { + const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); + __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bytes = _mm256_and_si256(lowMask, bytes); + + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + const auto* bb = b + ldb * nn; + const uint8_t zp = MlasQ4BlkZeroPoint(bb); + bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes = _mm256_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); + const __m256i vx16_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); + + // Convert to 32-bit int -> float 32 + bvf_lo[nn] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); + bvf_hi[nn] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); + const __m512 s = _mm512_set1_ps(scale_v[nn]); + bvf_lo[nn] = _mm512_mul_ps(bvf_lo[nn], s); + bvf_hi[nn] = _mm512_mul_ps(bvf_hi[nn], s); + } + for (int64_t nn = nblk; nn < 16; nn++) { + bvf_lo[nn] = _mm512_setzero_ps(); + bvf_hi[nn] = _mm512_setzero_ps(); + } + Transpose16x16Avx512( + FpData, _mm512_castps_si512(bvf_lo[0]), _mm512_castps_si512(bvf_lo[1]), + _mm512_castps_si512(bvf_lo[2]), _mm512_castps_si512(bvf_lo[3]), + _mm512_castps_si512(bvf_lo[4]), _mm512_castps_si512(bvf_lo[5]), + _mm512_castps_si512(bvf_lo[6]), _mm512_castps_si512(bvf_lo[7]), + _mm512_castps_si512(bvf_lo[8]), _mm512_castps_si512(bvf_lo[9]), + _mm512_castps_si512(bvf_lo[10]), _mm512_castps_si512(bvf_lo[11]), + _mm512_castps_si512(bvf_lo[12]), _mm512_castps_si512(bvf_lo[13]), + _mm512_castps_si512(bvf_lo[14]), _mm512_castps_si512(bvf_lo[15])); + if (kklen > 16) { + Transpose16x16Avx512( + FpData + 16 * 16, _mm512_castps_si512(bvf_hi[0]), + _mm512_castps_si512(bvf_hi[1]), _mm512_castps_si512(bvf_hi[2]), + _mm512_castps_si512(bvf_hi[3]), _mm512_castps_si512(bvf_hi[4]), + _mm512_castps_si512(bvf_hi[5]), _mm512_castps_si512(bvf_hi[6]), + _mm512_castps_si512(bvf_hi[7]), _mm512_castps_si512(bvf_hi[8]), + _mm512_castps_si512(bvf_hi[9]), _mm512_castps_si512(bvf_hi[10]), + _mm512_castps_si512(bvf_hi[11]), _mm512_castps_si512(bvf_hi[12]), + _mm512_castps_si512(bvf_hi[13]), _mm512_castps_si512(bvf_hi[14]), + _mm512_castps_si512(bvf_hi[15])); + } + FpData += 16 * kklen; + } + b += Q4Type::BlobSize; + } + } +} + + +template<> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); +} + +/** + * @brief For testing purpose, + * Dequantize the data intp fp32, and then pack them for use + * in sgemm kernel. equivalent to MlasQ4GemmUnPackB and then + * MlasSgemmCopyPackB + * @param QType + * @param FpData + * @param PackedB + * @param CountN + * @param CountK + * @param ldb + */ +void +MlasBlkQ4DequantSgemmPackB( + MLAS_BLK_QUANT_TYPE QType, + float* FpData, + const uint8_t* PackedB, + size_t CountN, + size_t CountK, + size_t ldb) +{ + switch (QType) { + case BlkQ4Zp8: + return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); + case BlkQ4Sym64: + return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); + case BlkQ4Sym128: + return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); + default: + return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); + } +} + +template<> +MLAS_FORCEINLINE +void +AddBiasAvx( + const float* Bias, + float* C, + size_t CountM, + size_t CountN, + size_t ldc + ) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + __m128 acc_x = _mm_loadu_ps(sum); + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias)); + _mm_storeu_ps(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + + +static MLAS_Q4GEMM_OPERATION* Q4Operations_avx512vnni[] = { + MlasQ4GemmOperation, + MlasQ4GemmOperation, + MlasQ4GemmOperation, + nullptr, + MlasQ4GemmOperation +}; + +const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512 = { + Q4Operations_avx512vnni +}; + + +//////////////////////////////////////////////////////////// +// Block int8 quantization, currently we only +// implement symmetric quant, with no zero-point + +template +MLAS_FORCEINLINE void +MlasQ80BlkQuantRow(const float* A, void* Qblob, size_t size) +{ + static_assert(QType::BlkLen % 16 == 0); + const __m512 signBit = _mm512_set1_ps(-0.0f); + int8_t* blob = reinterpret_cast(Qblob); + for (size_t k = 0; k < size; k += QType::BlkLen) { + const size_t step = std::min(QType::BlkLen, size - k); + + __m512 maxAbs = _mm512_setzero_ps(); + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + // Compute max(abs(e)) for the block + maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0)); + } + + __m256 max8 = + _mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + const float maxScalar = _mm_cvtss_f32(max4); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *reinterpret_cast(blob) = scale; + blob += sizeof(float); + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const __m512 mul = _mm512_set1_ps(inverse_scale); + __m128i* dst = reinterpret_cast<__m128i*>(blob); + + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + v0 = _mm512_mul_ps(v0, mul); + + // Round to nearest integer + v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST); + + // Convert floats to integers + __m512i i0 = _mm512_cvtps_epi32(v0); + + // Convert int32 to int8 + _mm_storeu_si128(dst++, _mm512_cvtepi32_epi8(i0)); + } + if (step < QType::BlkLen) { + memset(blob + step, 0, QType::BlkLen - step); + } + blob += QType::BlkLen; + } +} + +template +void +Q80BlkQuant(void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool) +{ + const size_t parts = (size_t)ceil(double(M) * K / (16.0 * 1024)); + const size_t TargetThreadCnt = + std::max(std::min(parts, (size_t)MlasGetMaximumThreadCount(ThreadPool)), (size_t)1); + const size_t linesize = MlasQ80BlkQuantSizeImpl(1, K); + + size_t M_stride = MlasDivRoundup(M, TargetThreadCnt); + size_t threads = MlasDivRoundup(M, M_stride); + MlasTrySimpleParallel(ThreadPool, threads, [&](ptrdiff_t tid) { + const size_t m = tid * M_stride; + const float* src = A + lda * m; + uint8_t* dst = reinterpret_cast(Qblob) + m * linesize; + for (size_t i = 0; i < std::min(M_stride, M-m); i++) { + MlasQ80BlkQuantRow(src, dst, K); + src += lda; + dst += linesize; + } + }); +} + +static MLAS_Q80_BLKQUANT* Q80Quant_avx512vnni[] = { + Q80BlkQuant, + Q80BlkQuant, + Q80BlkQuant, + nullptr, + Q80BlkQuant +}; + + +static +MLAS_FORCEINLINE +__m128 +FoldAccumulators( + const __m256& acc0, + const __m256& acc1, + const __m256& acc2, + const __m256& acc3 + ) +{ + __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); + __m256 acc_hi01 = _mm256_unpackhi_ps(acc0, acc1); + __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); + __m256 acc_hi23 = _mm256_unpackhi_ps(acc2, acc3); + + __m256 acc_lo0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); + __m256 acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + + return _mm_add_ps(_mm256_extractf32x4_ps(acc_lo0123, 0), _mm256_extractf32x4_ps(acc_lo0123, 1)); +} + +static inline float +mm256_reduce_add_ps(__m256& x) +{ + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} + + +template +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernelAvx512f( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + // We process 32 quantized values in a batch. + static_assert(MLAS_QUANT4_BLK_UNIT == 32); + static_assert(Q4Type::BlkLen % MLAS_QUANT4_BLK_UNIT == 0); + + const __m256i zero = _mm256_setzero_si256(); + const __m256i lowMask = _mm256_set1_epi8(0xF); + + for (size_t m = 0; m < CountM; m++) { + const uint8_t* b_col = PackedB; + auto* sum_ptr = C; + auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN) - 4; + while (nblk >= 0) { + __m256 acc_lo0 = _mm256_setzero_ps(); + __m256 acc_lo1 = _mm256_setzero_ps(); + __m256 acc_lo2 = _mm256_setzero_ps(); + __m256 acc_lo3 = _mm256_setzero_ps(); + const int8_t* ablob = QuantA; + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const float a_scale = *reinterpret_cast(ablob); + ablob += sizeof(float); + const float scale_v0 = MlasQ4BlkScale(b) * a_scale; + const float scale_v1 = MlasQ4BlkScale(b + ldb) * a_scale; + const float scale_v2 = MlasQ4BlkScale(b + ldb * 2) * a_scale; + const float scale_v3 = MlasQ4BlkScale(b + ldb * 3) * a_scale; + + const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); + const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); + const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); + const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); + + for (size_t kk = 0; kk < Q4Type::BlkLen; kk += MLAS_QUANT4_BLK_UNIT) { + // Load A row vector + const __m256i a_bytes = _mm256_loadu_si256((const __m256i*)ablob); + ablob += MLAS_QUANT4_BLK_UNIT; + + // Load 4 B column vectors (quantized to int4 blobs) + const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); + const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); + const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); + const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); + + // expand 4b into byte array + __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); + __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); + __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); + __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + + // Subtract zero-point from the integers + if constexpr (std::is_same_v) { + bytes0 = _mm256_sub_epi8( + bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); + bytes1 = _mm256_sub_epi8( + bytes1, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); + bytes2 = _mm256_sub_epi8( + bytes2, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); + bytes3 = _mm256_sub_epi8( + bytes3, + _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); + } else { + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // to use vnni unsigned x signed int, negate all negative + // b vals to make it all positive, and then also negate the + // corresponding a vals to compensate + const __m256i summed_pairs0 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bytes0, bytes0), _mm256_sign_epi8(a_bytes, bytes0)); + const __m256i summed_pairs1 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bytes1, bytes1), _mm256_sign_epi8(a_bytes, bytes1)); + const __m256i summed_pairs2 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bytes2, bytes2), _mm256_sign_epi8(a_bytes, bytes2)); + const __m256i summed_pairs3 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bytes3, bytes3), _mm256_sign_epi8(a_bytes, bytes3)); + + const __m256 sums0 = _mm256_cvtepi32_ps(summed_pairs0); + const __m256 sums1 = _mm256_cvtepi32_ps(summed_pairs1); + const __m256 sums2 = _mm256_cvtepi32_ps(summed_pairs2); + const __m256 sums3 = _mm256_cvtepi32_ps(summed_pairs3); + acc_lo0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sums0, acc_lo0); + acc_lo1 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v1), sums1, acc_lo1); + acc_lo2 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v2), sums2, acc_lo2); + acc_lo3 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v3), sums3, acc_lo3); + } + b += Q4Type::BlobSize; + } + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + b_col += 4 * ldb; + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m256 acc_lo[4]{}; + const int8_t* ablob = QuantA; + const auto* b = b_col; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const float a_scale = *reinterpret_cast(ablob); + ablob += sizeof(float); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + const auto* bb = b + ldb * nn; + scale_v[nn] = MlasQ4BlkScale(bb) * a_scale; + b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); + } + + for (size_t kk = 0; kk < Q4Type::BlkLen; kk += MLAS_QUANT4_BLK_UNIT) { + const __m256i a_bytes = _mm256_loadu_si256((const __m256i*)ablob); + ablob += MLAS_QUANT4_BLK_UNIT; + + for (int64_t nn = 0; nn < nblk; nn++) { + const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); + __m256i b_bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + b_bytes = _mm256_and_si256(lowMask, b_bytes); + + if constexpr (std::is_same_v) { + // Subtract zero-point from the integers + const auto* bb = b + ldb * nn; + const uint8_t zp = MlasQ4BlkZeroPoint(bb); + b_bytes = _mm256_sub_epi8(b_bytes, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + b_bytes = _mm256_sub_epi8(b_bytes, eight); + } + + // to use vnni unsigned x signed int, negate all negative + // b vals to make it all positive, + const __m256i ax = _mm256_sign_epi8(b_bytes, b_bytes); + // and then also negate the corresponding a vals to compensate + const __m256i sy = _mm256_sign_epi8(a_bytes, b_bytes); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + const __m256 sum = _mm256_cvtepi32_ps(summed_pairs); + acc_lo[nn] = _mm256_fmadd_ps(_mm256_set1_ps(scale_v[nn]), sum, acc_lo[nn]); + } + } + b += Q4Type::BlobSize; + } + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = mm256_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + QuantA += lda; + } + return CountM; +} + + +template<> +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernel( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, + lda, ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernel( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, + lda, ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernel( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, + lda, ldb, ldc, Bias); +} + +template<> +MLAS_FORCEINLINE +size_t +MlasQ8Q4GemmKernel( + const int8_t* QuantA, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias + ) +{ + return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, + lda, ldb, ldc, Bias); +} + + +static MLAS_Q8Q4GEMM_OPERATION* Q8Q4Operations_avx512vnni[] = { + MlasQ8Q4GemmOperation, + MlasQ8Q4GemmOperation, + MlasQ8Q4GemmOperation, + nullptr, + MlasQ8Q4GemmOperation +}; + + +const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni = { + Q80Quant_avx512vnni, + Q8Q4Operations_avx512vnni +}; diff --git a/third_party/mlas/lib/qdwconv.cpp b/third_party/mlas/lib/qdwconv.cpp index 924009ab5c..59f6877f70 100644 --- a/third_party/mlas/lib/qdwconv.cpp +++ b/third_party/mlas/lib/qdwconv.cpp @@ -41,6 +41,10 @@ MlasConvDepthwiseKernel( #elif defined(MLAS_NEON_INTRINSICS) const uint8x8_t InputZeroPointVector = vdup_n_u8(uint8_t(InputZeroPoint)); const uint8x8_t FilterZeroPointVector = vdup_n_u8(uint8_t(FilterZeroPoint)); +#elif defined(MLAS_LSX_INTRINSICS) + const __m128i ZeroVector = __lsx_vldi(0); + const __m128i InputZeroPointVector = __lsx_vreplgr2vr_h(InputZeroPoint); + const __m128i FilterZeroPointVector = __lsx_vreplgr2vr_h(FilterZeroPoint); #endif while (OutputCount > 0) { @@ -141,6 +145,54 @@ MlasConvDepthwiseKernel( vst1q_s32(&Output[4], Accumulator1); Output += 8; + ChannelOffset += 8; + c -= 8; + } +#elif defined(MLAS_LSX_INTRINSICS) + + while (c >= 8) { + __m128i Accumulator0 = __lsx_vldi(0); + __m128i Accumulator1 = __lsx_vldi(0); + size_t ChannelKernelOffset = ChannelOffset; + + for (size_t k = 0; k < KernelSize; k++) { + __m128i InputVector = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); + __lsx_vinsgr2vr_d(InputVector, 0, 1); + __m128i FilterVector = + __lsx_vld((const __m128i*)&Filter[ChannelKernelOffset], 0); + __lsx_vinsgr2vr_d(FilterVector, 0, 1); + + if (std::is_signed::value) { + InputVector = __lsx_vsrai_h(__lsx_vilvl_b(InputVector, ZeroVector), 8); + } else { + InputVector = __lsx_vilvl_b(ZeroVector, InputVector ); + } + + if (std::is_signed::value) { + FilterVector = __lsx_vsrai_h(__lsx_vilvl_b(FilterVector, ZeroVector), 8); + } else { + FilterVector = __lsx_vilvl_b(ZeroVector, FilterVector); + } + + InputVector = __lsx_vsub_h(InputVector, InputZeroPointVector); + FilterVector = __lsx_vsub_h(FilterVector, FilterZeroPointVector); + + // N.B. Emulate PMULLD functionality on LSX by computing the low + // and high parts of the result and interleaving the results. + __m128i MultiplyLowWords = __lsx_vmul_h(InputVector, FilterVector); + __m128i MultiplyHighWords = __lsx_vmuh_h(InputVector, FilterVector); + __m128i Multiply0 = __lsx_vilvl_h(MultiplyHighWords, MultiplyLowWords); + __m128i Multiply1 = __lsx_vilvh_h(MultiplyHighWords, MultiplyLowWords); + + Accumulator0 = __lsx_vadd_w(Accumulator0, Multiply0); + Accumulator1 = __lsx_vadd_w(Accumulator1, Multiply1); + ChannelKernelOffset += Channels; + } + + __lsx_vst(Accumulator0, (__m128i*)&Output[0], 0); + __lsx_vst(Accumulator1, (__m128i*)&Output[4], 0); + Output += 8; + ChannelOffset += 8; c -= 8; } @@ -322,4 +374,4 @@ Return Value: ); } } -} \ No newline at end of file +} diff --git a/third_party/mlas/lib/qgemm.h b/third_party/mlas/lib/qgemm.h index 53bc74c9e4..75c17a6b5a 100644 --- a/third_party/mlas/lib/qgemm.h +++ b/third_party/mlas/lib/qgemm.h @@ -241,8 +241,7 @@ MlasGemmQuantThreadInit() constexpr size_t packedASize = UpAlignSize(PackedStrides.M * PackedStrides.K * sizeof(typename KernelType::PackedAType)); - //constexpr size_t bufsize = std::max(packASize + packBSize, packedASize) + rowSumSize + colSumSize + zpbSize; - constexpr size_t bufsize = std::max(packASize + packBSize, packedASize) + rowSumSize + colSumSize + zpbSize; + constexpr size_t bufsize = std::max(packASize + packBSize, packedASize) + rowSumSize + colSumSize + zpbSize; MlasThreadedBufAlloc(bufsize); } @@ -872,7 +871,7 @@ MlasGemmQuantGetDispatch( GemmQuantDispatch = &MlasGemmQuantDispatchDefault; } -#if defined(MLAS_TARGET_AMD64_IX86) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) if (!AIsSigned) { if (BIsSigned) { GemmQuantDispatch = GetMlasPlatform().GemmU8S8Dispatch; @@ -883,13 +882,9 @@ MlasGemmQuantGetDispatch( } #elif defined(MLAS_TARGET_ARM64) if(BIsSigned) { - if(GetMlasPlatform().GemmU8X8Dispatch == &MlasGemmU8X8DispatchNeon) { - GemmQuantDispatch = &MlasGemmX8S8DispatchNeon; - } else { - GemmQuantDispatch = AIsSigned? &MlasGemmS8S8DispatchSdot : &MlasGemmU8X8DispatchUdot; - } + GemmQuantDispatch = AIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmU8S8Dispatch; } else if(!AIsSigned) { - GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; + GemmQuantDispatch = GetMlasPlatform().GemmU8U8Dispatch; } #elif defined(MLAS_TARGET_ARM64EC) || (defined(MLAS_TARGET_ARM) && !defined(_MSC_VER)) if(BIsSigned || !AIsSigned) { diff --git a/third_party/mlas/lib/qgemm_kernel_amx.cpp b/third_party/mlas/lib/qgemm_kernel_amx.cpp index 7c8743026b..479a82e712 100644 --- a/third_party/mlas/lib/qgemm_kernel_amx.cpp +++ b/third_party/mlas/lib/qgemm_kernel_amx.cpp @@ -16,6 +16,7 @@ Module Name: #include "mlasi.h" #include "qgemm.h" +#include "amx_common.h" #define TMM0 0 @@ -202,7 +203,7 @@ MlasGemmQuantThreadInit() static thread_local struct tileconfig_t tc = {0}; struct tileconfig_t current_tc = {0}; - _tile_storeconfig(¤t_tc); + tile_storeconfig(¤t_tc); if (tc.palette_id == 0 || (std::memcmp(¤t_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 && std::memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { @@ -212,7 +213,8 @@ MlasGemmQuantThreadInit() tc.rows[t] = 16; tc.colb[t] = 64; } - _tile_loadconfig(&tc); + + tile_loadconfig(&tc); } } @@ -238,14 +240,14 @@ InitHalfTileWithRowColSums( row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -290,14 +292,14 @@ InitHalfTileWithRowColSumsZeroPoints( row6 = _mm512_add_epi32(colsum, row6); row7 = _mm512_add_epi32(colsum, row7); if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_epi32(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_epi32(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_epi32(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_epi32(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_epi32(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_epi32(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_epi32(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_epi32(c_ptr+ldc*7)); + row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); + row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); + row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); + row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); + row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); + row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); + row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); + row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); } _mm512_storeu_si512(Tile, row0); _mm512_storeu_si512(Tile+16, row1); @@ -435,58 +437,58 @@ MlasGemmQuantKernel( size_t n = CountN; for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile4, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } - colsum = _mm512_loadu_epi32(col_sum_ptr); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; if (ZeroPointB != nullptr) { - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( Tile6, m0, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 != 0){ InitTileWithRowColSums( Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } @@ -494,33 +496,36 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM7, TMM3, TMM1); } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if (m0 == TILE_M) { - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, m0, FullMask, c_blk, ldc); MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc); } if (m1 != 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, FullMask, c16_blk, ldc); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc); } c_blk += 2 * TILE_N; @@ -539,23 +544,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } } if (nmask_high != 0){ @@ -565,23 +570,23 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, m0, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSumsZeroPoints( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } else { InitTileWithRowColSums( Tile6, m0, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); if (m1 > 0){ InitTileWithRowColSums( Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } } @@ -589,18 +594,19 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); if (m1 > 0){ - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM5, TMM3, TMM0); } if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); if (m1 > 0){ - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); } } b_blk += TILE_N * TILE_K; @@ -608,20 +614,20 @@ MlasGemmQuantKernel( a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0 && m0 == TILE_M){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); MoveTile(Tile4, m0, static_cast(nmasks), c_blk, ldc); } if (m1 > 0){ - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); MoveTile(Tile5, m1, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc); if (m1 > 0){ - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc); } } @@ -643,76 +649,78 @@ MlasGemmQuantKernel( const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - __m512i zeropoint = _mm512_loadu_epi32(zp_ptr); + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_epi32(zp_ptr); + zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } else { - __m512i colsum = _mm512_loadu_epi32(col_sum_ptr); + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; - _tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM0, b_blk, TILE_K); InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_epi32(col_sum_ptr); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + colsum = _mm512_loadu_si512(col_sum_ptr); col_sum_ptr += TILE_N; InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM4, TMM2, TMM0); InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); } - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM7, TMM3, TMM1); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM4, TMM2, TMM0); } - _tile_dpbusd(TMM5, TMM3, TMM0); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + c_blk += 2 * TILE_N; - _tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c16_blk += 2 * TILE_N; } @@ -726,20 +734,20 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); - _tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); - _tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); } if (nmask_high != 0){ colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); @@ -748,52 +756,58 @@ MlasGemmQuantKernel( InitTileWithRowColSumsZeroPoints( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSumsZeroPoints( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } else { InitTileWithRowColSums( Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); - _tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); InitTileWithRowColSums( Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); - _tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } } const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - _tile_loadd(TMM0, b_blk, TILE_K); - _tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - _tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - _tile_dpbusd(TMM4, TMM2, TMM0); - _tile_dpbusd(TMM5, TMM3, TMM0); + tile_loadd(TMM0, b_blk, TILE_K); + tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + + tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM5, TMM3, TMM0); + if (nmask_high != 0){ - _tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - _tile_dpbusd(TMM6, TMM2, TMM1); - _tile_dpbusd(TMM7, TMM3, TMM1); + tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); + } b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } if ((static_cast(nmasks) & 0x8000) != 0){ - _tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - _tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + } else { - _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); } if (nmask_high != 0){ - _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); } diff --git a/third_party/mlas/lib/qgemm_kernel_default.cpp b/third_party/mlas/lib/qgemm_kernel_default.cpp index adf3992089..8f4baaa0ff 100644 --- a/third_party/mlas/lib/qgemm_kernel_default.cpp +++ b/third_party/mlas/lib/qgemm_kernel_default.cpp @@ -41,12 +41,11 @@ MlasGemmQuantFixupZeroPointA( bool AIsSigned ) { - //if (AIsSigned) { - // ZeroPointA = (uint8_t)(ZeroPointA ^ 0x80); - //} + if (AIsSigned) { + ZeroPointA = (uint8_t)(ZeroPointA ^ 0x80); + } - //return ZeroPointA; - return AIsSigned ? (uint8_t)(ZeroPointA ^ 0x80):ZeroPointA; + return ZeroPointA; } template<> @@ -57,12 +56,11 @@ MlasGemmQuantFixupZeroPointB( bool BIsSigned ) { -// if (BIsSigned) { -// ZeroPointB = MLAS_GEMM_QUANT_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80); -// } + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_QUANT_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80); + } -// return ZeroPointB; - return BIsSigned ? MLAS_GEMM_QUANT_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80):ZeroPointB; + return ZeroPointB; } template<> diff --git a/third_party/mlas/lib/qgemm_kernel_lsx.cpp b/third_party/mlas/lib/qgemm_kernel_lsx.cpp new file mode 100644 index 0000000000..7d5817335b --- /dev/null +++ b/third_party/mlas/lib/qgemm_kernel_lsx.cpp @@ -0,0 +1,531 @@ +/*++ + +Copyright (C) 2023 Loongson Technology Corporation Limited. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_lsx.cpp + +Abstract: + + This module implements QGEMM kernels for LSX. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" +#include + +struct MLAS_GEMM_U8X8_KERNEL_LSX +{ + typedef int16_t PackedAType; + typedef int16_t PackedBType; + typedef uint8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 2; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_LSX::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_LSX::Strides; + +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_LSX::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + const __m128i ZeroVector = __lsx_vrepli_d(0); + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + __m128i ReductionVector = ZeroVector; + + // + // Zero extend the source bytes to 16-bits and write to the packed + // buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 2 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // These 16-bit values are also accumulated into an intermediate per-row + // accumulator. CountK cannot be greater than 128 to avoid overflowing + // these signed 16-bit accumulators. + // + + while (k >= 8) { + + __m128i Bytes = __lsx_vld((const __m128i*) & a[0], 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + __lsx_vst(Words, (__m128i*) & D[0], 0); + + a += 8; + D += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + __m128i Bytes = __lsx_vld((__m128i*)PaddedMatrixAData, 0); + __lsx_vinsgr2vr_d(Bytes, 0, 1); + __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); + + ReductionVector = __lsx_vadd_h(ReductionVector, Words); + + // + // Copy pairs of 16-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { + __lsx_vstelm_w(Words, (int32_t*)D, 0 , 0); + D += 2; + Words = __lsx_vshuf4i_w(Words, 0x39); //(0, 3, 2, 1) + } + } + + // + // Reduce the partial accumulators. + // + __m128i tmp1 = ZeroVector, tmp2 = ZeroVector; + tmp1 = __lsx_vmaddwev_w_h(tmp1, ReductionVector, OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ReductionVector, OnesWordBroadcast); + ReductionVector = __lsx_vadd_w(tmp1, tmp2); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0xee)); + ReductionVector = __lsx_vadd_w(ReductionVector, + __lsx_vshuf4i_w(ReductionVector, 0x11)); + + __lsx_vstelm_w(ReductionVector, RowSumBuffer++, 0 , 0); + + A += lda; + CountM -= 1; + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessLSX( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + __m128i BytesRow0, + __m128i BytesRow1, + __m128i BitFlipVector, + __m128i ColumnSums[2] +) +{ + __m128i BytesInterleaved = __lsx_vilvl_b(BytesRow1, BytesRow0); + + BytesInterleaved = __lsx_vxor_v(BytesInterleaved, BitFlipVector); + + __m128i WordsInterleaved0 = __lsx_vsrai_h(__lsx_vilvl_b(BytesInterleaved, BytesInterleaved), 8); + __m128i WordsInterleaved1 = __lsx_vsrai_h(__lsx_vilvh_b(BytesInterleaved, BytesInterleaved), 8); + + ColumnSums[0] = __lsx_vadd_h(ColumnSums[0], WordsInterleaved0); + ColumnSums[1] = __lsx_vadd_h(ColumnSums[1], WordsInterleaved1); + + __lsx_vst(WordsInterleaved0, (__m128i*) & D[0], 0); + __lsx_vst(WordsInterleaved1, (__m128i*) & D[8], 0); +} + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + uint16_t val = 1; + const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); + const __m128i BitFlipVector = __lsx_vreplgr2vr_w(BIsSigned ? 0 : 0x80808080); + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + // These values are also zero-extended and accumulated into an + // intermediate per-column accumulator. CountK cannot be greater than + // 128 to avoid overflowing these signed 16-bit accumulators. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((const __m128i*) & b[ldb], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + + D += 16; + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + __m128i ColumnSums[2]; + uint8_t PaddedMatrixBData[16]; + + __lsx_vst(BitFlipVector, (__m128i*)PaddedMatrixBData, 0); + + ColumnSums[0] = __lsx_vldi(0); + ColumnSums[1] = __lsx_vldi(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + __m128i BytesRow1 = __lsx_vld((__m128i*) & PaddedMatrixBData[8], 0); + __lsx_vinsgr2vr_d(BytesRow1, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); + + b += ldb * 2; + D += 16; + k -= 2; + } + + if (k > 0) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + + __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); + __lsx_vinsgr2vr_d(BytesRow0, 0, 1); + + MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); + } + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); + ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); + ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); + + __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); + __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8MultiplyAccumulateRowLSX( + __m128i ABroadcast, + const int16_t* B, + __m128i Accumulators[2] +) +{ + __m128i BElements0 = __lsx_vld((__m128i*) & B[0], 0); + __m128i BElements1 = __lsx_vld((__m128i*) & B[8], 0); + + __m128i tmp1, tmp2; + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements0, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements0, ABroadcast); + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vadd_w(tmp1, tmp2)); + tmp1 = tmp2 = __lsx_vldi(0); + tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements1, ABroadcast); + tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements1, ABroadcast); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vadd_w(tmp1, tmp2)); +} + +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + __m128i Accumulators[2]; + + // + // Initialize the accumulators with the row and column sums. + // + + int32_t RowSumValue = RowSumBuffer[0]; + + if (ZeroPointB != nullptr) { + + int32_t ScaledRowSumBuffer[8]; + + for (size_t i = 0; i < 8; i++) { + ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; + } + + ZeroPointB += 8; + + Accumulators[0] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[0], 0); + Accumulators[1] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[4], 0); + + } + else { + + Accumulators[0] = __lsx_vreplgr2vr_w(RowSumValue); + Accumulators[1] = Accumulators[0]; + } + + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((const __m128i*) & ColumnSumBuffer[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((const __m128i*) & ColumnSumBuffer[4], 0)); + ColumnSumBuffer += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the pair of 16-bit values from matrix B, and add the 32-bit + // intermediate into the accumulator registers. + // + + const int16_t* a = A; + size_t k = PackedCountK; + + while (k >= 4) { + + __m128i AElements = __lsx_vld((__m128i*)a, 0); + __m128i ABroadcast; + + ABroadcast = __lsx_vreplvei_w(AElements, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 1); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[16], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 2); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[32], Accumulators); + + ABroadcast = __lsx_vreplvei_w(AElements, 3); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[48], Accumulators); + + a += 4 * 2; + B += 4 * 16; + k -= 4; + } + + while (k > 0) { + + __m128i ABroadcast = __lsx_vldrepl_w((int32_t*)a, 0); + MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); + + a += 2; + B += 16; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((__m128i*) & C[4], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + __lsx_vst(Accumulators[1], (__m128i*) & C[4], 0); + + C += 8; + CountN -= 8; + + } + else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); + } + + __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); + C += 4; + + Accumulators[0] = Accumulators[1]; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vinsgr2vr_d(__lsx_vld((__m128i*) & C[0], 0), 0, 1)); + } + + *((uint64_t *)&C[0]) = __lsx_vpickve2gr_d(Accumulators[0], 0); + C += 2; + + Accumulators[0] = __lsx_vshuf4i_w(Accumulators[0], 0xee); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = __lsx_vpickve2gr_w(Accumulators[0], 0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + CountN = 0; + } + } + + return 1; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX = { + MlasGemmQuantOperation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_LSX::PackedK, + 0, + 1 // aLSXmbly kernel M stride +}; diff --git a/third_party/mlas/lib/qgemm_kernel_smmla.cpp b/third_party/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000..c41f43ca22 --- /dev/null +++ b/third_party/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,964 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) +{ + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8}; diff --git a/third_party/mlas/lib/qgemm_kernel_ummla.cpp b/third_party/mlas/lib/qgemm_kernel_ummla.cpp new file mode 100644 index 0000000000..3936154432 --- /dev/null +++ b/third_party/mlas/lib/qgemm_kernel_ummla.cpp @@ -0,0 +1,967 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_ummla.cpp + +Abstract: + + This module implements ummla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON UMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_U8X8_KERNEL_UMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + uint8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a0 + lda * 2; + const uint8_t* a3 = a0 + lda * 3; + const uint8_t* a4 = a0 + lda * 4; + const uint8_t* a5 = a0 + lda * 5; + const uint8_t* a6 = a0 + lda * 6; + const uint8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + uint32x4_t RowSums0 = vmovq_n_u32(0); + uint32x4_t RowSums1 = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); + a4 += 16; + uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); + a5 += 16; + uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); + a6 += 16; + uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); + a7 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + uint64x2_t z4 = vzip1q_u64(v4, v5); + uint64x2_t z5 = vzip2q_u64(v4, v5); + uint64x2_t z6 = vzip1q_u64(v6, v7); + uint64x2_t z7 = vzip2q_u64(v6, v7); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); + vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); + vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); + vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + uint64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + uint64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + uint64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + uint64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + vst1q_u8(&d[32], vmovq_n_u8(0)); + vst1q_u8(&d[48], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); + vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + uint8_t* d = PaddedMatrixAData; + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + uint8x8_t BytesRow[8], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2]) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); + uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); + + uint8x16x2_t zw1 = vzipq_u8(v02, v13); + uint16x8x2_t zd1 = + vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); + + uint8x16x2_t zw2 = vzipq_u8(v46, v57); + uint16x8x2_t zd2 = + vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); + + uint32x4x2_t zd3 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); + uint32x4x2_t zd4 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); + vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); + vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); + + uint32x4_t ColSums0L_pada = vmovq_n_u32(0); + ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); + uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); + uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); + uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; + + uint32x4_t ColSums0H_pada = vmovq_n_u32(0); + ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); + uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); + uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); + uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); + + uint32x4_t ColSums1L_pada = vmovq_n_u32(0); + ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); + uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); + uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); + uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; + + uint32x4_t ColSums1H_pada = vmovq_n_u32(0); + ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); + uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); + uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); + uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + while (k >= 8) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + BytesRow[4] = vld1_u8(&b[ldb * 4]); + BytesRow[5] = vld1_u8(&b[ldb * 5]); + BytesRow[6] = vld1_u8(&b[ldb * 6]); + BytesRow[7] = vld1_u8(&b[ldb * 7]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); + BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); + BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); + BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); + BytesRow[7] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[64]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + const uint8_t* bcopy4 = &b[ldb * 4]; + const uint8_t* bcopy5 = &b[ldb * 5]; + const uint8_t* bcopy6 = &b[ldb * 6]; + const uint8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, + 8}; diff --git a/third_party/mlas/lib/qladd.cpp b/third_party/mlas/lib/qladd.cpp index 971ea0161d..5dafa17c2a 100644 --- a/third_party/mlas/lib/qladd.cpp +++ b/third_party/mlas/lib/qladd.cpp @@ -552,6 +552,119 @@ MlasQLinearAddKernelHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +static +void +MlasQLinearAddKernelHelper( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const float ScaleRatio_AC = ScaleA / ScaleC; + const float ScaleRatio_BC = ScaleB / ScaleC; + const auto VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); + const auto VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); + auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); + + MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; + if (IsScalarB) { + float tmp_f = (float)*InputB; + uint32_t *tmp_p = (uint32_t *)&tmp_f; + vb_lo = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w(*tmp_p)); + VectorFixedPart = __lsx_vfmadd_s(vb_lo, VectorScaleRatio_BC, VectorFixedPart); + } + + __m128i tmp, tmp1; + + while (N >= 8) { + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputA, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + InputA += 8; + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputB, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + InputB += 8; + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + N -= 8; + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((MLAS_INT32X4*)OutputC, 0), __lsx_vpickve2gr_d(vc, 0), 0), (MLAS_INT32X4*)OutputC, 0); + OutputC += 8; + } + + if (N > 0) { + uint8_t TailData[8] = { 0 }; + + MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); + const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); + va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); + va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); + + if (!IsScalarB) { + MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); + const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); + const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); + vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); + vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); + } + + MLAS_INT32X4 r_lo, r_hi; + if (IsScalarB) { + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); + } else { + r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); + r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); + } + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); + + MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); + + if (N & 4) { + __lsx_vstelm_w(vc, (int*)OutputC, 0, 0); + N -= 4; + OutputC += 4; + vc = __lsx_vshuf4i_w(vc, 0x39); //_MM_SHUFFLE(0, 3, 2, 1) + } + + uint32_t PackedValueC = (uint32_t)__lsx_vpickve2gr_w(vc, 0); + for (size_t i = 0; i < N; ++i) { + *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; + PackedValueC >>= 8; + } + } +} #else template diff --git a/third_party/mlas/lib/qladd.h b/third_party/mlas/lib/qladd.h index 8c05a61853..94568941a5 100644 --- a/third_party/mlas/lib/qladd.h +++ b/third_party/mlas/lib/qladd.h @@ -463,5 +463,132 @@ MlasPackS16_128( { return reinterpret_cast(vec_packs(a, b)); } +#elif defined(MLAS_LSX_INTRINSICS) +#define LSX_DBG 1 +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsra_w(v, imm_v); +#else + return __lsx_vsrai_w(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt32( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); + return __lsx_vsrl_w(v, imm_v); +#else + return __lsx_vsrli_w(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ); + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsra_h(v, imm_v); +#else + return __lsx_vsrai_h(v, imm); +#endif +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasShiftRightInt16( + MLAS_INT32X4 v, + int imm + ) +{ +#if LSX_DBG + MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); + return __lsx_vsrl_h(v, imm_v); +#else + return __lsx_vsrli_h(v, imm); +#endif +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packus_epi16(a, b); + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, a); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, b); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); + +} + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + MLAS_INT32X4 a, + MLAS_INT32X4 b + ) +{ + // return _mm_packs_epi16(a, b); + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); + +} #endif diff --git a/third_party/mlas/lib/qlgavgpool.cpp b/third_party/mlas/lib/qlgavgpool.cpp index 1c2be0a833..e44d7ad25c 100644 --- a/third_party/mlas/lib/qlgavgpool.cpp +++ b/third_party/mlas/lib/qlgavgpool.cpp @@ -689,6 +689,316 @@ MlasQLinearGlobalAveragePoolNhwcSingleBatch( Output_zero_point, 0, 0, 1, Channels); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void MLASCALL +MlasQLinearGlobalAveragePoolNchw( + const T8Bits* Input, + float ScaleInput, + int32_t ZeroPointInput, + T8Bits* Output, + float ScaleOutput, + int32_t ZeroPointOutput, + size_t Channels, + size_t ImageSize, + int32_t* AccumulateBuffer + ) +{ + float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); + const int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; + const auto vbias = __lsx_vld((const __m128i*)&bias, 0); + const auto vzero = __lsx_vldi(0); + uint8_t buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + int32_t* sum_buffer = AccumulateBuffer; + for (size_t c = Channels; c > 0; c--) { + + __m128i vacc_lo = vbias; + __m128i vacc_hi = vzero; + auto Len = ImageSize; + for (; Len >= 32; Len -= 32) { + + const __m128i vi0 = __lsx_vld((const __m128i*)Input, 0); + __lsx_vinsgr2vr_d(vi0, 0, 1); + const __m128i vi1 = __lsx_vld((const __m128i*)(Input + 8), 0); + __lsx_vinsgr2vr_d(vi1, 0, 1); + const __m128i vi2 = __lsx_vld((const __m128i*)(Input + 16), 0); + __lsx_vinsgr2vr_d(vi2, 0, 1); + const __m128i vi3 = __lsx_vld((const __m128i*)(Input + 24), 0); + __lsx_vinsgr2vr_d(vi3, 0, 1); + + if constexpr (std::is_signed::value) { + + const __m128i vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); + const __m128i vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); + const __m128i vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); + const __m128i vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vxi0 = __lsx_vilvl_b(vzero, vi0); + const __m128i vxi1 = __lsx_vilvl_b(vzero, vi1); + const __m128i vxi2 = __lsx_vilvl_b(vzero, vi2); + const __m128i vxi3 = __lsx_vilvl_b(vzero, vi3); + const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), + __lsx_vadd_h(vxi2, vxi3)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 32; + } + for (; Len >= 8; Len -= 8) { + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += 8; + } + if (Len > 0) { + + memcpy(buffer, Input, Len); + + if constexpr (std::is_signed::value) { + + const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1), vzero), 8); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); + } else { + + const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1)); + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); + } + + Input += Len; + } + + __m128i vacc = __lsx_vadd_w(vacc_lo, vacc_hi); // [ D C | B A ] + __m128i vshuf = __lsx_vshuf4i_w(vacc, 0xb1); // [ C D | A B ] _MM_SHUFFLE(2, 3, 0, 1) + __m128i vsums = __lsx_vadd_w(vacc, vshuf); // [ D+C C+D | B+A A+B ] + vshuf = __lsx_vshuf4i_w(vsums, 0x4e); // [ B+A A+B | D+C C+D ] _MM_SHUFFLE(1, 0, 3, 2) + vsums = __lsx_vadd_w(vsums, vshuf); + __lsx_vstelm_w(vsums, sum_buffer++, 0 , 0); + } + + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, + static_cast(ZeroPointOutput), 0, 0, 1, Channels); +} + +template +MLAS_FORCEINLINE +void +MlasQLinearGlobalAveragePoolNhwcSingleBatch( + const T8Bits* Input, + T8Bits* Output, + const T8Bits* LastOf8, + size_t ImageSize, + size_t Channels, + size_t Stride, + int32_t Bias, + float Scale, + T8Bits Output_zero_point, + int32_t* AccumulateBuffer, + const T8Bits* ZeroBuffer + ) +{ + + constexpr size_t PixelsPerIteration = 7; +#define LOAD_FULL_CHANNELS() \ + const __m128i vi0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i0, 0), 0 , 1); \ + i0 += 8; \ + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i1, 0), 0 , 1); \ + i1 += 8; \ + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i2, 0), 0 , 1); \ + i2 += 8; \ + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i3, 0), 0 , 1); \ + i3 += 8; \ + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i4, 0), 0 , 1); \ + i4 += 8; \ + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i5, 0), 0 , 1); \ + i5 += 8; \ + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i6, 0), 0 , 1); \ + i6 += 8 + +#define CALCULATE_ACCUMULATE_VECTORS() \ + __m128i vacc_lo = finish_one_pass ? __lsx_vld((__m128i*)acc, 0) : vbias; \ + __m128i vacc_hi = finish_one_pass ? __lsx_vld(((__m128i*)acc) + 1, 0) : vbias; \ + __m128i vxi0; \ + __m128i vxi1; \ + __m128i vxi2; \ + __m128i vxi3; \ + __m128i vxi4; \ + __m128i vxi5; \ + __m128i vxi6; \ + if constexpr (std::is_signed::value) { \ + vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); \ + vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); \ + vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); \ + vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); \ + vxi4 = __lsx_vsrai_h(__lsx_vilvl_b(vi4, vzero), 8); \ + vxi5 = __lsx_vsrai_h(__lsx_vilvl_b(vi5, vzero), 8); \ + vxi6 = __lsx_vsrai_h(__lsx_vilvl_b(vi6, vzero), 8); \ + } else { \ + vxi0 = __lsx_vilvl_b(vzero, vi0); \ + vxi1 = __lsx_vilvl_b(vzero, vi1); \ + vxi2 = __lsx_vilvl_b(vzero, vi2); \ + vxi3 = __lsx_vilvl_b(vzero, vi3); \ + vxi4 = __lsx_vilvl_b(vzero, vi4); \ + vxi5 = __lsx_vilvl_b(vzero, vi5); \ + vxi6 = __lsx_vilvl_b(vzero, vi6); \ + } \ + const __m128i vsum01 = __lsx_vadd_h(vxi0, vxi1); \ + const __m128i vsum23 = __lsx_vadd_h(vxi2, vxi3); \ + const __m128i vsum45 = __lsx_vadd_h(vxi4, vxi5); \ + const __m128i vsum016 = __lsx_vadd_h(vsum01, vxi6); \ + const __m128i vsum2345 = __lsx_vadd_h(vsum23, vsum45); \ + const __m128i vsum = __lsx_vadd_h(vsum016, vsum2345); \ + if constexpr (std::is_signed::value) { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); \ + } else { \ + vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); \ + vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); \ + } + + + T8Bits tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + bool finish_one_pass = false; + const __m128i vbias = __lsx_vreplgr2vr_w(Bias); + const __m128i vzero = __lsx_vldi(0); + size_t step_next_group = PixelsPerIteration * Stride - (Channels & ~size_t{7}); + + const T8Bits* i0 = Input; + const T8Bits* i1 = i0 + Stride; + const T8Bits* i2 = i1 + Stride; + const T8Bits* i3 = i2 + Stride; + const T8Bits* i4 = i0 + Stride * 4; + const T8Bits* i5 = i4 + Stride; + const T8Bits* i6 = i5 + Stride; + + for (; ImageSize > PixelsPerIteration; ImageSize -= PixelsPerIteration) { + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0 ,1); + const __m128i vi2 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0 ,1); + const __m128i vi3 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0 ,1); + const __m128i vi4 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0 ,1); + const __m128i vi5 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0 ,1); + const __m128i vi6 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0 ,1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + finish_one_pass = true; + + i0 += step_next_group; + i1 += step_next_group; + i2 += step_next_group; + i3 += step_next_group; + i4 += step_next_group; + i5 += step_next_group; + i6 += step_next_group; + } + + if (ImageSize > 0) { + switch (ImageSize) { + case 1: + i1 = ZeroBuffer; + [[fallthrough]]; + case 2: + i2 = ZeroBuffer; + [[fallthrough]]; + case 3: + i3 = ZeroBuffer; + [[fallthrough]]; + case 4: + i4 = ZeroBuffer; + [[fallthrough]]; + case 5: + i5 = ZeroBuffer; + [[fallthrough]]; + case 6: + i6 = ZeroBuffer; + [[fallthrough]]; + default: + break; + } + + int32_t* acc = AccumulateBuffer; + size_t c = Channels; + for (; c >= 8; c -= 8) { + + LOAD_FULL_CHANNELS(); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + acc += 8; + } + + if (c > 0) { + const __m128i vi0 = + __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); + const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(1 < ImageSize && i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0, 1); + const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(2 < ImageSize && i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0, 1); + const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(3 < ImageSize && i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0, 1); + const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(4 < ImageSize && i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0, 1); + const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(5 < ImageSize && i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0, 1); + const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld( + (const __m128i*)(6 < ImageSize && i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0, 1); + + CALCULATE_ACCUMULATE_VECTORS(); + + __lsx_vst(vacc_lo, (__m128i*)acc, 0); + __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); + } + } + MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, + Output_zero_point, 0, 0, 1, Channels); +} + #else // Pure C++ Implementation @@ -771,7 +1081,7 @@ MlasQLinearGlobalAveragePoolNhwc( #endif -#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) template void diff --git a/third_party/mlas/lib/qlmul.cpp b/third_party/mlas/lib/qlmul.cpp index 4b8537f2b3..38818e1190 100644 --- a/third_party/mlas/lib/qlmul.cpp +++ b/third_party/mlas/lib/qlmul.cpp @@ -377,6 +377,170 @@ MlasQLinearMulKernel( MLAS_UNREFERENCED_PARAMETER(ValueBVector); } +#elif defined(MLAS_LSX_INTRINSICS) + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ); + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvl_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + return __lsx_vilvh_b(ZeroVector, Int8Vector); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvl_b(Int8Vector, Int8Vector), 8); +} + +template <> +MLAS_FORCEINLINE +__m128i +MlasExtendToS16( + __m128i Int8Vector, + __m128i ZeroVector + ) +{ + MLAS_UNREFERENCED_PARAMETER(ZeroVector); + return __lsx_vsrai_h(__lsx_vilvh_b(Int8Vector, Int8Vector), 8); +} + +template +MLAS_FORCEINLINE +static +__m128i +MlasExtendToS16Debias( + __m128i Int8Vector, + __m128i ZeroVector, + __m128i VectorBias + ) +{ + return __lsx_vsub_h(MlasExtendToS16(Int8Vector, ZeroVector), VectorBias); +} + +MLAS_FORCEINLINE +static +__m128i +MlasQLinearMulVectorS16( + __m128i va_s16x8, + __m128i vb_s16x8, + __m128 VectorScaleRatio, + __m128 VectorZeroPointC + ) +{ + __m128i tmp, tmp1; + + const auto ab_lo = __lsx_vmul_h(va_s16x8, vb_s16x8); + const auto ab_hi = __lsx_vmuh_h(va_s16x8, vb_s16x8); + auto r_lo = __lsx_vilvl_h(ab_hi, ab_lo); + auto r_hi = __lsx_vilvh_h(ab_hi, ab_lo); + r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_lo), VectorScaleRatio, VectorZeroPointC)); + r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_hi), VectorScaleRatio, VectorZeroPointC)); + + tmp = __lsx_vsat_w(r_lo, 15); + tmp1 = __lsx_vsat_w(r_hi, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +template +static +void +MlasQLinearMulKernel( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const auto VectorZeroPointA = __lsx_vreplgr2vr_h((int16_t)ZeroPointA); + const auto VectorZeroPointB = __lsx_vreplgr2vr_h((int16_t)ZeroPointB); + const auto VectorZeroPointC = MlasBroadcastFloat32x4((float)ZeroPointC); + const auto VectorScaleRatio = MlasBroadcastFloat32x4(ScaleA * ScaleB / ScaleC); + const auto ZeroVector = __lsx_vldi(0); + + uint8_t TailDataA[16] = { 0 }; + uint8_t TailDataB[16] = { 0 }; + __m128i vb_lo_s16x8, vb_hi_s16x8; + + if (IsScalarB) { + vb_lo_s16x8 = __lsx_vsub_h(__lsx_vreplgr2vr_h((int16_t)*InputB), VectorZeroPointB); + vb_hi_s16x8 = vb_lo_s16x8; + } + + while (N > 0) { + if (N < 16) { + MlasCopyTailBytes(TailDataA, (const uint8_t*)InputA, N); + InputA = (const DataType*)TailDataA; + if (!IsScalarB) { + MlasCopyTailBytes(TailDataB, (const uint8_t*)InputB, N); + InputB = (const DataType*)TailDataB; + } + } + + const auto va_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputA, 0); + InputA += 16; + const auto va_lo_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + const auto va_hi_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); + + if (!IsScalarB) { + const auto vb_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputB, 0); + InputB += 16; + vb_lo_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + vb_hi_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); + } + + const auto vc_lo_s16x8 = MlasQLinearMulVectorS16(va_lo_s16x8, vb_lo_s16x8, VectorScaleRatio, VectorZeroPointC); + const auto vc_hi_s16x8 = MlasQLinearMulVectorS16(va_hi_s16x8, vb_hi_s16x8, VectorScaleRatio, VectorZeroPointC); + auto vc = MlasPackS16_128(vc_lo_s16x8, vc_hi_s16x8); + + if (N >= 16) { + __lsx_vst(vc, (__m128i*)OutputC, 0); + OutputC += 16; + N -= 16; + } else { + __lsx_vst(vc, (__m128i*)TailDataA, 0); + MlasCopyTailBytes((uint8_t*)OutputC, TailDataA, N); + N = 0; + } + } +} + + #else // Pure C++ implementation. diff --git a/third_party/mlas/lib/quantize.cpp b/third_party/mlas/lib/quantize.cpp index c6e8af38c0..ae638fafee 100644 --- a/third_party/mlas/lib/quantize.cpp +++ b/third_party/mlas/lib/quantize.cpp @@ -20,7 +20,10 @@ Module Name: #include "mlasi.h" -#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) +#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || \ + defined(MLAS_LSX_INTRINSICS) + +#include // // QuantizeLinear implementation using NEON or SSE2 intrinsics. @@ -48,6 +51,9 @@ MlasQuantizeLinearVector( // is a NaN. FloatVector = vmaxnmq_f32(FloatVector, MinimumValueVector); FloatVector = vminnmq_f32(FloatVector, MaximumValueVector); +#elif defined(MLAS_LSX_INTRINSICS) + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); #else // N.B. MINPS and MAXPS returns the value from the second vector if the // value from the first vector is a NaN. @@ -63,6 +69,9 @@ MlasQuantizeLinearVector( #if defined(MLAS_NEON64_INTRINSICS) auto IntegerVector = vcvtnq_s32_f32(FloatVector); IntegerVector = vaddq_s32(IntegerVector, ZeroPointVector); +#elif defined(MLAS_LSX_INTRINSICS) + auto IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); #else // N.B. Assumes MXCSR has been configured with the default rounding mode of // "round to nearest even". @@ -79,6 +88,20 @@ MlasQuantizeLinearPackBytes( MLAS_INT32X4 IntegerVector ); +template +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + +template +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ); + #if defined(MLAS_NEON64_INTRINSICS) template @@ -100,6 +123,219 @@ MlasQuantizeLinearPackBytes( return vreinterpretq_s32_u8(ByteVector); } +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector); + WordVector = vuzp1q_u16(WordVector, WordVector); + return vreinterpretq_s32_u16(WordVector); +} + +template<> +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + // + // Swizzle the least significant u16 from each int32_t element to the + // bottom eight bytes of the vector register. + // + + int16x8_t WordVector = vreinterpretq_s16_s32(IntegerVector); + WordVector = vuzp1q_s16(WordVector, WordVector); + return vreinterpretq_s32_s16(WordVector); +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + vst1q_lane_s32(reinterpret_cast(Output), IntegerVector, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + vst1q_lane_s64(reinterpret_cast(Output), vreinterpretq_s64_s32(IntegerVector), 0); + } +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_u8(Output, vreinterpretq_u8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int8_t* Output + ) +{ + // Copies the lower 8-bit element of the vector into memory (Output). + vst1q_lane_s8(Output, vreinterpretq_s8_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + uint16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_u16(Output, vreinterpretq_u16_s32(IntegerVector), 0); +} + +template <> +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + int16_t* Output + ) +{ + // Copies the lower 16-bit element of the vector into memory (Output). + vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); +} + +#elif defined(MLAS_LSX_INTRINSICS) +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + + + tmp = __lsx_vmax_h(integervector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + integervector = __lsx_vpickev_b(tmp2, tmp2); + return integervector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 integervector + ) +{ + + __m128i tmp, tmp1; + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + + tmp = __lsx_vsat_h(integervector, 7); + tmp1 = __lsx_vsat_h(integervector, 7); + integervector = __lsx_vpickev_b(tmp1, tmp); + return integervector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + __lsx_vstelm_w(IntegerVector, reinterpret_cast(Output), 0, 0); + } else { + static_assert(std::is_same_v || std::is_same_v); + + __lsx_vstelm_d(IntegerVector, reinterpret_cast(Output), 0, 0); + } +} + + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(__lsx_vpickve2gr_w(IntegerVector, 0)); +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_w(IntegerVector, zero); + tmp2 = __lsx_vsat_wu(tmp, 15); + + IntegerVector = __lsx_vpickev_h(tmp2, tmp2); + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + __m128i tmp, tmp1; + + tmp = __lsx_vsat_w(IntegerVector, 15); + tmp1 = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp1, tmp); + return IntegerVector; +} #else template<> @@ -128,6 +364,86 @@ MlasQuantizeLinearPackBytes( return IntegerVector; } +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ +#if defined(MLAS_SSE41_INTRINSICS) + IntegerVector = _mm_packus_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#else + // Cannot use _mm_packus_epi32 because that was not available until SSE4.1. + // Instead, emulate by sign-extending the first 16-bits of each packed 32-bit element. + // Afterwards, can use _mm_packs_epi32, which is available on SSE2. + // See: https://stackoverflow.com/a/11028244 + + IntegerVector = _mm_slli_epi32(IntegerVector, 16); + IntegerVector = _mm_srai_epi32(IntegerVector, 16); // Sign-extend: undo left shift with right arithmetic shift + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. +#endif // defined(MLAS_SSE41_INTRINSICS) + + return IntegerVector; +} + +template<> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasQuantizeLinearPackBytes( + MLAS_INT32X4 IntegerVector + ) +{ + IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. + + return IntegerVector; +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStore4PackedValues( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + // Copies the lower 4 packed elements of the vector into memory (Output). + + if constexpr (std::is_same_v || std::is_same_v) { + *(reinterpret_cast(Output)) = _mm_cvtsi128_si32(IntegerVector); + } else { + static_assert(std::is_same_v || std::is_same_v); + +#if defined(MLAS_TARGET_IX86) + // x86 does not support _mm_cvtsi128_si64, so use _mm_maskmoveu_si128 instead. + constexpr uint32_t bytes_high_bit = 0x80808080; + const __m128i first_8_bytes_mask = _mm_set_epi32(0, 0, bytes_high_bit, bytes_high_bit); + _mm_maskmoveu_si128(IntegerVector, first_8_bytes_mask, reinterpret_cast(Output)); +#else + *(reinterpret_cast(Output)) = _mm_cvtsi128_si64(IntegerVector); +#endif // defined(MLAS_TARGET_IX86) + } +} + +template +MLAS_FORCEINLINE +void +MlasQuantizeLinearStoreSingleValue( + MLAS_INT32X4 IntegerVector, + OutputType* Output + ) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + + // Copies the lower element of the vector into memory (Output). + // Expects that the 32-bit element in lane 0 is already within the valid numerical + // range of the OutputType. + *Output = static_cast(_mm_cvtsi128_si32(IntegerVector)); +} + #endif template @@ -180,12 +496,7 @@ Return Value: MinimumValueVector, MaximumValueVector, ZeroPointVector); IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - -#if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_s32((int32_t*)Output, IntegerVector, 0); -#else - *((int32_t*)Output) = _mm_cvtsi128_si32(IntegerVector); -#endif + MlasQuantizeLinearStore4PackedValues(IntegerVector, Output); Input += 4; Output += 4; @@ -196,20 +507,99 @@ Return Value: #if defined(MLAS_NEON64_INTRINSICS) auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); #else auto FloatVector = _mm_load_ss(Input + n); #endif auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, MinimumValueVector, MaximumValueVector, ZeroPointVector); + MlasQuantizeLinearStoreSingleValue(IntegerVector, &Output[n]); + } +} + +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = MlasBroadcastFloat32x4(Scale); + auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); + auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); + auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); + + // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. + UnpackedType TmpOutput[4] = {}; + + while (N >= 4) { + + auto FloatVector = MlasLoadFloat32x4(Input); + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); + MlasQuantizeLinearStore4PackedValues(IntegerVector, &TmpOutput[0]); + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + + Input += 4; + N -= 4; + } + + for (size_t n = 0; n < N; n++) { + #if defined(MLAS_NEON64_INTRINSICS) - vst1q_lane_u8((uint8_t*)Output + n, vreinterpretq_u8_s32(IntegerVector), 0); + auto FloatVector = vld1q_dup_f32(Input + n); +#elif defined(MLAS_LSX_INTRINSICS) + MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); #else - *((uint8_t*)Output + n) = (uint8_t)_mm_cvtsi128_si32(IntegerVector); + auto FloatVector = _mm_load_ss(Input + n); #endif + auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, + MinimumValueVector, MaximumValueVector, ZeroPointVector); + + MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); + MlasSetInt4Element(Output, n, TmpOutput[0]); } } +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + void MLASCALL MlasQuantizeLinearS8Kernel( @@ -236,6 +626,68 @@ MlasQuantizeLinearU8Kernel( MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); } +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint +) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS4Kernel( +#else + MlasQuantizeLinearS4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU4Kernel( +#else + MlasQuantizeLinearU4Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + template<> void MLASCALL @@ -260,52 +712,143 @@ void MLASCALL MlasQuantizeLinear( const float* Input, - uint8_t* Output, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU8Kernel( +#else + MlasQuantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearU16Kernel( +#else + MlasQuantizeLinearU16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().QuantizeLinearS16Kernel( +#else + MlasQuantizeLinearS16Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +#else + +#if defined(MLAS_TARGET_POWER) + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + GetMlasPlatform().QuantizeLinearS16Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + uint16_t ZeroPoint ) { -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearU8Kernel( -#else - MlasQuantizeLinearU8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); + GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); } -#else - -#if defined(MLAS_TARGET_POWER) - -template<> void MLASCALL -MlasQuantizeLinear( +MlasQuantizeLinearS4( const float* Input, - int8_t* Output, + uint8_t* Output, size_t N, float Scale, int8_t ZeroPoint ) { - GetMlasPlatform().QuantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); + GetMlasPlatform().QuantizeLinearS4Kernel(Input, Output, N, Scale, ZeroPoint); } -template<> void MLASCALL -MlasQuantizeLinear( +MlasQuantizeLinearU4( const float* Input, uint8_t* Output, size_t N, float Scale, - uint8_t ZeroPoint + int8_t ZeroPoint ) { - GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); + GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint); } - #endif // @@ -381,6 +924,81 @@ MlasQuantizeLinear( float Scale, uint8_t ZeroPoint ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ); + +template +void +MLASCALL +MlasQuantizeLinear( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ); + +template +void +MLASCALL +MlasQuantizeLinearInt4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + +// QuantizeLinear INT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearS4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); +} + +// QuantizeLinear UINT4 implementation using the C++ runtime. +void +MLASCALL +MlasQuantizeLinearU4( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); +} #endif #endif @@ -1063,6 +1681,286 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_LSX_INTRINSICS) + +template +void +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + //TO BE CHECK + float min_f = float(std::numeric_limits::lowest() - ZeroPoint); + float max_f = float(std::numeric_limits::max() - ZeroPoint); + const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // + // Process 16 columns of the matrices at a time. + // + + while (n >= 16) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector0 = __lsx_vld((const __m128i*)&RowInput[0], 0); + __m128i IntegerVector1 = __lsx_vld((const __m128i*)&RowInput[4], 0); + __m128i IntegerVector2 = __lsx_vld((const __m128i*)&RowInput[8], 0); + __m128i IntegerVector3 = __lsx_vld((const __m128i*)&RowInput[12], 0); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = __lsx_vadd_w(IntegerVector0, __lsx_vld((const __m128i *)&bias[0], 0)); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, __lsx_vld((const __m128i *)&bias[4], 0)); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, __lsx_vld((const __m128i *)&bias[8], 0)); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, __lsx_vld((const __m128i *)&bias[12], 0)); + bias += 16; + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + + __m128 FloatVector0 = __lsx_vffint_s_w(IntegerVector0); + __m128 FloatVector1 = __lsx_vffint_s_w(IntegerVector1); + __m128 FloatVector2 = __lsx_vffint_s_w(IntegerVector2); + __m128 FloatVector3 = __lsx_vffint_s_w(IntegerVector3); + + if (scale != nullptr) { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[0], 0))); + FloatVector1 = __lsx_vfmul_s(FloatVector1, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[4], 0))); + FloatVector2 = __lsx_vfmul_s(FloatVector2, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[8], 0))); + FloatVector3 = __lsx_vfmul_s(FloatVector3, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[12], 0))); + scale += 16; + + } else { + + FloatVector0 = __lsx_vfmul_s(FloatVector0, PerMatrixScaleVector); + FloatVector1 = __lsx_vfmul_s(FloatVector1, PerMatrixScaleVector); + FloatVector2 = __lsx_vfmul_s(FloatVector2, PerMatrixScaleVector); + FloatVector3 = __lsx_vfmul_s(FloatVector3, PerMatrixScaleVector); + } + FloatVector0 = __lsx_vfmax_s(FloatVector0, MinimumValueVector); + FloatVector1 = __lsx_vfmax_s(FloatVector1, MinimumValueVector); + FloatVector2 = __lsx_vfmax_s(FloatVector2, MinimumValueVector); + FloatVector3 = __lsx_vfmax_s(FloatVector3, MinimumValueVector); + + FloatVector0 = __lsx_vfmin_s(FloatVector0, MaximumValueVector); + FloatVector1 = __lsx_vfmin_s(FloatVector1, MaximumValueVector); + FloatVector2 = __lsx_vfmin_s(FloatVector2, MaximumValueVector); + FloatVector3 = __lsx_vfmin_s(FloatVector3, MaximumValueVector); + + IntegerVector0 = __lsx_vftint_w_s(FloatVector0); + IntegerVector1 = __lsx_vftint_w_s(FloatVector1); + IntegerVector2 = __lsx_vftint_w_s(FloatVector2); + IntegerVector3 = __lsx_vftint_w_s(FloatVector3); + + IntegerVector0 = __lsx_vadd_w(IntegerVector0, ZeroPointVector); + IntegerVector1 = __lsx_vadd_w(IntegerVector1, ZeroPointVector); + IntegerVector2 = __lsx_vadd_w(IntegerVector2, ZeroPointVector); + IntegerVector3 = __lsx_vadd_w(IntegerVector3, ZeroPointVector); + + __m128i WordVector0; + __m128i WordVector1; + __m128i ByteVector; + + if (std::is_signed::value) { + + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(IntegerVector0, 15); + tmp1 = __lsx_vsat_w(IntegerVector1, 15); + WordVector0 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_w(IntegerVector2, 15); + tmp1 = __lsx_vsat_w(IntegerVector3, 15); + WordVector1 = __lsx_vpickev_h(tmp1, tmp); + + tmp = __lsx_vsat_h(WordVector0, 7); + tmp1 = __lsx_vsat_h(WordVector1, 7); + ByteVector = __lsx_vpickev_b(tmp1, tmp); + + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(IntegerVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector0 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(IntegerVector2, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(IntegerVector3, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + WordVector1 = __lsx_vpickev_b(tmp3, tmp2); + + tmp = __lsx_vmax_h(WordVector0, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(WordVector1, zero); + tmp3 = __lsx_vsat_hu(tmp, 7); + ByteVector = __lsx_vpickev_b(tmp3, tmp2); + + } + + __lsx_vst(ByteVector, (__m128i*)RowOutput, 0); + RowOutput += 16; + + n -= 16; + } + + // + // Process the remaining columns of the matrices. + // + + while (n > 0) { + + // + // Load the input data and optionally add the per-column bias. + // + + __m128i IntegerVector; + + if (n >= 4) { + + IntegerVector = __lsx_vld((const __m128i*)&RowInput[0], 0); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = __lsx_vadd_w(IntegerVector, __lsx_vld((const __m128i*)&bias[0], 0)); + bias += 4; + } + + } else { + + int32_t IntegerValue = *RowInput++; + + if (bias != nullptr) { + IntegerValue += *bias++; + } + IntegerVector = __lsx_vldrepl_w(&IntegerValue, 0); + } + + // + // Convert to integer values to float and apply the per-tensor or + // per-column scaling. + // + __m128 FloatVector = __lsx_vffint_s_w(IntegerVector); + __m128 ScaleVector; + + if (scale != nullptr) { + + if (n >= 4) { + ScaleVector = MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)scale, 0)); + scale += 4; + } else { + ScaleVector = (__m128)__lsx_vldrepl_w(scale, 0); + scale += 1; + } + + } else { + ScaleVector = PerMatrixScaleVector; + } + FloatVector = __lsx_vfmul_s(FloatVector, ScaleVector); + + FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); + FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); + + IntegerVector = __lsx_vftint_w_s(FloatVector); + IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); + + if (std::is_signed::value) { + + __m128i tmp; + tmp = __lsx_vsat_w(IntegerVector, 15); + IntegerVector = __lsx_vpickev_h(tmp, tmp); + + tmp = __lsx_vsat_h(IntegerVector, 7); + IntegerVector = __lsx_vpickev_b(tmp, tmp); + + } else { + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2; + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + tmp = __lsx_vmax_h(IntegerVector, zero); + tmp2 = __lsx_vsat_hu(tmp, 7); + IntegerVector = __lsx_vpickev_b(tmp2, tmp2); + + } + + uint32_t OutputValue = uint32_t(__lsx_vpickve2gr_w(IntegerVector, 0)); + + if (n >= 4) { + + *reinterpret_cast(RowOutput) = OutputValue; + RowOutput += 4; + + n -= 4; + + } else { + + *RowOutput = uint8_t(OutputValue); + RowOutput += 1; + + n -= 1; + } + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #else template diff --git a/third_party/mlas/lib/reorder.cpp b/third_party/mlas/lib/reorder.cpp index 99c1dbac3b..b329ea2ffb 100644 --- a/third_party/mlas/lib/reorder.cpp +++ b/third_party/mlas/lib/reorder.cpp @@ -180,6 +180,31 @@ Return Value: v[2] = _mm_movelh_ps(t[2], t[3]); v[3] = _mm_movehl_ps(t[3], t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); + MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); + MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); + MlasStoreFloat32x4(&D[ScatterStride * 3], v[3]); +#elif defined(MLAS_LSX_INTRINSICS) + + MLAS_FLOAT32X4 v[4]; + MLAS_FLOAT32X4 t[4]; + + v[0] = MlasLoadFloat32x4(&S[GatherStride * 0]); + v[1] = MlasLoadFloat32x4(&S[GatherStride * 1]); + v[2] = MlasLoadFloat32x4(&S[GatherStride * 2]); + v[3] = MlasLoadFloat32x4(&S[GatherStride * 3]); + + t[0] = (__m128)__lsx_vilvl_w((__m128i)v[1], (__m128i)v[0]); + t[2] = (__m128)__lsx_vilvh_w((__m128i)v[1], (__m128i)v[0]); + t[1] = (__m128)__lsx_vilvl_w((__m128i)v[3], (__m128i)v[2]); + t[3] = (__m128)__lsx_vilvh_w((__m128i)v[3], (__m128i)v[2]); + + + v[0] = (__m128)__lsx_vpickev_d((__m128i) t[1],(__m128i) t[0]); + v[1] = (__m128)__lsx_vpickod_d((__m128i) t[1],(__m128i) t[0]); + v[2] = (__m128)__lsx_vpickev_d((__m128i) t[3],(__m128i) t[2]); + v[3] = (__m128)__lsx_vpickod_d((__m128i) t[3],(__m128i) t[2]); + MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); @@ -456,7 +481,6 @@ Return Value: &TaskStart, &TasksRemaining); size_t TaskEnd = TaskStart + TasksRemaining; - // // Rebase the pointers to the source and destination buffers for this thread. // @@ -567,18 +591,17 @@ Return Value: WorkBlock.S = S; WorkBlock.D = D; - WorkBlock.OutputChannels = size_t(OutputShape[1]); WorkBlock.OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); const size_t BlockSize = MlasNchwcGetBlockSize(); const size_t TasksPerBatch = size_t(ceil(((float)WorkBlock.OutputChannels) / BlockSize)); const size_t BatchCount = size_t(OutputShape[0]); - const size_t TasksCount = BatchCount * TasksPerBatch; + const size_t TasksCount = BatchCount * TasksPerBatch; WorkBlock.TasksCount = TasksCount; // - // Schedule the operation across a set of worker threads if the output + // Schedule the operation across a set of worker threads if the output // tensor is sufficienly large. Limit the number of threads to at least // the number of available tasks. // @@ -590,7 +613,7 @@ Return Value: if (size_t(TargetThreadCount) > TasksCount) { TargetThreadCount = ptrdiff_t(TasksCount); } - } + } WorkBlock.TargetThreadCount = TargetThreadCount; MlasExecuteThreaded(MlasReorderOutputNchwThreaded, &WorkBlock, TargetThreadCount, ThreadPool); diff --git a/third_party/mlas/lib/sbgemm.h b/third_party/mlas/lib/sbgemm.h new file mode 100644 index 0000000000..de7fd72fad --- /dev/null +++ b/third_party/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/third_party/mlas/lib/sbgemm_kernel_neon.cpp b/third_party/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000..a6a73996c5 --- /dev/null +++ b/third_party/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/third_party/mlas/lib/sgemm.cpp b/third_party/mlas/lib/sgemm.cpp index 1ce64712d6..4d7a1ceb4e 100644 --- a/third_party/mlas/lib/sgemm.cpp +++ b/third_party/mlas/lib/sgemm.cpp @@ -472,7 +472,7 @@ Return Value: const float* b = B; size_t x = CountX; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* SgemmTransposePackB16x4Routine = GetMlasPlatform().TransposePackB16x4Routine; @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/third_party/mlas/lib/snchwc.cpp b/third_party/mlas/lib/snchwc.cpp index 74d65f934a..f9cf160578 100644 --- a/third_party/mlas/lib/snchwc.cpp +++ b/third_party/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/third_party/mlas/lib/sqnbitgemm.cpp b/third_party/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000..4b852be951 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,669 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + as well as some SQNBitGemm-related query functions. +--*/ + +#include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +#include + +namespace +{ + +enum SQNBitGemmVariant { + SQNBitGemmVariantInvalid = -1, + + // Valid variants + + SQNBitGemmVariant_BitWidth4_CompFp32 = 0, + SQNBitGemmVariant_BitWidth4_CompInt8, + + // End of valid variants + + // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Its value is used as an array size. + SQNBitGemmVariantCount, +}; + +SQNBitGemmVariant +GetSQNBitGemmVariant( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + if (BlkBitWidth == 4 && + (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (ComputeType == CompFp32 || + ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == CompInt8) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } + } + + return SQNBitGemmVariantInvalid; +} + +} // namespace + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return false; + } + + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: { + return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case SQNBitGemmVariant_BitWidth4_CompInt8: { + return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && + Dispatch->QuantizeARow_CompInt8 != nullptr; + } + default: { + return false; + } + } +} + +namespace +{ + +size_t +SQNBitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + } + + return 0; +} + +size_t +SQNBitGemmPerGemmWorkspaceAlignment( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 1; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + } + + return 1; +} + +size_t +SQNBitGemmPerGemmWorkspaceStride( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + return MlasDivRoundup(Size, Alignment) * Alignment; +} + +} // namespace + +size_t MLASCALL +MlasSQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + if (PerGemmWorkspaceStride == 0) { + return 0; + } + + const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + + return WorkspaceSize + Alignment - 1; +} + +size_t MLASCALL +MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); + } + + return 0; +} + +void MLASCALL +MlasSQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool +) +{ + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBData), + ThreadPool + ); + return; + } +} + +namespace +{ + +MLAS_FORCEINLINE void +AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +typedef void(SQNBitGemmFn)( + size_t BlkLen, + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* PerGemmWorkspace, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +void +SQ4BitGemm_CompFp32( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + BlkLen, + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +void +SQ4BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ +#ifdef MLAS_TARGET_AMD64_IX86 + if (RangeCountM != 1) { + // perf experiment shows fp32 is faster than int8 in M > 1 cases. + // route to fp32 compute before int8 compute is improved. + SQ4BitGemm_CompFp32( + BlkLen, + K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN + ); + return; + } +#endif + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } +} + +typedef void(InitializeWorkspaceFn)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +); + +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); +} + +struct Operations { + InitializeWorkspaceFn* InitializeWorkspace = nullptr; + SQNBitGemmFn* SQNBitGemm = nullptr; +}; + +constexpr auto OperationMap = []() { + std::array ops; + + ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; + + ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; + ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; + + return ops; +}(); + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +) +{ + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + // + // Ensure `Workspace` has correct alignment. + // + if (Workspace != nullptr) { + const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); + Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) + ); + } + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + InitializeWorkspaceOperation != nullptr) { + InitializeWorkspaceOperation( + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + ); + } + + const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = reinterpret_cast( + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride + ); + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} diff --git a/third_party/mlas/lib/sqnbitgemm.h b/third_party/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000..effb59b250 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm.h @@ -0,0 +1,233 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. + + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +// +// Kernel dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + + // + // Workspace size calculation function prototypes. + // + + /** + * @brief Gets the required size in bytes of the per-GEMM intermediate workspace. + * Returns a size of zero if no intermediate workspace is needed. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + + /** + * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. + * + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + + // + // CompFp32 kernel function prototypes. + // + + /** + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + // + // CompInt8 kernel function prototypes. + // + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + + /** + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeARow_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA + ); + + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; +}; diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/third_party/mlas/lib/sqnbitgemm_kernel_avx2.cpp new file mode 100644 index 0000000000..be573381c3 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -0,0 +1,1116 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx2.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +MLAS_FORCEINLINE +__m256 +load_float_n_avx2(const float* data, int n) +{ + assert(n <= 8); + if (n <= 0) { + return _mm256_setzero_ps(); + } + static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + const __m256i load_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - n)); + return _mm256_maskload_ps(data, load_mask); +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + + constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + // TODO: constexpr use temaplte parameter + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + const __m128i low_mask = _mm_set1_epi8(0xF); + for (size_t col = 0; col < CountN; col += NCols8) { + const int cols = std::min((int)NCols8, (int)CountN - (int)col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + const float* scale_ptr = QuantBScale + col * BlockCountK + k; + const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; + bool is_lower = (k % 2) == 0; + + __m256i weight_16_epi16[NCols8]; + __m256 scale_8_ps[NCols8]; + UnrolledLoop([&](size_t col_) { + if ((int)col_ < cols) { + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + __m128i bvi = _mm_loadl_epi64((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + const __m128i lower = _mm_and_si128(bvi, low_mask); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi, 4), low_mask), 8); + __m128i weight_16_epi8 = _mm_add_epi8(upper, lower); + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, _mm_set1_epi8(zp)); + } else { + const __m128i eight = _mm_set1_epi8(8); + weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, eight); + } + weight_16_epi16[col_] = _mm256_cvtepi8_epi16(weight_16_epi8); + scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); + } else { + weight_16_epi16[col_] = _mm256_setzero_si256(); + scale_8_ps[col_] = _mm256_setzero_ps(); + } + }); + for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if ((int)col_ < cols) { + if (i_of_2 == 0) { + __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); + } else { + __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = _mm256_setzero_ps(); + } + } + // transpose and store + __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); + __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); + + __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); + } + } + } +} + +template +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + constexpr size_t SubblkLen32 = 32; // process SubblkLen32 rows of QuantB at a time + + const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + // TODO: constexpr use temaplte parameter + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] int count_half_4 = 0; + + const __m256i low_mask = _mm256_set1_epi8(0xF); + for (size_t col = 0; col < CountN; col += NCols8) { + // TODO: handle last tile with cols < NCols8 + const size_t cols = std::min(NCols8, CountN - col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + const float* scale_ptr = QuantBScale + col * BlockCountK + k; + const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; + bool is_lower = (k % 2) == 0; + + for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { + __m256i weight_32_epi8[NCols8]; + __m256 scale_8_ps[NCols8]; + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (subblk % 2); + } + UnrolledLoop([&](size_t col_) { + if (col_ < cols) { + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 + // so that all v0-64 of the pack are dequantized. + const __m256i bvi = _mm256_loadu_si256((__m256i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + weight_32_epi8[col_] = _mm256_and_si256(_mm256_srli_epi16(bvi, count_half_4), low_mask); + } else { + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + __m128i bvi = _mm_loadu_si128((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); + __m128i lower = _mm_and_si128(bvi, _mm256_castsi256_si128(low_mask)); + __m128i upper = _mm_and_si128(_mm_srli_epi16(bvi, 4), _mm256_castsi256_si128(low_mask)); + weight_32_epi8[col_] = _mm256_set_m128i(upper, lower); + } + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], _mm256_set1_epi8(zp)); + } else { + const __m256i eight = _mm256_set1_epi8(8); + weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], eight); + } + + scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); + } else { + weight_32_epi8[col_] = _mm256_setzero_si256(); + scale_8_ps[col_] = _mm256_setzero_ps(); + } + }); + for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if (col_ < cols) { + if (i_of_4 == 0) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 1) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 2) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 3) { + __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); + __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); + weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = _mm256_setzero_ps(); + } + } + // transpose and store + __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); + __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); + __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); + __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); + __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); + + __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); + + const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); + weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); + _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); + } + dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; + if constexpr (IsBlkLen64Layout) { + b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; + } else { + b_data_ptr += subblk_data_size_in_bytes; + } + } // subblk + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemm_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +) +{ + if (BlkLen == 16) { + Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else if (BlkLen == 32) { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen16_CompFp32_avx2( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen16 = 16; + constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); + static_assert(SubBlkStep8 == 8); // 16 * 4 / 8 + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { + int kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); + + // Load A row vectors + int n_to_read = std::min(kklen, 8); + __m256 av_lo = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + n_to_read = std::min(kklen - 8, 8); + __m256 av_hi = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + + UnrolledLoop([&](size_t i) { + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // Load B col vectors. get SubBlkLen(16) 4 bits quantized features from each column + __m128i bvi4 = _mm_loadl_epi64((__m128i const*)(b_blk_data_col_ptr[i])); + b_blk_data_col_ptr[i] += SubBlkStep8; + + // TODO: avoid _mm_set1_epi8 + //__m128i lower_mask_epi8 = _mm_cmpeq_epi16(bvi4, bvi4); // can use any __m128i + // lower_mask_epi8 = _mm_srli_epi16(lower_mask_epi8, 13); + // lower_mask_epi8 = _mm_packus_epi16(lower_mask_epi8, lower_mask_epi8); + __m128i lower_mask_epi8 = _mm_set1_epi8(0x0F); // Mask to isolate the lower 4 bits + + const __m128i lower = _mm_and_si128(bvi4, lower_mask_epi8); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4, 4), lower_mask_epi8), 8); + __m256i bv_epi16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); // unpacked 16 weights of epi16 + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + __m256i zp = _mm256_set1_epi16(offset[i]); + bv_epi16 = _mm256_sub_epi16(bv_epi16, zp); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi16(8); + bv_epi16 = _mm256_sub_epi16(bv_epi16, eight); + } + + // Convert to 16 epi16 to 16 float32 + const __m128i bv_lo = _mm256_extractf128_si256(bv_epi16, 0); + const __m128i bv_hi = _mm256_extractf128_si256(bv_epi16, 1); + + __m256 bvf_lo = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_lo)); + __m256 bvf_hi = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_hi)); + + // multiply by scale + __m256 scale_ps = _mm256_set1_ps(scale_v[i]); + bvf_lo = _mm256_mul_ps(bvf_lo, scale_ps); + bvf_hi = _mm256_mul_ps(bvf_hi, scale_ps); + + // c[m,n] += a[m,k] * b[k,n] + acc[i] = _mm256_fmadd_ps(bvf_lo, av_lo, acc[i]); + acc[i] = _mm256_fmadd_ps(bvf_hi, av_hi, acc[i]); + }); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&sum_ptr[i], vsum); + sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + }); + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf +template +void +SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + ComputeDotProducts_BlkLen16_CompFp32_avx2( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkLen16_CompFp32_avx2<1, HasZeroPoint>( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen32PlusAvx512f to improve perf +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen32 = 32; + constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); + static_assert(SubBlkStep16 == 16); // 32 * 4 / 8 + + __m256i lowMask = _mm256_set1_epi8(0x0F); + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] int count_half_4 = 0; + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = *(s + StrideQuantBScale * i); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { + int kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); + + // Load 4 float8 from A + int n_to_read = std::min(kklen, 8); + __m256 av0_8_ps = load_float_n_avx2(ARowPtr + k + kk, n_to_read); + + n_to_read = std::min(kklen - 8, 8); + __m256 av1_8_ps = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); + + n_to_read = std::min(kklen - 16, 8); + __m256 av2_8_ps = load_float_n_avx2(ARowPtr + k + kk + 16, n_to_read); + + n_to_read = std::min(kklen - 24, 8); + __m256 av3_8_ps = load_float_n_avx2(ARowPtr + k + kk + 24, n_to_read); + + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); + } + UnrolledLoop([&](size_t i) { + // Load B col vectors. get SubBlkLen32 4b quantized weights from each column + __m256i bv_32_epi8; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // increment b_data_ptr by 2 * SubBlkStep16 if kk % (2 * SubBlkLen32) == 1 + // so that all v0-63 of the pack are processed. + const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_blk_data_col_ptr[i])); + bv_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); + b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; + } else { + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + __m128i bvi4 = _mm_loadu_si128((const __m128i*)(b_blk_data_col_ptr[i])); + b_blk_data_col_ptr[i] += SubBlkStep16; + + bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bv_32_epi8 = _mm256_and_si256(lowMask, bv_32_epi8); + } + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + __m256i zp = _mm256_set1_epi8(offset[i]); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, zp); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, eight); + } + + // Convert to 16 float32 + const __m256i bv0_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 0)); + const __m256i bv1_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 1)); + + __m256 bv0_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 0))); + __m256 bv1_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 1))); + __m256 bv2_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 0))); + __m256 bv3_8_ps = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 1))); + + // multiply by scale + __m256 scale_ps = _mm256_set1_ps(scale_v[i]); + bv0_8_ps = _mm256_mul_ps(bv0_8_ps, scale_ps); + bv1_8_ps = _mm256_mul_ps(bv1_8_ps, scale_ps); + bv2_8_ps = _mm256_mul_ps(bv2_8_ps, scale_ps); + bv3_8_ps = _mm256_mul_ps(bv3_8_ps, scale_ps); + + // c[m,n] += a[m,k] * b[k,n] + acc[i] = _mm256_fmadd_ps(bv0_8_ps, av0_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv1_8_ps, av1_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv2_8_ps, av2_8_ps, acc[i]); + acc[i] = _mm256_fmadd_ps(bv3_8_ps, av3_8_ps, acc[i]); + }); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&sum_ptr[i], vsum); + sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + }); + } +} + +// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf +template +void +SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + while (nblk >= 0) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, true>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, false>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_avx2( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +void MLASCALL +QuantizeARow_CompInt8_avx2( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + // port from MlasQ80BlkQuantRow + assert(BlkLen % 16 == 0); + const __m256 signBit = _mm256_set1_ps(-0.0f); + int8_t* blob = reinterpret_cast(QuantA); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t step = std::min(BlkLen, CountK - k); + + __m256 maxAbs = _mm256_setzero_ps(); + for (size_t kk = 0; kk < step; kk += 8) { + const int klen = std::min(8, (int)(step - kk)); + + __m256 v0 = load_float_n_avx2(A + k + kk, klen); + + // Compute max(abs(e)) for the block + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, v0)); + } + + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(maxAbs, 1), _mm256_castps256_ps128(maxAbs)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_shuffle_ps(max4, max4, 1)); + const float maxScalar = _mm_cvtss_f32(max4); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *reinterpret_cast(blob) = scale; + blob += sizeof(float); + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps(inverse_scale); + __m128i* dst = reinterpret_cast<__m128i*>(blob); + + for (size_t kk = 0; kk < step; kk += 16) { + const int klen = std::min(16, (int)(step - kk)); + + int n_to_read = std::min(klen, 8); + __m256 v0 = load_float_n_avx2(A + k + kk, n_to_read); + v0 = _mm256_mul_ps(v0, mul); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + + __m256 v1; + n_to_read = std::min(klen - 8, 8); + if (n_to_read <= 0) { + v1 = _mm256_setzero_ps(); + } else { + v1 = load_float_n_avx2(A + k + kk + 8, n_to_read); + v1 = _mm256_mul_ps(v1, mul); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + } + + __m128i i_8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_8); + } + if (step < BlkLen) { + memset(blob + step, 0, BlkLen - step); + } + blob += BlkLen; + } +} + +// +// Kernel dispatch structure definition. +// +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + + return d; +}(); diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/third_party/mlas/lib/sqnbitgemm_kernel_avx512.cpp new file mode 100644 index 0000000000..0099b61d81 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -0,0 +1,246 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx512.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx512. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +// +// CompFp32 kernel implementation. +// + +#include "sqnbitgemm_kernel_avx_common_fp32.h" + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_avx512( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else if (BlkLen == 32) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else /*if (BlkLen >= 64)*/ { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } +} + +// +// CompInt8 kernel implementation. +// + +void MLASCALL +MlasQ80BlkQuantRow_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + // port from MlasQ80BlkQuantRow + assert(BlkLen % 16 == 0); + const __m512 signBit = _mm512_set1_ps(-0.0f); + int8_t* blob = reinterpret_cast(QuantA); + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t step = std::min(BlkLen, CountK - k); + + __m512 maxAbs = _mm512_setzero_ps(); + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + // Compute max(abs(e)) for the block + maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0)); + } + + __m256 max8 = + _mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + const float maxScalar = _mm_cvtss_f32(max4); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *reinterpret_cast(blob) = scale; + blob += sizeof(float); + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const __m512 mul = _mm512_set1_ps(inverse_scale); + __m128i* dst = reinterpret_cast<__m128i*>(blob); + + for (size_t kk = 0; kk < step; kk += 16) { + const size_t klen = std::min(size_t(16), step - kk); + + uint32_t mask = 0xffff >> (16 - klen); + __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + v0 = _mm512_mul_ps(v0, mul); + + // Round to nearest integer + v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST); + + // Convert floats to integers + __m512i i0 = _mm512_cvtps_epi32(v0); + + // Convert int32 to int8 + __m128i i0_8 = _mm512_cvtepi32_epi8(i0); + _mm_storeu_si128(dst++, i0_8); + } + if (step < BlkLen) { + memset(blob + step, 0, BlkLen - step); + } + blob += BlkLen; + } +} + +void MLASCALL +QuantizeARow_CompInt8_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); +} + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + + return d; +}(); diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/third_party/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp new file mode 100644 index 0000000000..27310d8253 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -0,0 +1,267 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx512.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx512vnni. + +--*/ + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx_common_fp32.h" +#include "sqnbitgemm_kernel_avx_common_int8.h" + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen16Avx512f( + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else if (BlkLen == 32) { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } else /*if (BlkLen >= 64)*/ { + if (QuantBZeroPoint != nullptr) { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } else { + MlasQ4GemmKernelBlkLen32PlusAvx512f( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + 1, + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0, + 0 + ); + } + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8_avx512vnni( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } else { + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +void MLASCALL +MlasQ80BlkQuantRow_avx512( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); + +// +// Kernel dispatch structure definition. +// +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; + d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + + return d; +}(); diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx_common.h b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common.h new file mode 100644 index 0000000000..cfc0564cd0 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -0,0 +1,418 @@ +#pragma once +#include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +// +// Quantized B data packing function implementation. +// + +static size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +static void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +// +// Workspace size calculation function implementation. +// + +static size_t +SQ4BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch(ComputeType) { + case CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +static size_t +SQ4BitGemmPerGemmWorkspaceAlignment( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + switch (ComputeType) { + case CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +void +Q4BitBlkDequantBForSgemm_CompFp32_avx2( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +); + +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +// +// General helpers. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +// this function is used to dot product 2 pairs of 32 epi8s. it is used with Int8 precision +// and blklen >= 64. In this case, 64 of 4b weights are filled with one load. +static MLAS_FORCEINLINE __m256 +dot_quad_avx512vnni( + const __m256i bv0_32_epi8, const __m256i bv1_32_epi8, const __m256i av0_32_epi8, const __m256i av1_32_epi8 +) +{ + const __m256i zero = _mm256_setzero_si256(); + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(zero, _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + sum_8_epi32 = _mm256_dpbusd_epi32(sum_8_epi32, _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + return _mm256_cvtepi32_ps(sum_8_epi32); +} + +static MLAS_FORCEINLINE __m256 +dot_quad_avx2( + const __m256i b0, const __m256i b1, const __m256i a0, const __m256i a1 +) +{ + // Perform multiplication and create 16-bit values + const __m256i ones = _mm256_set1_epi16(1); + __m256i sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b0, b0), _mm256_sign_epi8(a0, b0)); + __m256i summed_pair_epi32 = _mm256_madd_epi16(ones, sum_epi16); + + sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b1, b1), _mm256_sign_epi8(a1, b1)); + summed_pair_epi32 = _mm256_add_epi32(_mm256_madd_epi16(ones, sum_epi16), summed_pair_epi32); + return _mm256_cvtepi32_ps(summed_pair_epi32); +} + +// TODO: refactor load_and_mul_sum_s8_quads_with_zp_avx512vnni, load_and_mul_sum_s8_quads_with_zp_avx2 +// and accumulate_mul_sum_avx512vnni, accumulate_mul_sum_avx2 +static MLAS_FORCEINLINE void +load_and_mul_sum_s8_quads_with_zp_avx512vnni( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const int8_t zp, const __m256 scale0, __m256& acc0 +) +{ + // load B + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + + // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i + // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. + // const __m128i low_mask = _mm_set1_epi8(15); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + + //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + // const __m256i low_mask = _mm256_set1_epi8(15); + // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); + + const __m256i bzp0 = _mm256_set1_epi8(zp); + bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); + // quantized dot product + __m256i dot_0_epi32 = _mm256_dpbusd_epi32( + zero, _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) + ); + const __m256 sum_ps = _mm256_cvtepi32_ps(dot_0_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); +} + +static MLAS_FORCEINLINE void +load_and_mul_sum_s8_quads_with_zp_avx2( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i, const int8_t zp, const __m256 scale0, __m256& acc0 +) +{ + // load B + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + + // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i + // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. + // const __m128i low_mask = _mm_set1_epi8(15); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + + //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + // const __m256i low_mask = _mm256_set1_epi8(15); + // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); + + const __m256i bzp0 = _mm256_set1_epi8(zp); + bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); + // quantized dot product + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); +} + +template +int8_t MLAS_FORCEINLINE +get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointPtr; + } + + if constexpr (HasZeroPoint) { + return is_lower_half_byte_zp ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + return 8; + } +} + +// this function load and unpack 32 4b weights (packed for BlkLen32) and dot product it with 32 +// epi8 input. dot products are accumulated into acc0. +// This function is called for Int8 precision with BlkLen = 32. +template +using AccumulateFunctionType = void (*)( + const __m256i, const __m128i*, const __m128i, const __m256i, const std::byte*, bool, const float, __m256& +); + +template +static MLAS_FORCEINLINE void +accumulate_mul_sum_avx512vnni( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 +) +{ + const __m256 scale0 = _mm256_set1_ps(combined_scale); + const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); + load_and_mul_sum_s8_quads_with_zp_avx512vnni( + av_0_epi8, reinterpret_cast(QuantBDataPtr), + low_mask, zero, + zp, scale0, acc0 + ); +} + +template +static MLAS_FORCEINLINE void +accumulate_mul_sum_avx2( + const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 +) +{ + const __m256 scale0 = _mm256_set1_ps(combined_scale); + const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); + load_and_mul_sum_s8_quads_with_zp_avx2( + av_0_epi8, reinterpret_cast(QuantBDataPtr), + low_mask, zero, + zp, scale0, acc0 + ); +} + +/** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 +FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) +{ + __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); + __m256 acc_hi01 = _mm256_unpackhi_ps(acc0, acc1); + __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); + __m256 acc_hi23 = _mm256_unpackhi_ps(acc2, acc3); + + __m256 acc_lo0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) + ); + __m256 acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); + + __m128 acc_y = + _mm_add_ps(_mm256_extractf128_ps(acc_lo0123, 0), _mm256_extractf128_ps(acc_lo0123, 1)); + return acc_y; +} + +static inline float +hsum_float_8(const __m256 x) +{ + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +/** + * @brief Horizontally sum 4 vectors and store + * the results in the returned vector + */ +static MLAS_FORCEINLINE __m128 +FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) +{ + __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); + __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); + __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); + __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); + + __m512 acc_lo0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) + ); + __m512 acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + acc_hi0123 = _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) + ); + acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); + + __m256 acc_y = + _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); + return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); +} +} // namespace diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h new file mode 100644 index 0000000000..5cd380e591 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -0,0 +1,639 @@ +#pragma once +#include "sqnbitgemm.h" + +template +MLAS_FORCEINLINE + size_t + MlasQ4GemmKernelBlkLen16Avx512f( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + // We process 32 quantized values in a batch. + // assert(BlkLen % 32 == 0) + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols = 4; + constexpr size_t BlkLen16 = 16; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m128i lowMask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + //*// + ////const float* BiasPtr = Bias; + + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + ////float* SumPtr = CRowPtr; + //*// + + auto* sum_ptr = C; + const auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + __m512 acc_lo0 = _mm512_setzero_ps(); + __m512 acc_lo1 = _mm512_setzero_ps(); + __m512 acc_lo2 = _mm512_setzero_ps(); + __m512 acc_lo3 = _mm512_setzero_ps(); + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen16) { + size_t kklen = std::min(CountK - k, BlkLen16); + + const float scale_v0 = *(s); + const float scale_v1 = *(s + StrideQuantBScale * 1); + const float scale_v2 = *(s + StrideQuantBScale * 2); + const float scale_v3 = *(s + StrideQuantBScale * 3); + + const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); + const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); + const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); + const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); + + // Load A row vector of 16 floats + uint32_t mask = 0xffff >> (BlkLen16 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); + + // Load B col vectors of 16 of 4b + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + const __m128i bvi4_0 = _mm_loadl_epi64(b0ptr++); + const __m128i bvi4_1 = _mm_loadl_epi64(b1ptr++); + const __m128i bvi4_2 = _mm_loadl_epi64(b2ptr++); + const __m128i bvi4_3 = _mm_loadl_epi64(b3ptr++); + + // expand 4b into byte array + __m128i lower = _mm_and_si128(bvi4_0, lowMask); + __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); + __m128i bytes0 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_1, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_1, 4), lowMask), 8); + __m128i bytes1 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_2, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_2, 4), lowMask), 8); + __m128i bytes2 = _mm_add_epi8(upper, lower); + + lower = _mm_and_si128(bvi4_3, lowMask); + upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_3, 4), lowMask), 8); + __m128i bytes3 = _mm_add_epi8(upper, lower); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + + bytes0 = _mm_sub_epi8(bytes0, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes1 = _mm_sub_epi8(bytes1, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes2 = _mm_sub_epi8(bytes2, _mm_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes3 = _mm_sub_epi8(bytes3, _mm_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m128i eight = _mm_set1_epi8(8); + bytes0 = _mm_sub_epi8(bytes0, eight); + bytes1 = _mm_sub_epi8(bytes1, eight); + bytes2 = _mm_sub_epi8(bytes2, eight); + bytes3 = _mm_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + const __m256i vx16_0 = _mm256_cvtepi8_epi16(bytes0); + const __m256i vx16_1 = _mm256_cvtepi8_epi16(bytes1); + const __m256i vx16_2 = _mm256_cvtepi8_epi16(bytes2); + const __m256i vx16_3 = _mm256_cvtepi8_epi16(bytes3); + + // Convert to 32-bit int -> float 32 + __m512 bvf_0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_0)); + __m512 bvf_1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_1)); + __m512 bvf_2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_2)); + __m512 bvf_3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_3)); + + __m512 scale_ps = _mm512_set1_ps(scale_v0); + bvf_0 = _mm512_mul_ps(bvf_0, scale_ps); + scale_ps = _mm512_set1_ps(scale_v1); + bvf_1 = _mm512_mul_ps(bvf_1, scale_ps); + scale_ps = _mm512_set1_ps(scale_v2); + bvf_2 = _mm512_mul_ps(bvf_2, scale_ps); + scale_ps = _mm512_set1_ps(scale_v3); + bvf_3 = _mm512_mul_ps(bvf_3, scale_ps); + + acc_lo0 = _mm512_fmadd_ps(bvf_0, av_lo, acc_lo0); + acc_lo1 = _mm512_fmadd_ps(bvf_1, av_lo, acc_lo1); + acc_lo2 = _mm512_fmadd_ps(bvf_2, av_lo, acc_lo2); + acc_lo3 = _mm512_fmadd_ps(bvf_3, av_lo, acc_lo3); + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + + } // k + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + + //*// + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + ////BiasPtr += BiasPtr != nullptr ? NCols : 0; + ////SumPtr += NCols; + + ////nblk -= NCols; + //*// + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m512 acc_lo[4]{}; + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen16) { + size_t klen = std::min(CountK - k, BlkLen16); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + //*// + scale_v[nn] = *(s + StrideQuantBScale * nn); + b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); + //*// + } + + uint32_t mask = 0xffff >> (BlkLen16 - klen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); + + for (int64_t nn = 0; nn < nblk; nn++) { + // Load B col vectors of 16 of 4b + // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + const __m128i bvi4_0 = _mm_loadl_epi64(b_ptr[nn]++); + + // expand 4b into byte array + __m128i lower = _mm_and_si128(bvi4_0, lowMask); + __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); + __m128i bytes = _mm_add_epi8(upper, lower); + + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes = _mm_sub_epi8(bytes, _mm_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m128i eight = _mm_set1_epi8(8); + bytes = _mm_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16 = _mm256_cvtepi8_epi16(bytes); + + // Convert to 32-bit int -> float 32 + __m512 bvf = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16)); + __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); + bvf = _mm512_mul_ps(bvf, scale_16_ps); + + acc_lo[nn] = _mm512_fmadd_ps(bvf, av_lo, acc_lo[nn]); + } + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + } // k + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + A += lda; + } + return CountM; +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4GemmKernelBlkLen32PlusAvx512f( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + // We process 32 quantized values in a batch. + // assert(BlkLen % 32 == 0) + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols = 4; + constexpr size_t MLAS_QUANT4_BLK_UNIT32 = 32; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i lowMask = _mm256_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + //*// + ////const float* BiasPtr = Bias; + + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + ////float* SumPtr = CRowPtr; + //*// + + auto* sum_ptr = C; + const auto* bias_ptr = Bias; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + __m512 acc_lo0 = _mm512_setzero_ps(); + __m512 acc_lo1 = _mm512_setzero_ps(); + __m512 acc_lo2 = _mm512_setzero_ps(); + __m512 acc_lo3 = _mm512_setzero_ps(); + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float scale_v0 = *(s); + const float scale_v1 = *(s + StrideQuantBScale * 1); + const float scale_v2 = *(s + StrideQuantBScale * 2); + const float scale_v3 = *(s + StrideQuantBScale * 3); + + const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); + const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); + const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); + const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); + + // Load A row vectors + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + // Load B col vectors + __m256i bytes0, bytes1, bytes2, bytes3; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 + // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd + // so that all v0-63 of the pack are processed. + const __m256i bvi4_0 = _mm256_loadu_si256((__m256i const*)(b0ptr)); + const __m256i bvi4_1 = _mm256_loadu_si256((__m256i const*)(b1ptr)); + const __m256i bvi4_2 = _mm256_loadu_si256((__m256i const*)(b2ptr)); + const __m256i bvi4_3 = _mm256_loadu_si256((__m256i const*)(b3ptr)); + const int count_half_4 = + 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); + bytes0 = _mm256_and_si256(_mm256_srli_epi16(bvi4_0, count_half_4), lowMask); + bytes1 = _mm256_and_si256(_mm256_srli_epi16(bvi4_1, count_half_4), lowMask); + bytes2 = _mm256_and_si256(_mm256_srli_epi16(bvi4_2, count_half_4), lowMask); + bytes3 = _mm256_and_si256(_mm256_srli_epi16(bvi4_3, count_half_4), lowMask); + b0ptr += count_half_4 / 2; + b1ptr += count_half_4 / 2; + b2ptr += count_half_4 / 2; + b3ptr += count_half_4 / 2; + } else { + const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); + const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); + const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); + const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); + + // expand 4b into byte array + bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); + bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); + bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); + bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); + bytes0 = _mm256_and_si256(lowMask, bytes0); + bytes1 = _mm256_and_si256(lowMask, bytes1); + bytes2 = _mm256_and_si256(lowMask, bytes2); + bytes3 = _mm256_and_si256(lowMask, bytes3); + } + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + + bytes0 = _mm256_sub_epi8(bytes0, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes1 = _mm256_sub_epi8(bytes1, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes2 = _mm256_sub_epi8(bytes2, _mm256_set1_epi8(zp)); + + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes3 = _mm256_sub_epi8(bytes3, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes0 = _mm256_sub_epi8(bytes0, eight); + bytes1 = _mm256_sub_epi8(bytes1, eight); + bytes2 = _mm256_sub_epi8(bytes2, eight); + bytes3 = _mm256_sub_epi8(bytes3, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); + const __m256i vx16_hi0 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); + const __m256i vx16_lo1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); + const __m256i vx16_hi1 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); + const __m256i vx16_lo2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); + const __m256i vx16_hi2 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); + const __m256i vx16_lo3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); + const __m256i vx16_hi3 = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); + __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); + __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); + __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); + __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); + __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); + __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); + __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); + + __m512 scale_ps = _mm512_set1_ps(scale_v0); + bvf_lo0 = _mm512_mul_ps(bvf_lo0, scale_ps); + bvf_hi0 = _mm512_mul_ps(bvf_hi0, scale_ps); + scale_ps = _mm512_set1_ps(scale_v1); + bvf_lo1 = _mm512_mul_ps(bvf_lo1, scale_ps); + bvf_hi1 = _mm512_mul_ps(bvf_hi1, scale_ps); + scale_ps = _mm512_set1_ps(scale_v2); + bvf_lo2 = _mm512_mul_ps(bvf_lo2, scale_ps); + bvf_hi2 = _mm512_mul_ps(bvf_hi2, scale_ps); + scale_ps = _mm512_set1_ps(scale_v3); + bvf_lo3 = _mm512_mul_ps(bvf_lo3, scale_ps); + bvf_hi3 = _mm512_mul_ps(bvf_hi3, scale_ps); + + acc_lo0 = _mm512_fmadd_ps(bvf_lo0, av_lo, acc_lo0); + acc_lo0 = _mm512_fmadd_ps(bvf_hi0, av_hi, acc_lo0); + acc_lo1 = _mm512_fmadd_ps(bvf_lo1, av_lo, acc_lo1); + acc_lo1 = _mm512_fmadd_ps(bvf_hi1, av_hi, acc_lo1); + acc_lo2 = _mm512_fmadd_ps(bvf_lo2, av_lo, acc_lo2); + acc_lo2 = _mm512_fmadd_ps(bvf_hi2, av_hi, acc_lo2); + acc_lo3 = _mm512_fmadd_ps(bvf_lo3, av_lo, acc_lo3); + acc_lo3 = _mm512_fmadd_ps(bvf_hi3, av_hi, acc_lo3); + } // kk + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + + } // k + + __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); + if (Bias != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); + } + _mm_storeu_ps(sum_ptr, acc_x); + + // move to next 4 columns + sum_ptr += 4; + bias_ptr += 4; + nblk -= 4; + + //*// + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + ////BiasPtr += BiasPtr != nullptr ? NCols : 0; + ////SumPtr += NCols; + + ////nblk -= NCols; + //*// + } + + // left over columns less than 4 ? + nblk += 4; + if (nblk > 0) { + __m512 acc_lo[4]{}; + + //*// + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + //*// + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx = 0; + } + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[4]; + const __m128i* b_ptr[4]; + for (int64_t nn = 0; nn < nblk; nn++) { + //*// + scale_v[nn] = *(s + StrideQuantBScale * nn); + b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); + //*// + } + + for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { + size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); + + uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); + __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); + + mask = mask >> 16; + __m512 av_hi = mask == 0 + ? _mm512_setzero_ps() + : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); + + for (int64_t nn = 0; nn < nblk; nn++) { + __m256i bytes; + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 + // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd + // so that all v0-63 of the pack are processed. + const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_ptr[nn])); + const int count_half_4 = + 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); + bytes = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); + b_ptr[nn] += count_half_4 / 2; + } else { + const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); + bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); + bytes = _mm256_and_si256(lowMask, bytes); + } + if constexpr (HasZeroPoint) { + // Subtract zero-point from the integers + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + + // TODO: void condition on is_lower + std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); + } else { + // Subtract 8 from the integers + const __m256i eight = _mm256_set1_epi8(8); + bytes = _mm256_sub_epi8(bytes, eight); + } + + // Convert to 16-bit int + const __m256i vx16_lo = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); + const __m256i vx16_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); + + // Convert to 32-bit int -> float 32 + __m512 bvf_lo = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); + __m512 bvf_hi = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); + __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); + bvf_lo = _mm512_mul_ps(bvf_lo, scale_16_ps); + bvf_hi = _mm512_mul_ps(bvf_hi, scale_16_ps); + + acc_lo[nn] = _mm512_fmadd_ps(bvf_lo, av_lo, acc_lo[nn]); + acc_lo[nn] = _mm512_fmadd_ps(bvf_hi, av_hi, acc_lo[nn]); + } + } // kk + + //*// + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + s++; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + //*// + } // k + + for (int64_t nn = 0; nn < nblk; nn++) { + sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); + sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; + } + } + + // Prepare pointers for the next row + C += ldc; + A += lda; + } + return CountM; +} diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h new file mode 100644 index 0000000000..250ffeacd7 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -0,0 +1,745 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_q8_block.h" + +void +SQ4BitGemmM1Kernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + if constexpr (!HasZeroPoint) { + // Suppress unused variable warnings + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + assert(BlkLen == 16); + constexpr size_t SubBlkLen = 16; + const __m128i low_mask = _mm_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + + __m256 acc[NCols]; + UnrolledLoop([&](size_t i) { + acc[i] = _mm256_setzero_ps(); + }); + + const std::byte* ablob = QuantARowPtr; + const auto* b = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + scale_v[i] = (*(s + StrideQuantBScale * i)) * a_scale; + }); + + std::byte* bptr[NCols]; + UnrolledLoop([&](size_t i) { + bptr[i] = (std::byte*)(b + StrideQuantBData * i); + }); + + [[maybe_unused]] uint8_t offset[NCols]; + // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + // Load A row vector + const __m128i av_epi8 = _mm_lddqu_si128((const __m128i*)ablob); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + ablob += BlkLen; + + // Load 4 B column vectors (quantized to int4 blobs) + __m128i bvi[NCols]; + UnrolledLoop([&](size_t i) { + bvi[i] = _mm_loadl_epi64((__m128i const*)bptr[i]); + bptr[i] += SubBlkStep; + }); + + // expand 4b into byte array + __m256i bv_epi16[NCols]; + UnrolledLoop([&](size_t i) { + const __m128i lower = _mm_and_si128(bvi[i], low_mask); + const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi[i], 4), low_mask), 8); + bv_epi16[i] = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); + }); + + // Subtract zero-point from the integers + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], _mm256_set1_epi16(offset[i])); + }); + } else { + const __m256i eight = _mm256_set1_epi16(8); + UnrolledLoop([&](size_t i) { + bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], eight); + }); + } + + UnrolledLoop([&](size_t i) { + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_epi16[i], av_epi16); + + const __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc[i] = _mm256_fmadd_ps(_mm256_set1_ps(scale_v[i]), prod_8_ps, acc[i]); + }); + + b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + s++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } else { + UnrolledLoop([&](size_t i) { + __m128 vlow = _mm256_castps256_ps128(acc[i]); + __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit + + // Add the two 128-bit vectors together + __m128 vsum = _mm_add_ps(vlow, vhigh); + // Horizontally add the elements of the resulting 128-bit vector + vsum = _mm_hadd_ps(vsum, vsum); + vsum = _mm_hadd_ps(vsum, vsum); + + _mm_store_ss(&SumPtr[i], vsum); + SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i]; + }); + } +} + +template +void +SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t BlkLen16 = 16; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( + BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than `NCols`? + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16<1, HasZeroPoint>( + BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template accumulator> +void +SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + const size_t NCols = 4; + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +using DotQuadFunctionType = __m256 (*)( + const __m256i, const __m256i, const __m256i, const __m256i +); + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + // TODO: make it work with all BlkLens + assert(BlkLen >= 64); + constexpr size_t SubBlkLen64 = 64; + // const __m256i zero = _mm256_setzero_si256(); + const __m256i low_mask = _mm256_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(), acc2 = _mm256_setzero_ps(), acc3 = _mm256_setzero_ps(); + + const std::byte* ablob = QuantARowPtr; + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* blk_scale_ptr = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float + scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale, + scale_v1 = (*(blk_scale_ptr + StrideQuantBScale * 1)) * a_scale, + scale_v2 = (*(blk_scale_ptr + StrideQuantBScale * 2)) * a_scale, + scale_v3 = (*(blk_scale_ptr + StrideQuantBScale * 3)) * a_scale; + + const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); + const std::byte* bptr1 = (b_blk_data_ptr + StrideQuantBData * 1); + const std::byte* bptr2 = (b_blk_data_ptr + StrideQuantBData * 2); + const std::byte* bptr3 = (b_blk_data_ptr + StrideQuantBData * 3); + + uint8_t zp0, zp1, zp2, zp3; + if constexpr (HasZeroPoint) { + // TODO: this block causes near 30% of the computation. + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp1 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp2 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp3 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + } else { + zp0 = 8; + zp1 = 8; + zp2 = 8; + zp3 = 8; + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { + // Load A row vector + const __m256i av0_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + const __m256i av1_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + + // Load B column vectors (quantized to int4 blobs) + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); + bptr0 += SubBlkStep; + __m256i bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + __m256i zp_epi8 = _mm256_set1_epi8(zp0); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + __m256 sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); + + bv = _mm256_loadu_si256((__m256i const*)bptr1); + bptr1 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v1), sum_ps, acc1); + + bv = _mm256_loadu_si256((__m256i const*)bptr2); + bptr2 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp2); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v2), sum_ps, acc2); + + bv = _mm256_loadu_si256((__m256i const*)bptr3); + bptr3 += SubBlkStep; + bv0_32_epi8 = _mm256_and_si256(bv, low_mask); + bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + zp_epi8 = _mm256_set1_epi8(zp3); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); + sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); + acc3 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v3), sum_ps, acc3); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + blk_scale_ptr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); +} + +// TODO: is this able to be inlined if DotQuadFunctionType is a function pointer? +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( + size_t BlkLen, + const std::byte* QuantARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + // TODO: make it work with all BlkLens + assert(BlkLen >= 64); + constexpr size_t SubBlkLen64 = 64; + // const __m256i zero = _mm256_setzero_si256(); + const __m256i low_mask = _mm256_set1_epi8(0xF); + + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); + + __m256 acc0 = _mm256_setzero_ps(); + + const std::byte* ablob = QuantARowPtr; + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* blk_scale_ptr = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + const float a_scale = Q8BlkScale(ablob); + ablob += sizeof(float); + + float scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale; + + const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); + + uint8_t zp0; + if constexpr (HasZeroPoint) { + // TODO: this block causes near 30% of the computation. + // The solution proposed by @yufenglee is to factor out scaleB * zp + // while packing A. Will do this in next PR. + bool is_lower = (QuantBZeroPointIdx & 1) == 0; + std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + } else { + zp0 = 8; + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { + // Load A row vector + const __m256i a_byte_lo = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + const __m256i a_byte_hi = _mm256_loadu_si256((const __m256i*)ablob); + ablob += 32; + + // Load B column vectors (quantized to int4 blobs) + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); + bptr0 += SubBlkStep; + __m256i bv_lo_epi8 = _mm256_and_si256(bv, low_mask); + __m256i bv_hi_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); + __m256i zp_epi8 = _mm256_set1_epi8(zp0); + bv_lo_epi8 = _mm256_sub_epi8(bv_lo_epi8, zp_epi8); + bv_hi_epi8 = _mm256_sub_epi8(bv_hi_epi8, zp_epi8); + __m256 sum_ps = dot_quad(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); + //__m256 sum_ps = mul_sum_s8_quads_float_avx2(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); + } // kk + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + blk_scale_ptr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } // k + + *SumPtr = hsum_float_8(acc0); + *SumPtr += BiasPtr == nullptr ? 0.0f : *BiasPtr; +} + +template +void +SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const size_t NCols = 4; + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( + BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, + SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/third_party/mlas/lib/sqnbitgemm_kernel_neon.cpp b/third_party/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000..6d1864794f --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,1501 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include + +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +// +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +// +// Workspace size calculation function implementation. +// + +size_t +SQ4BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + switch(ComputeType) { + case CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } +} + +size_t +SQ4BitGemmPerGemmWorkspaceAlignment( + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + switch (ComputeType) { + case CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +} // namespace + +// +// General helpers. +// + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE void +Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + Transpose4x4(a0, a1, a2, a3); + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +namespace fp32_conversion +{ + +// Manual conversion to float takes place in two steps: +// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. +// This target float range is convenient because the 4-bit source values can be placed directly into the +// target float bits. +// 2. Subtract the conversion offset of 16 from the float result. + +// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. +constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; +// sign|exponent|partial mantissa +// +|131: 2^4|~~~~ <- 4 bits go here + +const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + +constexpr float offset = 16.0f; + +} // namespace fp32_conversion + +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + float32x4_t acc[NCols]{}; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint is true + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. + // only used if HasZeroPoint is true + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = fp32_conversion::offset + std::to_integer(zp); + }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + +// Block dequantize a 16 x NCols section of B from column major source to row major destination. +template +MLAS_FORCEINLINE void +Q4BitBlkDequantB_16xNCols( + const std::byte* QuantBDataPtr, + size_t StrideQuantBData, + const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns + [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns + // only used if HasZeroPoint is true + float* DstColPtr +) +{ + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData + ); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset and zero point + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // write, transposed, 16 x NCols values + if constexpr (NCols == 4) { + UnrolledLoop<4>([&](size_t j) { + Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); + + vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); + vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); + vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); + vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); + }); + } else { + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { + DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); + DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); + DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); + DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); + }); + }); + } +} + +template +void +Q4BitBlkDequantBForSgemm_CompFp32_Impl( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + constexpr size_t BlkBitWidth = 4; + + float* Dst = FpData; + + const std::byte* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true + + const size_t StrideQuantBData = BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true + MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + + // + // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. + // + + // scales of blocks from 16 adjacent columns + float scale[16]; + // float conversion offsets (including zero point) of blocks from 16 adjacent columns + [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true + + size_t n_cols_remaining = CountN; + while (n_cols_remaining > 15) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < 16; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + constexpr size_t NCols = 4; + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < 16; nn += NCols) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += NCols; + if constexpr (HasZeroPoint) { + OffsetPtr += NCols; + } + DstColPtr += NCols; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + + n_cols_remaining -= 16; + + QuantBDataCol += 16 * StrideQuantBData; + QuantBScaleCol += 16 * BlockStrideQuantB; + if constexpr (HasZeroPoint) { + QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; + } + } + + if (n_cols_remaining > 0) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + scale[nn] = QuantBScaleCol[nn * BlockStrideQuantB + k_blk_idx]; + + if constexpr (HasZeroPoint) { + const std::byte zp_packed = + QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; + const std::byte zp = ((k_blk_idx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[nn] = fp32_conversion::offset + std::to_integer(zp); + } + } + + const size_t kklen = std::min(CountK - k, BlkLen); + + for (size_t kk = 0; kk < kklen; kk += 16) { + // zero out the 16x16 block in Dst first to ensure zero padding + const float32x4_t zero_v = vdupq_n_f32(0.0f); + UnrolledLoop<16 * 4>([&](size_t i) { + vst1q_f32(Dst + 4 * i, zero_v); + }); + + const float* ScalePtr = &scale[0]; + const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; + + float* DstColPtr = Dst; + + for (size_t nn = 0; nn < n_cols_remaining; ++nn) { + const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; + + Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( + QuantBDataPtr, + StrideQuantBData, + ScalePtr, + OffsetPtr, + DstColPtr + ); + + ScalePtr += 1; + if constexpr (HasZeroPoint) { + OffsetPtr += 1; + } + DstColPtr += 1; + } + + Dst += 16 * std::min(kklen - kk, size_t{16}); + } + } + } +} + +void +Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + if (QuantBZeroPoint != nullptr) { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemm_CompFp32_Impl( + BlkLen, + FpData, + QuantBData, + QuantBScale, + QuantBZeroPoint, + CountN, + CountK, + BlockStrideQuantB + ); + } +} + +// +// CompInt8 kernel implementation. +// + +template +MLAS_FORCEINLINE void +QuantizeBlock( + size_t BlkLen, + const float* A, + size_t ElementCount, + std::byte* QuantA +) +{ + static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); + + assert(BlkLen % SubBlkLen == 0); + + // + // Scan block values first to determine scale. + // + + float amax = 0.0f; // max of absolute values of A block + + size_t k; + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } + } + + // update existing amax + amax = std::max(amax, vmaxvq_f32(abs_a[0])); + } + + constexpr float range_max = (1 << 7) - 1; + const float scale = amax / range_max; + const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; + + Q8BlkScale(QuantA) = scale; + + // + // Compute quantized block values. + // + + int8_t* QuantAData = Q8BlkData(QuantA); + + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + + float32x4_t a[SubBlkLen / 4]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + + UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + + UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); + QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); + }); + } + + // + // Zero out any remaining sub-block elements. + // + + for (; k < BlkLen; k += SubBlkLen) { + const int8x16_t Zeros = vdupq_n_s8(0); + UnrolledLoop([&](size_t i) { + vst1q_s8(QuantAData + k + i * 16, Zeros); + }); + } +} + +void +QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +) +{ + const float* ADataBlkPtr = A; + std::byte* QuantABlkPtr = QuantA; + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); + + ADataBlkPtr += BlkLen; + QuantABlkPtr += Q8BlkSize(BlkLen); + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 16; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] >> 4) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); + + // load B + const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); + const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); + int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + bv1 = vsubq_s8(bv1, bzp1); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 8 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); + + // load B + const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); + + const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); + const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); + + int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp0); + + // quantized dot product + const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); + const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + const int8x16_t bzp1 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); + const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + bv_lo1 = vsubq_s8(bv_lo1, bzp1); + bv_hi1 = vsubq_s8(bv_hi1, bzp1); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); + + // increment block pointers + + QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + + // compute combined scale + const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp0 = vdupq_n_s8( + HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 + ); + + // load A + const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); + const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + + int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + + // subtract B zero point + bv_lo0 = vsubq_s8(bv_lo0, bzp0); + bv_hi0 = vsubq_s8(bv_hi0, bzp0); + + // quantized dot product + int32x4_t dot0{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen > 32); + assert(BlkLen % 32 == 0); + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); + + // process blocks in 32-element sub-blocks + const size_t SubBlksPerBlk = BlkLen / 32; + + for (size_t n = 0; n < CountN; ++n) { + const std::byte* QuantAPtr = QuantA; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + float32x4_t acc0{}, acc1{}; + + for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { + // compute combined scale + const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr)); + + // load B zero point + const int8x16_t bzp = [&]() -> int8x16_t { + if constexpr (HasZeroPoint) { + return vdupq_n_s8( + ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) + : std::to_integer((*QuantBZeroPointPtr) >> 4) + ); + } else { + return vdupq_n_s8(8); + } + }(); + + const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); + + for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { + // load A + const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); + const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); + const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); + const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); + + // load B + const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); + const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); + + int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); + int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); + int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); + int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); + + // subtract B zero point + bv0 = vsubq_s8(bv0, bzp); + bv1 = vsubq_s8(bv1, bzp); + bv2 = vsubq_s8(bv2, bzp); + bv3 = vsubq_s8(bv3, bzp); + + // quantized dot product + int32x4_t dot0{}, dot1{}; + dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); + dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + + // convert to float + const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); + const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); + + // multiply by scale and update accumulator + acc0 = vfmaq_f32(acc0, dot_f32_0, scale); + acc1 = vfmaq_f32(acc1, dot_f32_1, scale); + + // increment block data pointers to next sub-block + QuantADataPtr += 16 * 4; + QuantBDataPtr += 16 * 2; + } + + // increment other block pointers + + QuantAPtr += Q8BlkSize(BlkLen); + QuantBScalePtr += 1; + + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; + } + } + + *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t /*CountK*/, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } +} + +} // namespace + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + + return d; +}(); diff --git a/third_party/mlas/lib/sqnbitgemm_q8_block.h b/third_party/mlas/lib/sqnbitgemm_q8_block.h new file mode 100644 index 0000000000..80af2f4679 --- /dev/null +++ b/third_party/mlas/lib/sqnbitgemm_q8_block.h @@ -0,0 +1,70 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_q8_block.h + +Abstract: + + This module includes helper functions for manipulating blocks of quantized + int8 (Q8) values. + +--*/ + +#pragma once + +#include +#include +#include + +#include "mlasi.h" + +MLAS_FORCEINLINE +const float& +Q8BlkScale(const std::byte* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +MLAS_FORCEINLINE +float& +Q8BlkScale(std::byte* BlkPtr) +{ + return *reinterpret_cast(BlkPtr); +} + +MLAS_FORCEINLINE +const int8_t* +Q8BlkData(const std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +int8_t* +Q8BlkData(std::byte* BlkPtr) +{ + return reinterpret_cast(BlkPtr + sizeof(float)); +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} + +MLAS_FORCEINLINE +constexpr size_t +Q8BlkAlignment() +{ + return alignof(float); +} diff --git a/third_party/mlas/lib/threading.cpp b/third_party/mlas/lib/threading.cpp index ecdc5250eb..dc5daf998d 100644 --- a/third_party/mlas/lib/threading.cpp +++ b/third_party/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/third_party/mlas/lib/transpose.cpp b/third_party/mlas/lib/transpose.cpp index 86b0897bb9..a758a0e59f 100644 --- a/third_party/mlas/lib/transpose.cpp +++ b/third_party/mlas/lib/transpose.cpp @@ -371,6 +371,121 @@ MlasTranspose16x16Block( vec_vsx_st(e0, 0, &Output[OutputStride * 14]); vec_vsx_st(e1, 0, &Output[OutputStride * 15]); } + +#elif defined(MLAS_LSX_INTRINSICS) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + + __m128i b0 = __lsx_vilvl_w(a2, a0); + __m128i b1 = __lsx_vilvh_w(a2, a0); + __m128i b2 = __lsx_vilvl_w(a3, a1); + __m128i b3 = __lsx_vilvh_w(a3, a1); + __m128i c0 = __lsx_vilvl_w(b2, b0); + __m128i c1 = __lsx_vilvh_w(b2, b0); + __m128i c2 = __lsx_vilvl_w(b3, b1); + __m128i c3 = __lsx_vilvh_w(b3, b1); + + __lsx_vst(c0, (__m128i*)&Output[OutputStride * 0], 0); + __lsx_vst(c1, (__m128i*)&Output[OutputStride * 1], 0); + __lsx_vst(c2, (__m128i*)&Output[OutputStride * 2], 0); + __lsx_vst(c3, (__m128i*)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint16_t* Input, + size_t InputStride, + uint16_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0 , 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0 , 1); + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0 , 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0 , 1); + + __m128i b0 = __lsx_vilvl_h(a2, a0); + __m128i b1 = __lsx_vilvl_h(a3, a1); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(c0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(c0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(c1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(c1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); +} + +MLAS_FORCEINLINE +void +MlasTranspose8x8Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); + __lsx_vinsgr2vr_d(a0, 0, 1); + __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); + __lsx_vinsgr2vr_d(a1, 0, 1); + __m128i b0 = __lsx_vilvl_b(a1, a0); + + __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); + __lsx_vinsgr2vr_d(a2, 0, 1); + __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); + __lsx_vinsgr2vr_d(a3, 0, 1); + __m128i b1 = __lsx_vilvl_b(a3, a2); + + __m128i a4 = __lsx_vld((const __m128i*)&Input[InputStride * 4], 0); + __lsx_vinsgr2vr_d(a4, 0, 1); + __m128i a5 = __lsx_vld((const __m128i*)&Input[InputStride * 5], 0); + __lsx_vinsgr2vr_d(a5, 0, 1); + __m128i b2 = __lsx_vilvl_b(a5, a4); + + __m128i a6 = __lsx_vld((const __m128i*)&Input[InputStride * 6], 0); + __lsx_vinsgr2vr_d(a6, 0, 1); + __m128i a7 = __lsx_vld((const __m128i*)&Input[InputStride * 7], 0); + __lsx_vinsgr2vr_d(a7, 0, 1); + __m128i b3 = __lsx_vilvl_b(a7, a6); + __m128i c0 = __lsx_vilvl_h(b1, b0); + __m128i c1 = __lsx_vilvh_h(b1, b0); + __m128i c2 = __lsx_vilvl_h(b3, b2); + __m128i c3 = __lsx_vilvh_h(b3, b2); + + __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + + __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + + __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + + __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); + __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); +} + #endif template @@ -472,7 +587,8 @@ Return Value: uint32_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ + defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -597,7 +713,7 @@ Return Value: uint16_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -734,7 +850,7 @@ Return Value: uint8_t* d = Output; size_t m = M; -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) while (m >= 8) { diff --git a/third_party/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/third_party/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp index 955b7c5dee..43a12b37e4 100644 --- a/third_party/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ b/third_party/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp @@ -171,11 +171,9 @@ Return Value: if (k > 0) { Row0AElements0 = a[0]; - Row0AElements1 = a[1]; if (ProcessTwoRows) { Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; } BElements0 = MlasLoadFloat32x4(B + 0); diff --git a/third_party/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S b/third_party/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S new file mode 100644 index 0000000000..7d042e2d8f --- /dev/null +++ b/third_party/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S @@ -0,0 +1,234 @@ +/*++ + +Copyright (c) 2023 Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + AssembleAmx.h + +Abstract: + + This module contains macros to build AMX instructions for toolchains that + do not natively support this newer instruction set extension. + +--*/ + +// +// Map friendly register names to the encoded register index. +// + + .equ .LTmmIndex_tmm0, 0 + .equ .LTmmIndex_tmm1, 1 + .equ .LTmmIndex_tmm2, 2 + .equ .LTmmIndex_tmm3, 3 + .equ .LTmmIndex_tmm4, 4 + .equ .LTmmIndex_tmm5, 5 + .equ .LTmmIndex_tmm6, 6 + .equ .LTmmIndex_tmm7, 7 + +/*++ + +Macro Description: + + This macro builds a AMX instruction of the form: + + instr tmm1,tmm2,tmm3 + +Arguments: + + prefix - Specifies the opcode for the AMX instruction. + + DestReg - Specifies the destination AMX tile. + + Src1Reg - Specifies the first source AMX tile. + + Src2Reg - Specifies the second source AMX tile. + +--*/ + + .macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) + + .set Payload1, \prefix\() + .set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3) + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3) + .set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7) + + .byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte + + .endm + + + .macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + + + .macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg + + DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() + + .endm + +/*++ + +Macro Description: + + This macro builds a AMX tile release instruction. + +Arguments: + + + +--*/ + +// .macro TileReleaseMacro + +// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0 + +// .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile zero instruction of the form: + + instr tmm1 + +Arguments: + + SrcReg - Specifies the source AMX tile. + +--*/ + + .macro TileZeroMacro SrcReg + + .set ModRMByte, 0xC0 # register form + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + .byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte + + .endm + +/*++ + +Macro Description: + + This macro builds an AMX memory instruction of the form: + + instr tmm, base, stride + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + SrcReg - Specifies the target AMX tile. + + BaseReg - Specifies the base address of memory location. + + Stride - Specifies the stride for the memory instruction + +--*/ + + .macro TileLoadMacro instr, SrcReg, BaseReg, Stride + + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7) + .set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form + .set ModRMByte, ModRMByte + (1 << 2) # SibBye required + .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) + + .set SibByte, 0x00 # scale factor 1(SS) + .set SibByte, SibByte + ((3 & 7) << 3) + .set SibByte, SibByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte + + .endm + + + .macro TileloaddTmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + .macro TileloaddT1TmmMem DstReg, BaseReg, Stride + TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\() + .endm + + + .macro TileStoredMemTmm SrcReg, BaseReg, Stride + TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\() + .endm + + +/*++ + +Macro Description: + + This macro builds an AMX tile configuration instruction of the form: + + instr base + +Arguments: + + instr - Specifies the opcode for the AMX instruction. + + BaseReg - Specifies the memory address of the tile configuration. + +--*/ + + .macro tilecfgMacro instr, BaseReg + .set Payload0, 0x02 # "0F 38" prefix + .set Payload0, Payload0 + (1 << 7) + .set Payload0, Payload0 + (1 << 6) + .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) + + .set ModRMByte, 0x00 # memory form & no reg + .set ModRMByte, ModRMByte + (0 & 7) + + .byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte + + .endm + + + .macro ldtilecfgMacro BaseReg + + tilecfgMacro 0x78, \BaseReg\() + + .endm + + + .macro sttilecfgMacro BaseReg + + tilecfgMacro 0x79, \BaseReg\() + + .endm + diff --git a/third_party/mlas/lib/x86_64/SoftmaxKernelAvx512F.S b/third_party/mlas/lib/x86_64/SoftmaxKernelAvx512F.S new file mode 100644 index 0000000000..db97286046 --- /dev/null +++ b/third_party/mlas/lib/x86_64/SoftmaxKernelAvx512F.S @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SoftmaxKernelAvx512F.s + +Abstract: + + This module implements the kernels for the single precision softmax + operation. + + This implementation uses AVX512F instructions. + +--*/ + +#include "asmmacro.h" + + .intel_syntax noprefix + + .text + +/*++ + +Routine Description: + + This routine implements a vectorized kernel to find the maximum value of + the supplied buffer. + +Arguments: + + Input (rdi) - Supplies the input buffer. + + N (rsi) - Supplies the number of elements to process. + +Return Value: + + Returns the maximum value of the supplied buffer. + +--*/ + + FUNCTION_ENTRY MlasReduceMaximumF32KernelAvx512F + + vbroadcastss zmm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] + test rsi,rsi + jz .LReduceMaximum.ExitKernel + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountBy1 + cmp rsi,64 + jb .LReduceMaximum.ProcessRemainingCountBy16 + vmovaps zmm1,zmm0 + vmovaps zmm2,zmm0 + vmovaps zmm3,zmm0 + +.LReduceMaximum.ProcessRemainingCountBy64: + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + vmaxps zmm1,zmm1,ZMMWORD PTR [rdi+16*4] + sub rsi,64 + vmaxps zmm2,zmm2,ZMMWORD PTR [rdi+32*4] + vmaxps zmm3,zmm3,ZMMWORD PTR [rdi+48*4] + add rdi,64*4 # advance input by 64 elements + cmp rsi,64 + jae .LReduceMaximum.ProcessRemainingCountBy64 + vmaxps zmm0,zmm0,zmm1 # reduce to single vector + vmaxps zmm2,zmm2,zmm3 + vmaxps zmm0,zmm0,zmm2 + +.LReduceMaximum.ProcessRemainingCountBy16: + cmp rsi,16 + jb .LReduceMaximum.ProcessRemainingCountLessThan16 + vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] + sub rsi,16 + add rdi,16*4 # advance input by 16 elements + jmp .LReduceMaximum.ProcessRemainingCountBy16 + +.LReduceMaximum.ProcessRemainingCountLessThan16: + vextractf32x8 ymm1,zmm0,1 # reduce to single scalar + vmaxps ymm0,ymm0,ymm1 + vextractf128 xmm1,ymm0,1 + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0xEE + vmaxps xmm0,xmm0,xmm1 + vshufps xmm1,xmm0,xmm0,0x55 + vmaxss xmm0,xmm0,xmm1 + test rsi,rsi + jz .LReduceMaximum.ExitKernel + +.LReduceMaximum.ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rdi] + add rdi,4 # advance input by 1 element + dec esi + jnz .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: + vzeroupper + ret + + .end diff --git a/third_party/versions b/third_party/versions index 2a0b1c9dae..1c7f3a9a71 100644 --- a/third_party/versions +++ b/third_party/versions @@ -6,7 +6,7 @@ simdjson v3.1.8 tlx v0.6.1 nlohmann 3.11.2 zsv commit-id:5c22aae4363fdcd433079d2a9b48007a7c6fbbdf -mlas onnxruntime:v1.14.1 +mlas onnxruntime:2c53b4a534a9b64466e435d384c91f0b684ea58a cppjieba v5.1.0 thrift v0.19.0 xor_singleheader v1.0.3