diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 15b5e42b1f2e2..91fe3794e1b43 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1695,7 +1695,7 @@ endif() #Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next. #The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake. -set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) +set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) if (onnxruntime_USE_WINML) # WINML uses and depends on the shared lib. Note: You can build WINML without DML and you will get a diff --git a/cmake/deps.txt b/cmake/deps.txt index 2aec0e35e1d7f..ddbf3e619ee12 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -33,6 +33,7 @@ googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 +microsoft_mlas;https://github.com/microsoft/mlas/archive/98eade39dc87f043c0406c216e31985768a7e1d4.zip;2f44f085b9e7b57f9d426d9f99c5c1ae82331ecf mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index a69d2649ad832..d6c6d2c02ebba 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -575,7 +575,6 @@ if (onnxruntime_RUN_ONNX_TESTS) add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS) endif() - if(onnxruntime_ENABLE_ATEN) message(STATUS "Aten fallback is enabled.") FetchContent_Declare( diff --git a/cmake/linux_arm32_crosscompile_toolchain.cmake b/cmake/linux_arm32_crosscompile_toolchain.cmake index 0183262a8875e..1184efd4a5fa3 100644 --- a/cmake/linux_arm32_crosscompile_toolchain.cmake +++ b/cmake/linux_arm32_crosscompile_toolchain.cmake @@ -1,6 +1,7 @@ #This file is just a sample. You may need to modify it before using. SET(CMAKE_SYSTEM_NAME Linux) SET(CMAKE_SYSTEM_VERSION 1) + SET(CMAKE_SYSTEM_PROCESSOR armv7l) SET(CMAKE_C_COMPILER arm-none-linux-gnueabihf-gcc) SET(CMAKE_CXX_COMPILER arm-none-linux-gnueabihf-g++) SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) diff --git a/cmake/linux_arm64_crosscompile_toolchain.cmake b/cmake/linux_arm64_crosscompile_toolchain.cmake index 1a492bbc269e7..fd37ca1bf1689 100644 --- a/cmake/linux_arm64_crosscompile_toolchain.cmake +++ b/cmake/linux_arm64_crosscompile_toolchain.cmake @@ -1,6 +1,7 @@ #This file is just a sample. You may need to modify it before using. SET(CMAKE_SYSTEM_NAME Linux) SET(CMAKE_SYSTEM_VERSION 1) + SET(CMAKE_SYSTEM_PROCESSOR aarch64) SET(CMAKE_C_COMPILER aarch64-none-linux-gnu-gcc) SET(CMAKE_CXX_COMPILER aarch64-none-linux-gnu-g++) SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 896379d743441..c89d9f313803d 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -224,3 +224,24 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + + +set(MLAS_ENABLE_WEBASSEMBLY_THREADS ${onnxruntime_ENABLE_WEBASSEMBLY_THREADS}) +set(MLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING ${onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING}) + +FetchContent_Declare( + microsoft_mlas + URL ${DEP_URL_microsoft_mlas} + URL_HASH SHA1=${DEP_SHA1_microsoft_mlas} +) +onnxruntime_fetchcontent_makeavailable(microsoft_mlas) +include_directories(${microsoft_mlas_SOURCE_DIR}/include) + +set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) +if(TARGET onnxruntime_mlas_arm64) + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) +endif() +if(TARGET onnxruntime_mlas_x86_64) + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) +endif() +message("ONNXRUNTIME_MLAS_LIBS: ${ONNXRUNTIME_MLAS_LIBS}") diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake deleted file mode 100644 index 20bb1fb772189..0000000000000 --- a/cmake/onnxruntime_mlas.cmake +++ /dev/null @@ -1,768 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) -set(MLAS_SRC_DIR ${MLAS_ROOT}/lib) -set(MLAS_INC_DIR ${MLAS_ROOT}/inc) - -# -# All hardware agnostic source files here -# hardware specific files would cause trouble in -# multi-target build -# -onnxruntime_add_static_library(onnxruntime_mlas - ${MLAS_SRC_DIR}/mlasi.h - ${MLAS_SRC_DIR}/platform.cpp - ${MLAS_SRC_DIR}/threading.cpp - ${MLAS_SRC_DIR}/sgemm.cpp - ${MLAS_SRC_DIR}/halfgemm.cpp - ${MLAS_SRC_DIR}/qgemm.cpp - ${MLAS_SRC_DIR}/qdwconv.cpp - ${MLAS_SRC_DIR}/convolve.cpp - ${MLAS_SRC_DIR}/convsym.cpp - ${MLAS_SRC_DIR}/pooling.cpp - ${MLAS_SRC_DIR}/transpose.cpp - ${MLAS_SRC_DIR}/reorder.cpp - ${MLAS_SRC_DIR}/snchwc.cpp - ${MLAS_SRC_DIR}/activate.cpp - ${MLAS_SRC_DIR}/logistic.cpp - ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/erf.cpp - ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/quantize.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp - ${MLAS_SRC_DIR}/qladd.cpp - ${MLAS_SRC_DIR}/qlmul.cpp - ${MLAS_SRC_DIR}/qpostprocessor.cpp - ${MLAS_SRC_DIR}/qlgavgpool.cpp - ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp - ${MLAS_SRC_DIR}/cast.cpp -) - -target_sources(onnxruntime_mlas PRIVATE - ${MLAS_INC_DIR}/mlas_float16.h - ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h - ${MLAS_INC_DIR}/mlas_q4.h - ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas.h -) - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4_dq.cpp - ${MLAS_SRC_DIR}/q4gemm.cpp - ) -endif() - -set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) - -#TODO: set MASM flags properly -function(setup_mlas_source_for_windows) - - # - # Sources common for all platforms. - # - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ) - - #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake - #Don't use it for other platforms. - if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) - set(PREPROCESS_ARMASM_FLAGS "") - set(ARMASM_FLAGS "") - - if(onnxruntime_target_platform STREQUAL "ARM64") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm - ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm - ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm - ) - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm - ) - - string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") - string(APPEND ARMASM_FLAGS " -machine ARM64EC") - endif() - - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - string(APPEND ARMASM_FLAGS " -g") - endif() - - # Remove double quotes from flag strings. - separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") - separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") - - # Run the C precompiler on each input before the assembler. - foreach(asm_filename ${mlas_platform_preprocess_srcs}) - get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) - set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) - add_custom_command( - OUTPUT ${obj_filename} - COMMAND - cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} - COMMAND - armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} - DEPENDS ${asm_filename} - BYPRODUCTS ${preprocess_filename} - ) - target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) - endforeach() - elseif(onnxruntime_target_platform STREQUAL "ARM") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ) - elseif(onnxruntime_target_platform STREQUAL "x64") - - file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") - - file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/dgemm.cpp - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/sgemma.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm - ) - if(MSVC_VERSION GREATER_EQUAL 1933) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm - ) - endif() - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - endif() - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm - ) - endif() -endfunction() - -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/wasm_simd/*.cpp" - ) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp - ) - else() - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp" - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -elseif(MSVC) - setup_mlas_source_for_windows() -else() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - - if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) - set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - endif() - foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) - if (OSX_ARCH STREQUAL "arm64") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm64e") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm") - set(ARM TRUE) - elseif (OSX_ARCH STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (OSX_ARCH STREQUAL "i386") - set(X86 TRUE) - endif() - endforeach() - elseif(ANDROID) - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - #Linux/FreeBSD/PowerPC/... - #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` - #Example values: - #arm64v8/ubuntu -> aarch64 - #arm32v6/alpine -> armv7l - #arm32v7/centos -> armv7l - #ppc64le/debian -> ppc64le - #s390x/ubuntu -> s390x - #ppc64le/busybox -> ppc64le - #arm64v8/ubuntu -> aarch64 - #Android: armv7-a aarch64 i686 x86_64 - #chasun: I don't think anyone uses 'arm64' - if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") - set(ARM TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") - set(POWER TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") - set(LOONGARCH64 TRUE) - endif() - endif() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) - set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) - endif() - #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below - #and split MLAS to multiple static libraries. - #Otherwise, it works like if(...) elseif(...) elseif(...) endif() - set(MLAS_SOURCE_IS_NOT_SET 1) - if(ARM) - enable_language(ASM) - - set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) - enable_language(ASM) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S - ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(POWER AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp - ${MLAS_SRC_DIR}/power/QuantizePower.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") - - check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) - if (HAS_POWER9) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") - endif() - - check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) - if(HAS_POWER10) - set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") - check_cxx_source_compiles(" - #include - int main() { - __vector_quad acc0; - __builtin_mma_xxsetaccz (&acc0); - return 0; - }" - COMPILES_P10 - ) - if(COMPILES_P10) - check_cxx_source_compiles(" - #ifdef _AIX - #define POWER_10 0x40000 - #define POWER_10_ANDUP (POWER_10) - #include - #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) - int main() { - bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); - return 0; - } - #else - #include - int main() { - unsigned long hwcap2 = getauxval(AT_HWCAP2); - bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); - return 0; - } - } - #endif" - HAS_P10_RUNTIME - ) - if (HAS_P10_RUNTIME) - set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - endif() - set(mlas_platform_srcs_power10 - ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") - set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") - set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_power10} - ) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S - ) - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ) - - # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own - # implementation to avoid external dependency. - if(ANDROID) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S - ) - endif() - - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - # Forward the flags for the minimum target platform version from the C - # compiler to the assembler. This works around CMakeASMCompiler.cmake.in - # not including the logic to set this flag for the assembler. - set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") - - # The LLVM assembler does not support the .arch directive to enable instruction - # set extensions and also doesn't support AVX-512F instructions without - # turning on support via command-line option. Group the sources by the - # instruction set extension and explicitly set the compiler flag as appropriate. - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S - ) - if(NOT APPLE) - set(mlas_platform_srcs_sse2 - ${mlas_platform_srcs_sse2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S - ) - endif() - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S - ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs_avx2 - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S - ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp - ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) - set(mlas_platform_srcs_avx2 - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S - ) - endif() - -message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") - -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") - message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") -else() - message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") -endif() - set(mlas_platform_srcs_avx512f - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - - set(mlas_platform_srcs_avx512core - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") - - set(mlas_platform_srcs_avx512vnni - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${mlas_platform_srcs_avx512f} - ${mlas_platform_srcs_avx512core} - ${mlas_platform_srcs_avx512vnni} - ) - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - if(NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S - ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S - ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp") - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -endif() - -foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) - - set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") -endforeach() - -if (WIN32) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") - if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") - endif() -endif() - -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() - source_group(TREE ${MLAS_ROOT} FILES ${source_group_srcs}) -endblock() - - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - - # - # Command line tool for quantization and de-quantization of 2-D fp32 tensors - # based on block-wise quantization of int4 - # - - onnxruntime_add_executable(onnxruntime_mlas_q4dq - ${MLAS_SRC_DIR}/q4_dq_cli.cpp - ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) - endif() - if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${CMAKE_DL_LIBS}) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) - - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() - -endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 67e5a9c0aa08b..aa0cda722d10e 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1165,7 +1165,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/gelu.cc ${BENCHMARK_DIR}/activation.cc ${BENCHMARK_DIR}/quantize.cc - ${BENCHMARK_DIR}/reduceminmax.cc ${BENCHMARK_DIR}/layer_normalization.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) @@ -1190,25 +1189,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) add_dependencies(onnxruntime_benchmark ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_benchmark PROPERTIES FOLDER "ONNXRuntimeTest") - SET(MLAS_BENCH_DIR ${TEST_SRC_DIR}/mlas/bench) - file(GLOB_RECURSE MLAS_BENCH_SOURCE_FILES "${MLAS_BENCH_DIR}/*.cpp" "${MLAS_BENCH_DIR}/*.h") - onnxruntime_add_executable(onnxruntime_mlas_benchmark ${MLAS_BENCH_SOURCE_FILES}) - target_include_directories(onnxruntime_mlas_benchmark PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc) - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark onnxruntime_util onnxruntime_framework ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) - target_compile_definitions(onnxruntime_mlas_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) - if(WIN32) - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE debug Dbghelp) - # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26409) - # "Global initializer calls a non-constexpr function." BENCHMARK_CAPTURE macro needs this. - target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26426) - else() - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE ${CMAKE_DL_LIBS}) - endif() - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo) - endif() - set_target_properties(onnxruntime_mlas_benchmark PROPERTIES FOLDER "ONNXRuntimeTest") endif() if(WIN32) @@ -1463,55 +1443,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32) endif() - if(NOT onnxruntime_target_platform STREQUAL "ARM64EC") - file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/mlas/unittest/*.h" - "${TEST_SRC_DIR}/mlas/unittest/*.cpp" - ) - onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) - if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" - "$<$>:/wd26409>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" - "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" - "$<$>:/wd26426>") - endif() - if(IOS) - set_target_properties(onnxruntime_mlas_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) - endif() - target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR}) - target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) - endif() - if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE ${CMAKE_DL_LIBS}) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) - endif() - if(WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) - set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() -endif() + # Training API Tests # Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing. # TODO(askhade): Fix the warnings. diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index 7e64235d3fc3d..7630ed38e5874 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -7,7 +7,7 @@ #include "core/common/narrow.h" #include "core/framework/op_kernel.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 4d435f71cc195..29939540497d2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -9,7 +9,7 @@ #include "core/common/safeint.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc b/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc index a7fa8f111d47e..74a1ec92f5b4e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc +++ b/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc @@ -10,7 +10,7 @@ #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" using onnxruntime::narrow; namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index ca818f09c4b1e..85a0443601ea1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -12,7 +12,7 @@ #include "core/common/safeint.h" #include "core/platform/env_var_utils.h" #include "core/platform/threadpool.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include #include diff --git a/onnxruntime/contrib_ops/cpu/cdist.cc b/onnxruntime/contrib_ops/cpu/cdist.cc index 736dbcfede2fc..527c5ebffb4d4 100644 --- a/onnxruntime/contrib_ops/cpu/cdist.cc +++ b/onnxruntime/contrib_ops/cpu/cdist.cc @@ -7,7 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" using onnxruntime::narrow; namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..1e3cd3e88aa9f 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -4,7 +4,7 @@ #include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "core/graph/constants.h" #include "core/framework/int4.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/fused_activation.h b/onnxruntime/contrib_ops/cpu/fused_activation.h index 0121a2038e1cb..40f0f16aa4dd2 100644 --- a/onnxruntime/contrib_ops/cpu/fused_activation.h +++ b/onnxruntime/contrib_ops/cpu/fused_activation.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/matmul_fpq4.cc b/onnxruntime/contrib_ops/cpu/matmul_fpq4.cc index 9bccdf2fe2090..a35f39993ae73 100644 --- a/onnxruntime/contrib_ops/cpu/matmul_fpq4.cc +++ b/onnxruntime/contrib_ops/cpu/matmul_fpq4.cc @@ -13,7 +13,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 13748b43b1ae6..0f1ec77574619 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -4,7 +4,7 @@ #include "nchwc_ops.h" #include "core/common/narrow.h" #include "core/common/safeint.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { using ConvPadVector = ConvAttributes::ConvPadVector; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c897f183164f..e0559ab896751 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -9,7 +9,7 @@ #include "core/util/math_cpuonly.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 69eabcfe2654a..34b83968c6fd5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -3,7 +3,7 @@ #include "core/common/narrow.h" #include "core/common/safeint.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/cpu/quantization/matmul_integer_base.h" diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc index b898c956b6e6a..25a55275ebb16 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -6,7 +6,7 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" #include "dequantize_blockwise_bnb4.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 89e96543c4729..642b9960d2655 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -10,9 +10,9 @@ #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" -#include "core/mlas/inc/mlas.h" -#include "core/mlas/inc/mlas_qnbit.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" diff --git a/onnxruntime/contrib_ops/cpu/quantization/nhwc_max_pool.cc b/onnxruntime/contrib_ops/cpu/quantization/nhwc_max_pool.cc index f3fe781c2b76a..d49bc55500a7b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/nhwc_max_pool.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/nhwc_max_pool.cc @@ -6,7 +6,7 @@ #include "core/providers/cpu/nn/pool_attributes.h" #include "core/common/safeint.h" #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_activations.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_activations.cc index a5ffeaa0d1c3b..eaf6eec52023f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_activations.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_activations.cc @@ -5,7 +5,7 @@ #include "qlinear_lookup_table.h" #include "core/common/narrow.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_binary_op.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_binary_op.cc index c4c738960bf66..feb2064e0e9e2 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_binary_op.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_binary_op.cc @@ -4,7 +4,7 @@ #include "qlinear_binary_op.h" #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index af163b6be702b..7ef82b32a183f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -7,7 +7,7 @@ #include "core/common/narrow.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc index e9924bf616eb5..a31011d96be7a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_global_average_pool.cc @@ -7,7 +7,7 @@ #include "core/providers/common.h" #include "core/platform/threadpool.h" #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_lookup_table.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_lookup_table.cc index 694102882ac71..f78519202462b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_lookup_table.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_lookup_table.cc @@ -3,7 +3,7 @@ #include "qlinear_lookup_table.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/common.h" namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_pool.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_pool.cc index e0dd9a6ac1009..2c64a69951591 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_pool.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_pool.cc @@ -10,7 +10,7 @@ #include "core/providers/common.h" #include "core/platform/threadpool.h" #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc index de1798e54874f..3db8deffb9cc9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc @@ -13,7 +13,7 @@ #include "core/providers/common.h" #include "core/providers/cpu/tensor/transpose.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_where.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_where.cc index 9c72f65e9c402..c1564af37f751 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_where.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_where.cc @@ -2,7 +2,7 @@ #include "qlinear_lookup_table.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 67b4950af73bf..4b9a564eaf622 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/util/math_cpuonly.h" #include "core/providers/common.h" #include "core/platform/threadpool.h" diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index 6c5a02903db3c..7a5127bf81750 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -6,7 +6,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/core/codegen/mti/nn/pool_ops.cc b/onnxruntime/core/codegen/mti/nn/pool_ops.cc index 868a14748cabc..917411b71974f 100644 --- a/onnxruntime/core/codegen/mti/nn/pool_ops.cc +++ b/onnxruntime/core/codegen/mti/nn/pool_ops.cc @@ -4,7 +4,7 @@ #include "core/codegen/mti/nn/pool_ops.h" #include "core/codegen/mti/mti_tvm_utils.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/cpu/nn/pool_attributes.h" #include diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index b6dc8ad56f257..fe676162b4568 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -3,7 +3,7 @@ #include "core/common/safeint.h" #include "core/framework/allocator.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/framework/utils.h" #include "core/session/ort_apis.h" #include diff --git a/onnxruntime/core/framework/transpose_helper.cc b/onnxruntime/core/framework/transpose_helper.cc index 38f68215a0484..5d5692674828e 100644 --- a/onnxruntime/core/framework/transpose_helper.cc +++ b/onnxruntime/core/framework/transpose_helper.cc @@ -4,7 +4,7 @@ #include "core/framework/copy.h" #include "core/framework/element_type_lists.h" #include "core/framework/transpose_helper.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/cpu/tensor/utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9eed0249711f9..55c9ecae32c00 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -18,7 +18,7 @@ #include "core/framework/session_state.h" #include "core/framework/sequential_executor.h" #include "core/framework/tensorprotoutils.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/framework/TensorSeq.h" #include "core/framework/run_options.h" #include "core/session/onnxruntime_run_options_config_keys.h" diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09a4a77780916..b6065a0ca080a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -12,8 +12,8 @@ #include "core/graph/contrib_ops/attn_lstm_schema_defs.h" #include "core/graph/contrib_ops/range_schema_defs.h" #include "core/graph/op.h" -#include "core/mlas/inc/mlas.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas.h" +#include "mlas_q4.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" #include "onnx/defs/function.h" diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc index 065e2912b566d..742328e4c44d8 100644 --- a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc @@ -4,7 +4,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace ONNX_NAMESPACE { void convPoolShapeInference( diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format deleted file mode 100644 index 16ad8bd8a7234..0000000000000 --- a/onnxruntime/core/mlas/.clang-format +++ /dev/null @@ -1,13 +0,0 @@ ---- - -BasedOnStyle: Google -IndentWidth: 4 -# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. -# Developers are responsible for adhering to the 120 character maximum. -ColumnLimit: 0 -AlignAfterOpenBracket: BlockIndent -AlwaysBreakAfterReturnType: TopLevel -AlwaysBreakTemplateDeclarations: Yes -BinPackParameters: false -BreakBeforeBraces: Linux -... diff --git a/onnxruntime/core/mlas/README.md b/onnxruntime/core/mlas/README.md deleted file mode 100644 index 072795c0fd61c..0000000000000 --- a/onnxruntime/core/mlas/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# About MLAS -MLAS is a compute library containing processor optimized GEMM kernels and platform specific threading code. - -# Unit tests for MLAS -Unit tests for the SGEMM kernels are available under onnxruntime\test\mlas. These tests run over a range of inputs that then execute the various special cases for aligned and unaligned outputs. The tests have failed if any "mismatch" strings are printed. - diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h deleted file mode 100644 index 28ae64c4d5b3e..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas.h +++ /dev/null @@ -1,1871 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas.h - -Abstract: - - This module contains the public data structures and procedure prototypes - for the Microsoft Machine Learning algebra subprogram library. - ---*/ - -#pragma once - -#include -#include -#include -#include - -// -// Define the calling convention for Windows targets. -// - -#if (_MSC_VER >= 800) || defined(_STDCALL_SUPPORTED) -#define MLASCALL __stdcall -#else -#define MLASCALL -#endif - -// -// Define the target architecture. -// - -#if (defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__) -#define MLAS_TARGET_AMD64 -#endif -#if defined(_M_IX86) || defined(__i386__) -#define MLAS_TARGET_IX86 -#endif -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_IX86) -#define MLAS_TARGET_AMD64_IX86 -#endif -#if defined(_M_ARM64) || defined(__aarch64__) -#define MLAS_TARGET_ARM64 -#endif -#if defined(_M_ARM64EC) -#define MLAS_TARGET_ARM64EC -#endif -#if defined(_M_ARM) || defined(__arm__) -#define MLAS_TARGET_ARM -#endif -#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) || defined(MLAS_TARGET_ARM) -#define MLAS_TARGET_ARM_ANY -#endif - -#if defined(__VSX__) -#define MLAS_TARGET_POWER -#endif -#if defined(__wasm__) -#define MLAS_TARGET_WASM -#if defined(__wasm_simd128__) -#define MLAS_TARGET_WASM_SIMD -#else -#define MLAS_TARGET_WASM_SCALAR -#endif -#endif - -#if defined(__loongarch64) -#define MLAS_TARGET_LARCH64 -#endif -// -// Define the support levels for the target architecture. -// - -#if defined(MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) -#define MLAS_SUPPORTS_GEMM_DOUBLE -#endif - -#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) -#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) -#if !defined(__APPLE__) -// Had to temporary disable fp16 under APPLE ARM64, as compiling -// the source files require a hardware specific compilation flag. -// When building an universial binary for APPLE, this flag would -// cause trouble for x64 target. - -#define MLAS_F16VEC_INTRINSICS_SUPPORTED - -#endif // -#endif // ARM64 -#endif // Visual Studio 16 or earlier does not support fp16 intrinsic - -// -// Basic Linear Algebra Subprograms (BLAS) types. -// - -#ifndef CBLAS_ENUM_DEFINED_H -#define CBLAS_ENUM_DEFINED_H -typedef enum { CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113 } CBLAS_TRANSPOSE; -typedef enum { CblasUpper=121, CblasLower=122 } CBLAS_UPLO; -typedef enum { CblasNonUnit=131, CblasUnit=132 } CBLAS_DIAG; -typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE; -#endif - -// -// Forward declare the thread pool implementation class and half precision floating point. -// -// N.B. Avoid including ONNX Runtime headers here to keep the dependencies for -// standalone MLAS test executables smaller. -// - -namespace onnxruntime { - namespace concurrency { - class ThreadPool; - }; - struct MLFloat16; -}; // namespace onnxruntime - -using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool; - - -// -// Platform routines. -// - -size_t -MLASCALL -MlasGetPreferredBufferAlignment( - void - ); - -#ifdef MLAS_TARGET_AMD64_IX86 - -/** - * @brief Return whether the current CPU has over saturation problem - * when computing u8s8 matrix multiplication - * https://www.intel.com/content/www/us/en/develop/documentation/onednn-developer-guide-and-reference/top/advanced-topics/nuances-of-int8-computations.html -*/ -bool -MLASCALL -MlasPlatformU8S8Overflow( - void - ); - -#endif - - -// -// Activation routines. -// - -enum MLAS_ACTIVATION_KIND { - MlasIdentityActivation, - MlasReluActivation, - MlasLeakyReluActivation, - MlasTanhActivation, - MlasLogisticActivation, - MlasClipActivation, - MlasHardSigmoidActivation, - MlasActivationKindCount, -}; - -struct MLAS_ACTIVATION { - MLAS_ACTIVATION_KIND ActivationKind; - union { - struct { - float alpha; - } LeakyRelu; - struct { - float minimum; - float maximum; - } Clip; - struct { - float alpha; - float beta; - } HardSigmoid; - float Values[2]; - } Parameters; -}; - -void -MLASCALL -MlasActivation( - const MLAS_ACTIVATION* Activation, - float* Buffer, - const float* Bias, - size_t M, - size_t N, - size_t ldc - ); - -// -// Matrix/matrix multiply routines. -// C := alpha * op(A) * op(B) + beta * C -// op(X) = X or op(X) = transpose(X) or op(X) = conjg(transpose(X)) -// - -/** - * @brief Supply matrices data information to single precision gemm functions - */ -struct MLAS_SGEMM_DATA_PARAMS { - const float* A = nullptr; /**< Supplies the address of matrix A */ - size_t lda = 0; /**< Supplies the first dimension of matrix A. */ - const float* B = nullptr; /**< Supplies the address of matrix B */ - size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ - float* C = nullptr; /**< Supplies the address of matrix C */ - size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ - float alpha = 1.0f; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ - float beta = 0.0f; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ - bool BIsPacked = false; /**< Whether B is pre-packed */ -}; - -/** - * @brief Batched single precision matrix/matrix multiply operation (SGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param Data A array of matrices data parameters - * @param BatchSize Supplies number of multiplications in this batch - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -void -MLASCALL -MlasGemmBatch( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_SGEMM_DATA_PARAMS* Data, - size_t BatchSize, - MLAS_THREADPOOL* ThreadPool - ); - -/** - * @brief Single precision matrix/matrix multiply operation (SGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param Data Supplies the matrices data parameters - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -inline -void -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_SGEMM_DATA_PARAMS& Data, - MLAS_THREADPOOL* ThreadPool - ) -{ - MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); -} - -/** - * @brief Single precision matrix/matrix multiply operation (SGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) - * @param A Supplies the address of matrix A - * @param lda Supplies the first dimension of matrix A. - * @param B Supplies the address of matrix B - * @param ldb Supplies the first dimension of matrix B. - * @param beta Supplies the scalar beta multiplier (see SGEMM definition) - * @param C Supplies the address of matrix C - * @param ldc Supplies the first dimension of matrix C. - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -inline -void -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -{ - MLAS_SGEMM_DATA_PARAMS Data; - Data.alpha = alpha; - Data.A = A; - Data.lda = lda; - Data.B = B; - Data.ldb = ldb; - Data.beta = beta; - Data.C = C; - Data.ldc = ldc; - - MlasGemm(TransA, TransB, M, N, K, Data, ThreadPool); -} - -/** - * @brief the single precision matrix/matrix multiply operation (SGEMM) with pre-packed B - * - * @param TransA - Supplies the transpose operation for matrix A. - * @param M - Supplies the number of rows of matrix A and matrix C. - * @param N - Supplies the number of columns of matrix B and matrix C. - * @param K - Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - * @param A - Supplies the address of matrix A. - * @param lda - Supplies the first dimension of matrix A. - * @param PackedB - Supplies the address of packed matrix B. - * @param beta - Supplies the scalar beta multiplier (see SGEMM definition). - * @param C - Supplies the address of matrix C. - * @param ldc - Supplies the first dimension of matrix C. - * @param ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -inline -void -MlasGemm( - CBLAS_TRANSPOSE TransA, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const void* PackedB, - float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -{ - MLAS_SGEMM_DATA_PARAMS DataParams; - DataParams.A = A; - DataParams.lda = lda; - DataParams.B = static_cast(PackedB); - DataParams.ldb = 0; - DataParams.C = C; - DataParams.ldc = ldc; - DataParams.alpha = alpha; - DataParams.beta = beta; - DataParams.BIsPacked = true; - - MlasGemmBatch(TransA, - CblasTrans, // deos not matter when B is packed - M, N, K, &DataParams, 1, ThreadPool); -} - -/** - * @brief Supply matrices data information to double precision gemm functions - */ -struct MLAS_DGEMM_DATA_PARAMS { - const double* A = nullptr; /**< Supplies the address of matrix A */ - size_t lda = 0; /**< Supplies the first dimension of matrix A. */ - const double* B = nullptr; /**< Supplies the address of matrix B */ - size_t ldb = 0; /**< Supplies the first dimension of matrix B. */ - double* C = nullptr; /**< Supplies the address of matrix C */ - size_t ldc = 0; /**< Supplies the first dimension of matrix C. */ - double alpha = 1.0; /**< Supplies the scalar alpha multiplier (see SGEMM definition) */ - double beta = 0.0; /**< Supplies the scalar beta multiplier (see SGEMM definition) */ -}; - -/** - * @brief Batched double precision matrix/matrix multiply operation (DGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param Data A array of matrices data parameters - * @param BatchSize Supplies number of multiplications in this batch - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -void -MLASCALL -MlasGemmBatch( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_DGEMM_DATA_PARAMS* Data, - size_t BatchSize, - MLAS_THREADPOOL* ThreadPool - ); - -/** - * @brief Double precision matrix/matrix multiply operation (DGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param Data Supplies the matrices data parameters - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -inline -void -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_DGEMM_DATA_PARAMS& Data, - MLAS_THREADPOOL* ThreadPool - ) -{ - MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); -} - -/** - * @brief Double precision matrix/matrix multiply operation (DGEMM) - * - * @param TransA Supplies the transpose operation for matrix A. - * @param TransB Supplies the transpose operation for matrix B. - * @param M Supplies the number of rows of matrix A and matrix C. - * @param N Supplies the number of columns of matrix B and matrix C. - * @param K Supplies the number of columns of matrix A and the number - of rows of matrix B. - * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) - * @param A Supplies the address of matrix A - * @param lda Supplies the first dimension of matrix A. - * @param B Supplies the address of matrix B - * @param ldb Supplies the first dimension of matrix B. - * @param beta Supplies the scalar beta multiplier (see SGEMM definition) - * @param C Supplies the address of matrix C - * @param ldc Supplies the first dimension of matrix C. - * @param ThreadPool Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - */ -inline -void -MlasGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc, - MLAS_THREADPOOL* ThreadPool - ) -{ - MLAS_DGEMM_DATA_PARAMS Data; - Data.alpha = alpha; - Data.A = A; - Data.lda = lda; - Data.B = B; - Data.ldb = ldb; - Data.beta = beta; - Data.C = C; - Data.ldc = ldc; - MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); -} - -enum class MLAS_QUANTIZATION_GRANULARITY { - PerMatrix, - PerColumn, -}; - -enum class MLAS_QGEMM_OUTPUT_MODE { - ZeroMode, // overwrite the output buffer - AccumulateMode, // accumulate to the output buffer -}; - -class MLAS_QGEMM_OUTPUT_PROCESSOR { -public: - virtual - void - Process( - const int32_t*, // Supplies the address of matrix to process - size_t, // Supplies the start row index of matrix - size_t, // Supplies the start col index of matrix - size_t, // Supplies the element count per row to process - size_t, // Supplies the element count per col to process - size_t // Supplies the leading dimension of matrix - ) const = 0; - - virtual ~MLAS_QGEMM_OUTPUT_PROCESSOR() {} -}; - -class MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR { -public: - MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR( - float* Output, - size_t LeadingDimensionOutput, - const float* Scale, - const float* Bias, - MLAS_QGEMM_OUTPUT_MODE Mode = MLAS_QGEMM_OUTPUT_MODE::ZeroMode, - MLAS_QUANTIZATION_GRANULARITY QuantGran = MLAS_QUANTIZATION_GRANULARITY::PerMatrix) : - Output_(Output), - LeadingDimensionOutput_(LeadingDimensionOutput), - Scale_(Scale), - Bias_(Bias), - OutputMode_(Mode), - QuantGran_(QuantGran) - { - } - - void - Process( - const int32_t* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const override; - -private: - template - inline - void - ProcessImpl( - const int32_t* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const; - -private: - float* Output_; - size_t LeadingDimensionOutput_; - const float* Scale_; - const float* Bias_; - MLAS_QGEMM_OUTPUT_MODE OutputMode_; - MLAS_QUANTIZATION_GRANULARITY QuantGran_; -}; - -/** - * @brief Supply matrices shape and data type information to quantized gemm functions - * - ** NOTE: AIsSigned == true is not supported on non-ARM devices for now. - ** AIsSigned == true is supported on ARM devices when BIsSigned is also true. - * -*/ -struct MLAS_GEMM_QUANT_SHAPE_PARAMS { - size_t M = 0; /**< Supplies the row size of matrix A */ - size_t N = 0; /**< Supplies the column size of matrix B */ - size_t K = 0; /**< Supplies the column size of matrix A and row size of matrix B */ - bool AIsSigned = false; /**< Indicates whether type of A is int8_t or uint8_t.*/ - bool BIsSigned = false; /**< Indicates whether type of B is int8_t or uint8_t */ - bool IsAccumulateMode = false; /**< Indicates whether to accumulate to matrix C or override matrix C */ -}; - -struct MLAS_GEMM_QUANT_DATA_PARAMS { - const uint8_t* A = nullptr; - size_t lda = 0; - uint8_t ZeroPointA = 0; - const void* B = 0; - size_t ldb = 0; - const uint8_t* ZeroPointB = nullptr; - bool BIsPacked = false; - bool PerColumnZeroPoints = false; - int32_t* C = nullptr; - size_t ldc = 0; - const MLAS_QGEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; -}; - -/** - * @brief Batched GEMM, for multiplying multiple pairs of matrices. - * Note: We only support uniform batching, so shapes and types of the - * input must be same: M, N, K, BIsSigned must be the - * same across all parameter blocks. - * - * @param [IN] Shape A single shape descriptor for all the multiplications - * @param [IN] DataParams Array of data descriptors for the matrices. - * @param [IN] BatchN Size of the parameters array, also number of multiplications to perform - * @param [IN] ThreadPool optional thread pool for parallel processing - */ -void -MLASCALL -MlasGemmBatch( - const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, - MLAS_THREADPOOL* ThreadPool - ); - -inline -void -MlasGemm( - const MLAS_GEMM_QUANT_SHAPE_PARAMS &Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS &DataParams, - MLAS_THREADPOOL *ThreadPool) -{ - MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); -} - -// -// Symmetric QGEMM has limited buffer overrun. -// Currently only supported in ARM64 -// -#if defined(MLAS_TARGET_ARM64) -constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30; -#else -constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0; -#endif - -/** - * @brief Supply data parameters for symmetric quantized GEMM. - * B matrix zero point must be zero, and it must be - * pre-packed, with column sums scaled by (-ZeroPointA) -*/ -struct MLAS_SYMM_QGEMM_DATA_PARAMS { - const void* A = nullptr; - size_t lda = 0; - const void* B = 0; - void* C = nullptr; - size_t ldc = 0; - // TODO!! add re-quantization parameters -}; - -/** - * @brief Batched QGEMM. Similar to MlasGemmBatch, but right hand side matrix - * must be symmetrically quantized and prepacked. - * - * @param [IN] Shape A single shape descriptor for all multiplicatons. - Currently A and B must be signed, and accumulation - mode not supported - * @param [IN] DataParams Array of data descriptors, one for each multiplication - * B must be prepacked - * @param [IN] BatchN Number of multiplications - * @param [IN] ThreadPool -*/ -void -MLASCALL -MlasSymmQgemmBatch( - const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape, - const MLAS_SYMM_QGEMM_DATA_PARAMS* DataParams, - const size_t BatchN, - MLAS_THREADPOOL* ThreadPool - ); - - -// -// Buffer packing routines. -// - -size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K - ); - -void -MLASCALL -MlasGemmPackB( - CBLAS_TRANSPOSE TransB, - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ); - -size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K, - bool AIsSigned, - bool BIsSigned - ); - -void -MLASCALL -MlasGemmPackB( - size_t N, - size_t K, - const uint8_t* B, - size_t ldb, - bool AIsSigned, - bool BIsSigned, - void* PackedB - ); - -/** - * @brief For symmetric quantized GEMM, returns size of the - * packing buffer needed for right hand side - * @param N Number of columns - * @param K Number of rows - * @param AIsSigned Whether left hand size is signed int8_t - * @return size of the packing buffer, - * 0 if operation not supported -*/ -size_t -MLASCALL -MlasSymmQgemmPackBSize( - size_t N, - size_t K, - bool AIsSigned - ); - -void -MLASCALL -MlasSymmQgemmPackB( - size_t N, - size_t K, - const int8_t* B, - size_t ldb, - bool AIsSigned, - int32_t ZeroPointA, - void* PackedB - ); - -// -// Convolution routines. -// - -enum MLAS_CONV_ALGORITHM { - MlasConvAlgorithmGemmDirect, - MlasConvAlgorithmExpandThenGemm, - MlasConvAlgorithmExpandThenGemmSegmented, -#if defined(MLAS_TARGET_WASM_SCALAR) - MlasConvAlgorithmDepthwise, -#endif -}; - -struct MLAS_CONV_PARAMETERS { - const MLAS_ACTIVATION* Activation; - size_t Dimensions; - size_t BatchCount; - size_t GroupCount; - size_t InputChannels; - size_t InputShape[3]; - size_t KernelShape[3]; - size_t DilationShape[3]; - size_t Padding[6]; - size_t StrideShape[3]; - size_t FilterCount; - size_t OutputShape[3]; - size_t InputSize; - size_t OutputSize; - size_t K; - float Beta; - MLAS_CONV_ALGORITHM Algorithm; - ptrdiff_t ThreadCount; - union { - struct { - CBLAS_TRANSPOSE TransB; - size_t ldb; - } GemmDirect; - struct { - size_t ThreadStrideN; - } ExpandThenGemmSegmented; - } u; -}; - -void MLASCALL -MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, - size_t Dimensions, - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - size_t FilterCount, - const MLAS_ACTIVATION* Activation, - size_t* WorkingBufferSize, - float Beta, - MLAS_THREADPOOL* ThreadPool); - -void -MLASCALL -MlasConv( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - const float* Bias, - float* WorkingBuffer, - float* Output, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasConvDepthwise( - const void* const* Input, - int32_t InputZeroPoint, - bool InputIsSigned, - const void* Filter, - int32_t FilterZeroPoint, - bool FilterIsSigned, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -// -// Symmetric quantized integer convolution routines. -// - -size_t -MlasConvSymPackWSize( - size_t GroupCount, - size_t InputChannels, - size_t OutputChannels, - size_t KernelSize, - bool InputIsSigned - ); - -void -MlasConvSymPackW( - size_t GroupCount, - size_t InputChannels, - size_t OutputChannels, - size_t KernelSize, - const int8_t* W, - int8_t* PackedW, - size_t PackedWSize, - bool InputIsSigned - ); - -int32_t -MlasConvSymFixupInputZeroPoint( - int32_t zero_point_value, - bool InputIsSigned - ); - -// -// Convolution operators (or maybe others in the future) need to do their -// own job partition. Since filters (right hand side B matrix) is usually -// small in size, activations are divided horizontally. We need to provide -// kernel stride units to facilitate the divide. -// - -int32_t -MlasConvSymGetKernelOutputCount( - bool InputIsSigned - ); - -int32_t -MlasConvSymDepthwiseGetKernelOutputCnt( - bool InputIsSigned - ); - -/** - * @brief Returns the stride M of depthwise conv kernel - * - * Most optimized path is Symmetric conv. See - * MlasConvSymDepthwiseGetKernelOutputCnt(bool) - * - * These kernels are implemented in qdwconv.cpp using - * intrincic, all of them with stride val 1. We use - * a slightly bigger value to improve cache reuse. - * - * This needs to be changed if we optimize depthwise - * kernels. - * - * @return -*/ -inline -int32_t -MlasConvDepthwiseGetKernelOutputCnt() -{ - return 4; -} - -int32_t -MlasSymmQgemmGetKernelOutputCnt(); - -int32_t -MlasQgemmGetKernelOutputCnt( - bool AIsSigned, - bool BIsSigned - ); - - -struct MLAS_CONV_SYM_PARAMS { - const void* InputDirect; - const void* const* InputIndirection; - const void* Filter; - void* Output; - size_t InputChannels; - size_t OutputChannels; - size_t OutputCount; - size_t KernelSize; - const int32_t* Bias; - const float* Scale; - bool PerChannelScale; - int32_t OutputZeroPoint; - bool InputIsSigned; -}; - -void -MlasConvSym( - const MLAS_CONV_SYM_PARAMS& Params - ); - -void -MlasConvSymDepthwise( - const MLAS_CONV_SYM_PARAMS& Params - ); - -// -// Pooling routines. -// - -enum MLAS_POOLING_KIND { - MlasMaximumPooling, - MlasAveragePoolingExcludePad, - MlasAveragePoolingIncludePad, - MlasPoolingKindCount, -}; - -void -MLASCALL -MlasPool( - MLAS_POOLING_KIND PoolingKind, - size_t Dimensions, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output, - MLAS_THREADPOOL* ThreadPool - ); - -template -void -MLASCALL -MlasMaximumPool( - const T8Bits* const* Input, - T8Bits* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -// -// Miscellaneous compute routines. -// - -void -MLASCALL -MlasComputeErf( - const float* Input, - float* Output, - size_t N - ); - -void -MLASCALL -MlasComputeExp( - const float* Input, - float* Output, - size_t N - ); - -void -MLASCALL -MlasComputeLogistic( - const float* Input, - float* Output, - size_t N - ); - -void -MLASCALL -MlasComputeSoftmax( - const float* Input, - float* Output, - size_t N, - size_t D, - bool LogSoftmax, - bool SmoothSoftmax, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasComputeTanh( - const float* Input, - float* Output, - size_t N - ); - -// -// Transpose routines. -// - -void -MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const int8_t* Input, - int8_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N - ); - -// -// Buffer reordering routines. -// - -void -MLASCALL -MlasReorderInputNchw( - const float* S, - float* D, - size_t InputChannels, - size_t InputSize - ); - -void -MLASCALL -MlasReorderInputNhwc( - const float* S, - float* D, - size_t InputChannels, - size_t RowCount, - size_t FullRowCount - ); - -void -MLASCALL -MlasReorderOutputNchw( - const int64_t* OutputShape, - const float* S, - float* D, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasReorderOutputNhwc( - const int64_t* OutputShape, - const float* S, - float* D - ); - -void -MLASCALL -MlasReorderFilterOIHWBiBo( - const int64_t* FilterShape, - const float* S, - float* D - ); - -void -MLASCALL -MlasReorderFilterOIHWBo( - const int64_t* FilterShape, - const float* S, - float* D - ); - -// -// Single precision NCHWc routines. -// - -size_t -MLASCALL -MlasNchwcGetBlockSize( - void - ); - -void -MLASCALL -MlasNchwcConv( - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - size_t GroupCount, - const float* Input, - const float* Filter, - const float* Bias, - float* Output, - const MLAS_ACTIVATION* Activation, - bool ZeroMode, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasNchwcPool( - MLAS_POOLING_KIND PoolingKind, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasNchwcUpsampleNearest( - const int64_t* InputShape, - const int64_t* Scales, - const float* Input, - float* Output - ); - -void -MLASCALL -MlasNchwcUpsampleLinear( - size_t InputHeight, - size_t InputWidth, - size_t OutputWidth, - float InterpolationHeight, - const float* InterpolationWidth, - const float* Input, - float* Output - ); - -// -// Linear quantization routines. -// - -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ); - -void -MLASCALL -MlasQuantizeLinearU4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -void -MLASCALL -MlasQuantizeLinearS4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -/** - * @brief Requantize a block of the intermediate buffer to the output buffer, - * optionally adding the supplied bias - * - * @param Input Input matrix - * @param InputLeadingDimension Input matrix leading dimension - * @param Output Output matrix - * @param OutputLeadingDimension Output matrix leading dimension - * @param Bias Optional bias vector, to be added - to the input before quantization - * @param Scale Quantization scale - * @param PerColumnScale true if scale is per-column - * @param ZeroPoint quantization zero point value - * @param StartM - * @param StartN - * @param CountM - * @param CountN - * @return -*/ -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ); - -class MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR : public MLAS_QGEMM_OUTPUT_PROCESSOR -{ - public: - MLAS_QGEMM_REQUANT_OUTPUT_PROCESSOR( - void* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - int32_t ZeroPoint, - bool OutputIsSigned) - : Output_(Output), - OutputLeadingDimension_(OutputLeadingDimension), - Bias_(Bias), - Scale_(Scale), - PerColumnScale_(PerColumnScale), - ZeroPoint_(ZeroPoint), - OutputIsSigned_(OutputIsSigned) - { - } - - void Process(const int32_t* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc) const override - { - if(OutputIsSigned_){ - MlasRequantizeOutput(C, ldc, reinterpret_cast(Output_), OutputLeadingDimension_, - Bias_, Scale_, PerColumnScale_, static_cast(ZeroPoint_), - StartM, StartN, CountM, CountN); - } else { - MlasRequantizeOutput(C, ldc, reinterpret_cast(Output_), OutputLeadingDimension_, - Bias_, Scale_, PerColumnScale_, static_cast(ZeroPoint_), - StartM, StartN, CountM, CountN); - } - } - - - private: - void* Output_; - size_t OutputLeadingDimension_; - const int32_t* Bias_; - const float* Scale_; - bool PerColumnScale_; - int32_t ZeroPoint_; - bool OutputIsSigned_; -}; - - -void -MLASCALL -MlasFindMinMaxElement( - const float* Input, - float* Min, - float* Max, - size_t N - ); - -size_t -MLASCALL -MlasQLinearSafePaddingElementCount( - size_t ElementSize, - size_t ElementCount - ); - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ); - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNhwc( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Batch, - size_t ImageSize, - size_t Stride, - size_t Channels, - int32_t* AccumulateBuffer, - const T8Bits* ZeroBuffer - ); - -// -// InputA is of size N, -// Input B is of size 1 if IsScalarB == true, otherwise it is of size N -// -template -void -MLASCALL -MlasQLinearAdd( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N, - bool IsScalarB - ); - -template -void -MLASCALL -MlasQLinearMul( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N, - bool IsScalarB - ); - -// -// Half precision routines -// - -// Any type with size=2 should work -using MLAS_FP16 = onnxruntime::MLFloat16; - -constexpr size_t FP16_SIZE = sizeof(uint16_t); - -// -// Half-precision floating-point routines. -// - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const MLAS_FP16* Source, - float* Destination, - size_t Count -); - -void -MLASCALL -MlasConvertFloatToHalfBuffer( -const float* Source, -MLAS_FP16* Destination, -size_t Count -); - - /** - * @brief Whether current CPU supports FP16 acceleration. -*/ -bool MLASCALL -MlasFp16AccelerationSupported(); - -/** - * @brief Interface for half gemm post processors. - * - * Example implementation of this interface includes activations, - * conversion from half precision to single precision, etc. - * - * Half GEMM is computed tile by tile. When a tile of result matrix - * is produced, the method Process() is called to process this tile. - * Parameters of this method describe the location and shape of the - * tile. -*/ -class MLAS_HALF_GEMM_POSTPROCESSOR { -public: - virtual - void - Process( - MLAS_FP16*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {} -}; - -/** - * @brief Half precision activation functions, with optional sum tensor. - * Supplied sum tensor must be the same layout as the GEMM output tensor. - * And the supplied sum tensor will be added to the tensor before activation. -*/ -class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR -{ - public: - MLAS_HALF_GEMM_ACTIVATION_PROCESSOR( - const MLAS_ACTIVATION& Activation, - const MLAS_FP16* SumBuf = nullptr) - : Activation_(Activation), SumBuf_(SumBuf) - {} - - void Process( - MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const override; - - private: - const MLAS_ACTIVATION& Activation_; - const MLAS_FP16* SumBuf_; -}; - -inline -void -MlasFp16Activation( - const MLAS_ACTIVATION* Activation, - MLAS_FP16* Buffer, - size_t M, - size_t N, - size_t ldc - ) -{ - MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(*Activation); - proc.Process(Buffer, 0, 0, M, N, ldc); -} - - -/** - * @brief Convert half gemm result matrix to single precision float matrix -*/ -class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { -public: - MLAS_HALF_GEMM_2FLOAT_PROCESSOR( - const MLAS_ACTIVATION& Activation, - float* Output, /**< address of the output matrix, row major */ - size_t RowStride /**< row stride of the output matrix */ - ) : Activation_(Activation), - Output_(Output), - RowStride_(RowStride) - {} - - void - Process( - MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const override; - -private: - const MLAS_ACTIVATION& Activation_; - float* Output_; - const size_t RowStride_; -}; - - -/** - * @brief Data parameters for half precision GEMM routine - * All except C are [in] parameters -*/ -struct MLAS_HALF_GEMM_DATA_PARAMS { - const void* A = nullptr; /**< address of A */ - const void* B = nullptr; /**< address of B */ - const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ - MLAS_FP16* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ - size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; - bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ - bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ -}; - -/** - * @brief Half precision Batched GEMM: C = A * B + Bias - * Either A or B can be fp32 or fp16 - * - * Note: We only support uniform batching, so shapes and types of the - * input must be same across all parameter blocks. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return -*/ -void -MLASCALL -MlasHalfGemmBatch( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool = nullptr - ); - -/** - * @brief For half precision GEMM, returns size of the - * packing buffer needed for right hand side - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] float2half Whether the input is float that - * needs to be converted to half precision - * @return size of the packing buffer, - * 0 if operation not supported -*/ -size_t -MLASCALL -MlasHalfGemmPackBSize( - size_t N, - size_t K, - bool float2half - ); - -/** - * @brief For half precision GEMM, pack the right hand - * side matrix B - * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix -*/ -void -MLASCALL -MlasHalfGemmPackB( - size_t N, - size_t K, - const MLAS_FP16* B, - size_t ldb, - void* PackedB - ); - -/** - * @brief For half precision GEMM, convert the float matrix B - * to half precision and pack it into a packing buffer - * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix -*/ -void -MLASCALL -MlasHalfGemmConvertPackB( - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ); - -#if defined(__aarch64__) && defined(__linux__) -/** - * @brief Whether current CPU supports Bfloat16(bf16) acceleration. - */ -bool MLASCALL -MlasBf16AccelerationSupported(); - -/** - * @brief Interface for bf16 gemm post processors. - * - * Example implementation of this interface includes activations, - * conversion from single precision to precision, etc. - * - * SBGEMM is computed tile by tile. When a tile of result matrix - * is produced, the method Process() is called to process this tile. - * Parameters of this method describe the location and shape of the - * tile. - */ -class MLAS_SBGEMM_POSTPROCESSOR -{ - public: - virtual void Process(float*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} -}; - -/** - * @brief bfloat16 precision activation functions, with optional sum tensor. - * Supplied sum tensor must be the same layout as the GEMM output tensor. - * And the supplied sum tensor will be added to the tensor before activation. - */ -class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR -{ - public: - MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) - : Activation_(Activation), SumBuf_(SumBuf) - { - } - - void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) - const override; - - private: - const MLAS_ACTIVATION& Activation_; - const float* SumBuf_; -}; - -/** - * @brief Data parameters for bfloat16 precision GEMM routine - * All except C are [in] parameters - */ -struct MLAS_SBGEMM_DATA_PARAMS { - const void* A = nullptr; /**< address of A */ - const void* B = nullptr; /**< address of B */ - const float* Bias = nullptr; /**< address of Bias, vector size N */ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ - size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; - bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ - bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ -}; - -/** - * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias - * Either B can be either fp32 or bf16 - * - * Note: We only support uniform batching, so shapes and types of the - * input must be same across all parameter blocks. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); - -/** - * @brief For bfloat16 precision GEMM, returns size of the - * packing buffer needed for right hand side - * @param[in] N Number of columns - * @param[in] K Number of rows - * @return size of the packing buffer, - * 0 if operation not supported - */ -size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K); - -/** - * @brief For bfloat16 precision GEMM, convert the float matrix B - * to blfoat16 precision and pack it into a packing buffer - * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix - */ -void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); -#endif - -/** - * @brief Indirect Depthwise convolution for fp16 - * @param Input Supplies the indirect buffer for NHWC input - * @param Filter Supplies the address for filter tensor - * @param Bias Supplies the address for 1D bias tensor B, has size of M - * @param Output Supplies the address for the result tensor - * @param Channels # of input channels - * @param OutputCount # of output pixels - * @param KernelSize # kernel size - * @return -*/ -void -MLASCALL -MlasConvDepthwise( - const MLAS_FP16* const* Input, - const MLAS_FP16* Filter, - const MLAS_FP16* Bias, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize, - MLAS_HALF_GEMM_POSTPROCESSOR* PostProc - ); - - -inline -void -MlasTranspose( - const MLAS_FP16* Input, - MLAS_FP16* Output, - size_t M, - size_t N - ) -{ - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, N); -} - - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -/** - * @brief Max Pooling for fp16 NHWC - * @param Input Indirect buffer to activations - * @param Output Address of the result tensor - * @param Channels C in NHWC - * @param OutputCount Number of output pixels - * @param KernelSize Size of the kernel - * @return -*/ -void -MLASCALL -MlasNhwcMaxPool( - const MLAS_FP16* const* Input, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -/** - * @brief Avg Pooling for fp16 nhwc - * @param Input Indirect buffer to activations - * @param Output Address of the output data - * @param Channels C in NHWC - * @param OutputCount Number of output pixels - * @param KernelSize size of the kernel - * @return -*/ -void -MLASCALL -MlasNhwcAvgPool( - const MLAS_FP16* const* Input, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -#endif - -struct MlasFlashAttentionThreadedArgs { - int batch_size; - int num_heads; - int q_sequence_length; - int kv_sequence_length; - int qk_head_size; - int v_head_size; - int q_block_size; - int kv_block_size; - float scale; - int thread_count; - float* buffer; - size_t buffer_size_per_thread; - const float* query; - const float* key; - const float* value; - float* output; -}; - -/** - * @brief Per-thread worker function for fp32 Flash Attention - * @param thread_id Thread index - * @param args Arguments - * @return -*/ -void -MLASCALL -MlasFlashAttention( - MlasFlashAttentionThreadedArgs* args, - MLAS_THREADPOOL* ThreadPool -); diff --git a/onnxruntime/core/mlas/inc/mlas_float16.h b/onnxruntime/core/mlas/inc/mlas_float16.h deleted file mode 100644 index 33227ea90d6be..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_float16.h +++ /dev/null @@ -1,115 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_float16.h - -Abstract: - - Utilities for half precision floating type conversions. Used internally - by MLAS on platforms without half precision support. Provided here as - convenience for tests or other client libraries/apps. - ---*/ - -#pragma once - -#include -#include -#include - - -using _mlas_fp16_ = uint16_t; - -union fp32_bits { - uint32_t u; - float f; -}; - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) - -/*PreFast told us to convert them to constexpr but the compiler says we can't.*/ -#pragma warning(disable : 26497) - -/*Added whole bunch of casts, still can't get rid of these overflow warnings.*/ -#pragma warning(disable : 26450) -#pragma warning(disable : 26451) -#endif - -inline -_mlas_fp16_ -MLAS_Float2Half(float ff) -{ - constexpr fp32_bits f32infty = {255 << 23}; - constexpr fp32_bits f16max = {(127 + 16) << 23}; - constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; - constexpr uint32_t sign_mask = 0x80000000u; - - auto val = static_cast(0x0u); - fp32_bits f; - f.f = ff; - - uint32_t sign = f.u & sign_mask; - f.u ^= sign; - - if (f.u >= f16max.u) { - // Inf or NaN (all exponent bits set) - val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf - } else { - if (f.u < (113 << 23)) { - // Subnormal or zero - // use a magic value to align our 10 mantissa bits at the bottom of - // the float. as long as FP addition is round-to-nearest-even this - // just works. - f.f += denorm_magic.f; - - // and one integer subtract of the bias later, we have our final float! - val = static_cast(f.u - denorm_magic.u); - } else { - uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd - - // update exponent, rounding bias part 1 - f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; - // rounding bias part 2 - f.u += mant_odd; - // take the bits! - val = static_cast(f.u >> 13); - } - } - - val |= static_cast(sign >> 16); - return val; -} - -inline -float -MLAS_Half2Float(_mlas_fp16_ val) -{ - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift - fp32_bits o; - - o.u = (val & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } - - o.u |= (val & 0x8000) << 16; // sign bit - return o.f; -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h deleted file mode 100644 index 7ea29eb091318..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h +++ /dev/null @@ -1,33 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_gemm_postprocessor.h - -Abstract: - - This module contains a base class for custom postprocessing following a - GEMM. - ---*/ - -#pragma once - -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T* C, /**< the address of matrix to process */ - size_t RangeStartM, /**< the start row index of matrix */ - size_t RangeStartN, /**< the start col index of matrix */ - size_t RangeCountM, /**< the element count per row to process */ - size_t RangeCountN, /**< the element count per col to process */ - size_t ldc /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h deleted file mode 100644 index aec14070ffd55..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ /dev/null @@ -1,437 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_q4.h - -Abstract: - - This module contains the public data structures and procedure prototypes - for blocked int4 quantization and dequantization. - - Int4 block quantization is used to compress weight tensors of large - language models. - ---*/ - -#pragma once - -#include "mlas.h" -#include "mlas_gemm_postprocessor.h" - -#include -#include - -/** - * @brief Define types of block quantization - */ -typedef enum { - BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ - BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ - BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ - BlkQ4Sym128 = 4 /*!< int4 Symmetric Block Quantization, 128 values per block*/ -} MLAS_BLK_QUANT_TYPE; - -/** - * @brief Computes the number of bytes required to pack and int4-quantize - * a weight matrix - * @param QType type of block quantization - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @return size of the packing buffer, 0 if the operation is not yet supported. -*/ -size_t -MLASCALL -MlasQ4GemmPackBSize( - MLAS_BLK_QUANT_TYPE QType, - size_t N, - size_t K - ); - -/** - * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * - * @param QType type of block quantization - * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B -*/ -void -MLASCALL -MlasQ4GemmPackB( - MLAS_BLK_QUANT_TYPE QType, - void* PackedBuf, - const float* FpData, - size_t N, - size_t K, - size_t ldb - ); - - -/** - * @brief Unpack and dequantize from int4 to fp32, reverse operation of - * MlasQ4GemmPackB - * @param QType type of block quantization - * @param FpData destination buffer, the fp32 matrix - * @param PackedBuf int4 quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - */ -void -MLASCALL -MlasQ4GemmUnPackB( - MLAS_BLK_QUANT_TYPE QType, - float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb - ); - - -/** - * @brief Data parameters for Q4 GEMM routine - * C = A * B + Bias - * A must be a float32 matrix - * B must be a quantized and packed int4 blob - * All except C are [in] parameters - */ -struct MLAS_Q4_GEMM_DATA_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ - const float* Bias = nullptr; /**< address of Bias, vector size N */ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; -}; - -/** - * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix - * B must be a quantized and packed int4 blob - * - * @param[in] QType type of block quantization used in B - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasQ4GemmBatch( - MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool = nullptr - ); - - -/** - * @brief Calculate the buffer size needed for int8 block quantize - * @param[in] QType Type of block quantization used - * @param[in] M Number of rows of the input matrix - * @param[in] K Number of columns of the input matrix - * @return buffer size (in bytes) needed, 0 if not yet supported on current hardware -*/ -size_t -MLASCALL -MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K); - -/** - * @brief Given an input float 2-D matrix, perform blocked int8 quantize - * - * @param QType Type of block quantization used - * @param Qblob Pointer to the output buffer - * @param A Pointer to the float matrix - * @param M Number of rows of the input matrix - * @param K Number of columns of the input matrix - * @param lda leading dimension of the input matrix - * @param ThreadPool -*/ -void -MLASCALL -MlasQ80BlkQuant( - MLAS_BLK_QUANT_TYPE QType, - void* Qblob, - const float* A, - size_t M, - size_t K, - size_t lda, - MLAS_THREADPOOL* ThreadPool - ); - - -/** - * @brief Data parameters for Q8Q4 GEMM routine - * C = A * B + Bias - * A must be a block quantized int8 matrix - * B must be a block quantized and packed int4 blob - * All except C are [in] parameters - */ -struct MLAS_Q8Q4_GEMM_DATA_PARAMS { - const void* A = nullptr; /**< address of A (quantized int8 blob)*/ - const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ - const float* Bias = nullptr; /**< address of Bias, vector size N */ - float* C = nullptr; /**< address of result matrix */ - size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; -}; - -/** - * @brief Batched GEMM: C = A * B + Bias - * A must be a quantized int8 blob - * B must be a quantized and packed int4 blob - * - * @param[in] QType type of block quantization used in B - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasQ8Q4GemmBatch( - MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool - ); - - -//////////////////////////////////////////////////////////// -// Blockwise quantization and dequantization where quantization -// parameters are packed into separate buffers. -// - -/** - * @brief For quantization type , and - * matrix shape [rows, columns], compute the shape of the - * quantization parameter matrix [meta_rows, meta_cols] -*/ -template -void -MlasBlockwiseQuantMetaShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& meta_rows, - int& meta_cols - ); - -/** - * @brief For quantization type , and - * matrix shape [rows, columns], compute the shape of the - * quantized matrix [q_rows, q_cols]. The quantized matrix - * is in column major layout, with bits packed on the column. - * - * @tparam T - * @tparam qbits - * @param block_size - * @param columnwise - * @param rows - * @param columns - * @param q_rows - * @param q_cols -*/ -template -void -MlasBlockwiseQuantizedShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& q_rows, - int& q_cols - ); - -/** - * @brief Compute the sizes of the quantized data and quantization parameter buffers. - * - * @param qbits The bit width of each quantized value. - * @param block_size The number of quantized values in a block. - * @param columnwise Whether a block contains values from a matrix column (true) or row (false). - * @param rows Number of matrix rows. - * @param columns Number of matrix columns. - * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. - * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. - * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. - * - * If the qbits or block_size values are unsupported the output sizes will be zero. - */ -void MLASCALL -MlasBlockwiseQuantizedBufferSizes( - int qbits, - int block_size, - bool columnwise, - int rows, - int columns, - size_t& q_data_size_in_bytes, - size_t& q_scale_num_elements, - size_t* q_zero_point_size_in_bytes -); - - -/** - * @brief Blockwise 4 bits quantization, resulting elements and quantization - * parameters (scales, zero points) are packed into separate matrices - * all in column major layout for faster access during subsequent matrix - * multiplication. - * - * @tparam ElementT type of the input matrix element, usually floating point - * @tparam qbits number of bits used for quantization, 4 for int4 - * - * @param dst points to the quantized matrix, shape [rows, columns] column major - * @param scales points to the scales matrix, column major - * @param zero_points points to the zero_points matrix, column major - * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] - * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point - * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param leading_dimension - * @param thread_pool -*/ -template -void -MlasQuantizeBlockwise( - uint8_t* dst, - ElementT* scales, - uint8_t* zero_points, - const ElementT* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool - ); - - -/** - * @brief Blockwise 4 bits dequantization, quantized elements and quantization - * parameters (scales, zero points) are from separate matrices packed - * in column major layout. Output is a floating point matrix in column - * major layout for faster access during subsequent matrix multiplication. - * - * @tparam ElementT type of the dequantized matrix element, usually floating point - * @tparam qbits number of bits used for quantization, 4 for int4 - * - * @param dst points to dequantized matrix shape [rows, columns] column major - * @param src points to quantized matrix, column major - * @param scales points to quantization scales, column major - * @param zero_points points to quantization zero points, column major - * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point - * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row - * @param rows - * @param columns - * @param thread_pool -*/ -template -void -MlasDequantizeBlockwise( - ElementT* dst, - const uint8_t* src, - const ElementT* scales, - const uint8_t* zero_points, - int block_size, - bool columnwise, - int rows, - int columns, - MLAS_THREADPOOL* thread_pool - ); - -/** - * @brief Blockwise 4 bits quantization. After quantization, the weights and zero points - * are packed row-wise. If zero_points is null, quantized type is int4 with default - * zero point 0, to align with DQ schema. Otherwise, quantized type is uint4. - * In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales. - * @tparam Tin - * @tparam qbits number of bits used for quantization, only 4 is supported - * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] - * @param scales points to the scales matrix, row major - * @param zero_points points to the zero_points matrix, row major - * @param dst points to the quantized matrix, shape [rows, columns] row major in qbits type. - * In uint8_t type, shape is [rows, columns * qbits / 8]. - * @param columnwise true when quantize elements in a column, false when quantize elements in a row. - * @param rows - * @param columns - * @param quant_block_size number of elements in a quantize block - * @param thread_pool - * @return the quantized type is signed. - */ -template -bool -MlasQDQQuantizeBlockwise( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero - * points are packed row-wise. The dst tensors are column major. dst weights and zero points - * are packed column-wise. - * dst_weights and dst_zero_points are in uint4. - * If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are - * converted to uint4 by adding 8. - * If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8. - * src_zero_points is 0 and dst_zero_points is 8. - * If src_weights is uint4 and has src_zero_points, just transpose. - * If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with - * 0 values. Otherwise exception is thrown. - * @tparam Tin - * @tparam qbits number of bits used for quantization, only 4 is supported - * @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned - * @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type. - * In uint8_t type, shape is [rows, columns * qbits / 8]. - * @param src_scales points to the scales matrix, row major - * @param src_zero_points points to the zero_points matrix, row major. Packed row-wise. - * @param dst_weights points to the quantized matrix, column major. Packed column-wise. - * @param dst_scales points to the scales matrix, column major - * @param dst_zero_points points to the zero_points matrix, column major. Packed column-wise. - * @param columnwise true when quantize elements in a column, false when quantize elements in a row. - * @param rows - * @param columns - * @param quant_block_size number of elements in a quantize block - * @param thread_pool - */ -template -void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const Tin* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - Tin* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h deleted file mode 100644 index 232bf2261ef4c..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ /dev/null @@ -1,201 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_qnbit.h - -Abstract: - - This module contains the public data structures and procedure prototypes - for blocked n-bit quantized GEMM. - - N-bit block quantization is used to compress weight tensors of large - language models. - ---*/ - -#pragma once - -#include "mlas.h" -#include "mlas_gemm_postprocessor.h" - -/** - * @brief Define compute types of block quantization, in order of decreasing accuracy. - */ -typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; - -/** - * @brief Data parameters for float/n-bit quantized int GEMM routine. - */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C - - ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; -}; - -/** - * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix - * B must be a quantized and packed n-bit int matrix - * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. - * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). - * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should - * point to an intermediate workspace buffer. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a - buffer with at least that many bytes. Otherwise, it may be nullptr. - * @param[in] ThreadPool optional thread pool to use - */ -void MLASCALL -MlasSQNBitGemmBatch( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - MLAS_THREADPOOL* ThreadPool = nullptr -); - -/** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. - * - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - */ -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -); - -/** - * @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM - * implementation. If zero, no intermediate workspace is required. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - */ -size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -); - -/** - * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). - * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - */ -size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -); - -/** - * @brief Packs the quantized B data in a format that the kernel expects. - * - * If the function is called without QuantBScale and QuantBZeroPoint, - * it just packs QuantBData into PackedQuantBDataAndOrBlkSum. - * - * If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint - * additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum. - * - * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, - * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, - * and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint, - * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. - * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. - * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[in] QuantBData quantized B data - * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum - * @param[in] QuantBScale quantized B scale - * @param[in] has_zp_input whether QuantBZeroPoint is provided - * @param[in] QuantBZeroPoint quantized B zero point - * @param[in] ThreadPool thread pool to use (no parallel if nullptr) - */ -void MLASCALL -MlasSQNBitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const void* QuantBData, - void* PackedQuantBDataAndOrBlkSum, - const void* QuantBScale, - bool has_zp_input, - const void* QuantBZeroPoint, - MLAS_THREADPOOL* ThreadPool -); diff --git a/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S deleted file mode 100644 index fc7f482b36e54..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch32/QgemmU8X8KernelNeon.S +++ /dev/null @@ -1,657 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelNeon.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "asmmacro.h" - - .syntax unified - .arch armv7-a - .thumb - -// -// Stack frame layout for the U8X8 kernel. -// - - .equ .LGemmU8X8KernelFrame_SavedGeneralRegisters, (7 * 4) - .equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (8 * 8) - .equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedGeneralRegisters + .LGemmU8X8KernelFrame_SavedNeonRegisters - .equ .LGemmU8X8KernelFrame_CountM, 0 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_CountN, 4 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ldc, 8 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_RowSumBuffer, 12 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 16 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ZeroPointB, 20 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ZeroMode, 24 + .LGemmU8X8KernelFrame_SavedRegisters - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (r0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (r1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (r2) - Supplies the address of matrix C. - - PackedCountK (r3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for matrix - A and matrix C. The actual number of rows handled for this invocation - depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to iterate - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A. These values have - been pre-scaled by the zero point offset of matrix B if the offset is - per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmU8X8KernelNeon - -// -// Register usage: -// -// q0-q1 (d0-d3) matrix B data -// q2-q3 (d4-d7) matrix A data -// q4 (d8-d9) packed matrix A data -// q5 (d10-d11) RowSumBuffer data -// q6-q7 (d12-d15) ColumnSumBuffer data -// q8-q15 accumulators[4][2] -// - - push {r4,r5,r6,r7,r8,r9,r10} - vpush {d8-d15} - ldr r4,[sp,#.LGemmU8X8KernelFrame_CountM] - ldr r5,[sp,#.LGemmU8X8KernelFrame_ZeroMode] - ldr r7,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] - ldr r8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] - ldr r9,[sp,#.LGemmU8X8KernelFrame_RowSumBuffer] - ldr r10,[sp,#.LGemmU8X8KernelFrame_ldc] - ldr r12,[sp,#.LGemmU8X8KernelFrame_CountN] - vld1.32 {d10-d11},[r9] // load RowSumBuffer - mov r6,r0 - mov r9,r3 - cmp r4,#1 // CountM == 1? - beq .LGemmU8X8.M1.ProcessNextColumnLoop - cmp r4,#4 // CountM < 4? - blo .LGemmU8X8.M2.ProcessNextColumnLoop - -// -// Process 4 rows of the matrices. -// - -.LGemmU8X8.M4.ProcessNextColumnLoop: - vldr d0,[r1] // load packed B0 - mov r0,r6 // reload matrix A - vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r9 // reload PackedCountK - vmovl.u8 q0,d0 - vdup.32 q9,d10[0] - vdup.32 q11,d10[1] - vdup.32 q13,d11[0] - vdup.32 q15,d11[1] - cbz r7,.LGemmU8X8.M4.SkipScaleByZeroPointB - vld1.32 {d8-d9},[r7]! // load ZeroPointB0 - vmul.u32 q8,q9,q4 - vmul.u32 q10,q11,q4 - vmul.u32 q12,q13,q4 - vmul.u32 q14,q15,q4 - vld1.32 {d8-d9},[r7]! // load ZeroPointB1 - vmul.u32 q9,q9,q4 - vmul.u32 q11,q11,q4 - vmul.u32 q13,q13,q4 - vmul.u32 q15,q15,q4 - vldr d8,[r0] // load first packed A0 - vadd.u32 q8,q8,q6 - vadd.u32 q9,q9,q7 - vadd.u32 q10,q10,q6 - vadd.u32 q11,q11,q7 - vldr d9,[r0,#8] // load first packed A1 - vadd.u32 q12,q12,q6 - vadd.u32 q13,q13,q7 - vadd.u32 q14,q14,q6 - vadd.u32 q15,q15,q7 - b .LGemmU8X8.M4.ComputeBlockLoop - -.LGemmU8X8.M4.SkipScaleByZeroPointB: - vldr d8,[r0] // load first packed A0 - vadd.u32 q8,q9,q6 - vadd.u32 q9,q9,q7 - vadd.u32 q10,q11,q6 - vadd.u32 q11,q11,q7 - vldr d9,[r0,#8] // load first packed A1 - vadd.u32 q12,q13,q6 - vadd.u32 q13,q13,q7 - vadd.u32 q14,q15,q6 - vadd.u32 q15,q15,q7 - -.LGemmU8X8.M4.ComputeBlockLoop: - vmovl.u8 q2,d8 - add r0,#16 - vmovl.u8 q3,d9 - vldr d2,[r1,#8] // load packed B1 - vmlal.u16 q8,d0,d4[0] - vmlal.u16 q9,d1,d4[0] - vmlal.u16 q10,d0,d5[0] - vmlal.u16 q11,d1,d5[0] - vmovl.u8 q1,d2 - vmlal.u16 q12,d0,d6[0] - vmlal.u16 q13,d1,d6[0] - vmlal.u16 q14,d0,d7[0] - vmlal.u16 q15,d1,d7[0] - vldr d0,[r1,#16] // load packed B2 - vmlal.u16 q8,d2,d4[1] - vmlal.u16 q9,d3,d4[1] - vmlal.u16 q10,d2,d5[1] - vmlal.u16 q11,d3,d5[1] - vmovl.u8 q0,d0 - vmlal.u16 q12,d2,d6[1] - vmlal.u16 q13,d3,d6[1] - vmlal.u16 q14,d2,d7[1] - vmlal.u16 q15,d3,d7[1] - vldr d2,[r1,#24] // load packed B3 - add r1,#32 - subs r3,#1 - beq .LGemmU8X8.M4.ComputeBlockLoopFinish - vmlal.u16 q8,d0,d4[2] - vmlal.u16 q9,d1,d4[2] - vmlal.u16 q10,d0,d5[2] - vmlal.u16 q11,d1,d5[2] - vmovl.u8 q1,d2 - vldr d8,[r0] // load next packed A0 - vmlal.u16 q12,d0,d6[2] - vmlal.u16 q13,d1,d6[2] - vmlal.u16 q14,d0,d7[2] - vmlal.u16 q15,d1,d7[2] - vldr d0,[r1] // load packed B0 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - vmlal.u16 q10,d2,d5[3] - vmlal.u16 q11,d3,d5[3] - vmovl.u8 q0,d0 - vldr d9,[r0,#8] // load next packed A1 - vmlal.u16 q12,d2,d6[3] - vmlal.u16 q13,d3,d6[3] - vmlal.u16 q14,d2,d7[3] - vmlal.u16 q15,d3,d7[3] - b .LGemmU8X8.M4.ComputeBlockLoop - -.LGemmU8X8.M4.ComputeBlockLoopFinish: - vmlal.u16 q8,d0,d4[2] // finish computing tail vectors - vmlal.u16 q9,d1,d4[2] - add r0,r2,r10,lsl #2 // compute output row 2 - vmlal.u16 q10,d0,d5[2] - vmlal.u16 q11,d1,d5[2] - vmovl.u8 q1,d2 - vmlal.u16 q12,d0,d6[2] - vmlal.u16 q13,d1,d6[2] - vmlal.u16 q14,d0,d7[2] - vmlal.u16 q15,d1,d7[2] - add r3,r0,r10,lsl #2 // compute output row 3 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - vmlal.u16 q10,d2,d5[3] - vmlal.u16 q11,d3,d5[3] - vmlal.u16 q12,d2,d6[3] - vmlal.u16 q13,d3,d6[3] - add r4,r3,r10,lsl #2 // compute output row 4 - vmlal.u16 q14,d2,d7[3] - vmlal.u16 q15,d3,d7[3] - subs r12,#8 // adjust CountN remaining - blo .LGemmU8X8.M4.StoreOutputPartial - cbnz r5,.LGemmU8X8.M4.SkipAccumulateOutput - vld1.32 {d0-d3},[r2] - vld1.32 {d4-d7},[r0] - vadd.u32 q8,q8,q0 - vadd.u32 q9,q9,q1 - vld1.32 {d0-d3},[r3] - vadd.u32 q10,q10,q2 - vadd.u32 q11,q11,q3 - vld1.32 {d4-d7},[r4] - vadd.u32 q12,q12,q0 - vadd.u32 q13,q13,q1 - vadd.u32 q14,q14,q2 - vadd.u32 q15,q15,q3 - -.LGemmU8X8.M4.SkipAccumulateOutput: - vst1.32 {d16-d19},[r2]! - vst1.32 {d20-d23},[r0] - vst1.32 {d24-d27},[r3] - vst1.32 {d28-d31},[r4] - cmp r12,#0 - bne .LGemmU8X8.M4.ProcessNextColumnLoop - -.LGemmU8X8.M4.ExitKernel: - mov r0,#4 // return number of rows handled - vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r9,r10} - bx lr - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M4.StoreOutputPartial: - cbz r5,.LGemmU8X8.M4.StoreOutputPartial.AddMode - -.LGemmU8X8.M4.StoreOutputPartial.ZeroMode: - tst r12,#4 - beq .LGemmU8X8.M4.StoreOutputPartial2.ZeroMode - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - vst1.32 {d20-d21},[r0]! - vmov q10,q11 - vst1.32 {d24-d25},[r3]! - vmov q12,q13 - vst1.32 {d28-d29},[r4]! - vmov q14,q15 - -.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode: - tst r12,#2 - beq .LGemmU8X8.M4.StoreOutputPartial1.ZeroMode - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - vst1.32 {d20},[r0]! - vmov d20,d21 - vst1.32 {d24},[r3]! - vmov d24,d25 - vst1.32 {d28},[r4]! - vmov d28,d29 - -.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode: - tst r12,#1 - beq .LGemmU8X8.M4.ExitKernel - vst1.32 d16[0],[r2] - vst1.32 d20[0],[r0] - vst1.32 d24[0],[r3] - vst1.32 d28[0],[r4] - b .LGemmU8X8.M4.ExitKernel - -.LGemmU8X8.M4.StoreOutputPartial.AddMode: - tst r12,#4 - beq .LGemmU8X8.M4.StoreOutputPartial2.AddMode - vld1.32 {d0-d1},[r2] - vld1.32 {d4-d5},[r0] - vadd.u32 q8,q8,q0 - vld1.32 {d0-d1},[r3] - vadd.u32 q10,q10,q2 - vld1.32 {d4-d5},[r4] - vadd.u32 q12,q12,q0 - vadd.u32 q14,q14,q2 - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - vst1.32 {d20-d21},[r0]! - vmov q10,q11 - vst1.32 {d24-d25},[r3]! - vmov q12,q13 - vst1.32 {d28-d29},[r4]! - vmov q14,q15 - -.LGemmU8X8.M4.StoreOutputPartial2.AddMode: - tst r12,#2 - beq .LGemmU8X8.M4.StoreOutputPartial1.AddMode - vld1.32 {d0},[r2] - vld1.32 {d4},[r0] - vadd.u32 d16,d16,d0 - vld1.32 {d0},[r3] - vadd.u32 d20,d20,d4 - vld1.32 {d4},[r4] - vadd.u32 d24,d24,d0 - vadd.u32 d28,d28,d4 - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - vst1.32 {d20},[r0]! - vmov d20,d21 - vst1.32 {d24},[r3]! - vmov d24,d25 - vst1.32 {d28},[r4]! - vmov d28,d29 - -.LGemmU8X8.M4.StoreOutputPartial1.AddMode: - tst r12,#1 - beq .LGemmU8X8.M4.ExitKernel - vld1.32 d0[0],[r2] - vld1.32 d4[0],[r0] - vadd.u32 d16,d16,d0 - vld1.32 d0[0],[r3] - vadd.u32 d20,d20,d4 - vld1.32 d4[0],[r4] - vadd.u32 d24,d24,d0 - vadd.u32 d28,d28,d4 - vst1.32 d16[0],[r2] - vst1.32 d20[0],[r0] - vst1.32 d24[0],[r3] - vst1.32 d28[0],[r4] - b .LGemmU8X8.M4.ExitKernel - -// -// Process 2 rows of the matrices. -// - -.LGemmU8X8.M2.ProcessNextColumnLoop: - vldr d0,[r1] // load packed B0 - mov r0,r6 // reload matrix A - vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r9 // reload PackedCountK - vmovl.u8 q0,d0 - vdup.32 q9,d10[0] - vdup.32 q11,d10[1] - cbz r7,.LGemmU8X8.M2.SkipScaleByZeroPointB - vld1.32 {d28-d31},[r7]! // load ZeroPointB - vmul.u32 q8,q9,q14 - vmul.u32 q9,q9,q15 - vmul.u32 q10,q11,q14 - vmul.u32 q11,q11,q15 - vld1.32 d8,[r0]! // load first packed A0 - vadd.u32 q8,q8,q6 - vadd.u32 q9,q9,q7 - vadd.u32 q10,q10,q6 - vadd.u32 q11,q11,q7 - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.SkipScaleByZeroPointB: - vld1.32 d8,[r0]! // load first packed A0 - vadd.u32 q8,q9,q6 - vadd.u32 q9,q9,q7 - vadd.u32 q10,q11,q6 - vadd.u32 q11,q11,q7 - -.LGemmU8X8.M2.ComputeBlockLoop: - vmovl.u8 q2,d8 - vldr d2,[r1,#8] // load packed B1 - vmlal.u16 q8,d0,d4[0] - vmlal.u16 q9,d1,d4[0] - vmlal.u16 q10,d0,d5[0] - vmlal.u16 q11,d1,d5[0] - vmovl.u8 q1,d2 - vldr d0,[r1,#16] // load packed B2 - vmlal.u16 q8,d2,d4[1] - vmlal.u16 q9,d3,d4[1] - vmlal.u16 q10,d2,d5[1] - vmlal.u16 q11,d3,d5[1] - vmovl.u8 q0,d0 - vldr d2,[r1,#24] // load packed B3 - add r1,#32 - subs r3,#1 - beq .LGemmU8X8.M2.ComputeBlockLoopFinish - vmlal.u16 q8,d0,d4[2] - vmlal.u16 q9,d1,d4[2] - vmlal.u16 q10,d0,d5[2] - vmlal.u16 q11,d1,d5[2] - vmovl.u8 q1,d2 - vld1.32 d8,[r0]! // load next packed A0 - vldr d0,[r1] // load packed B0 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - vmlal.u16 q10,d2,d5[3] - vmlal.u16 q11,d3,d5[3] - vmovl.u8 q0,d0 - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.ComputeBlockLoopFinish: - vmlal.u16 q8,d0,d4[2] // finish computing tail vectors - vmlal.u16 q9,d1,d4[2] - add r0,r2,r10,lsl #2 // compute output row 2 - vmlal.u16 q10,d0,d5[2] - vmlal.u16 q11,d1,d5[2] - vmovl.u8 q1,d2 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - vmlal.u16 q10,d2,d5[3] - vmlal.u16 q11,d3,d5[3] - subs r12,#8 // adjust CountN remaining - blo .LGemmU8X8.M2.StoreOutputPartial - cbnz r5,.LGemmU8X8.M2.SkipAccumulateOutput - vld1.32 {d0-d3},[r2] - vld1.32 {d4-d7},[r0] - vadd.u32 q8,q8,q0 - vadd.u32 q9,q9,q1 - vadd.u32 q10,q10,q2 - vadd.u32 q11,q11,q3 - -.LGemmU8X8.M2.SkipAccumulateOutput: - vst1.32 {d16-d19},[r2]! - vst1.32 {d20-d23},[r0] - cmp r12,#0 - bne .LGemmU8X8.M2.ProcessNextColumnLoop - -.LGemmU8X8.M2.ExitKernel: - mov r0,#2 // return number of rows handled - vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r9,r10} - bx lr - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M2.StoreOutputPartial: - cbz r5,.LGemmU8X8.M2.StoreOutputPartial.AddMode - -.LGemmU8X8.M2.StoreOutputPartial.ZeroMode: - tst r12,#4 - beq .LGemmU8X8.M2.StoreOutputPartial2.ZeroMode - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - vst1.32 {d20-d21},[r0]! - vmov q10,q11 - -.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode: - tst r12,#2 - beq .LGemmU8X8.M2.StoreOutputPartial1.ZeroMode - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - vst1.32 {d20},[r0]! - vmov d20,d21 - -.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode: - tst r12,#1 - beq .LGemmU8X8.M2.ExitKernel - vst1.32 d16[0],[r2] - vst1.32 d20[0],[r0] - b .LGemmU8X8.M2.ExitKernel - -.LGemmU8X8.M2.StoreOutputPartial.AddMode: - tst r12,#4 - beq .LGemmU8X8.M2.StoreOutputPartial2.AddMode - vld1.32 {d0-d1},[r2] - vld1.32 {d4-d5},[r0] - vadd.u32 q8,q8,q0 - vadd.u32 q10,q10,q2 - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - vst1.32 {d20-d21},[r0]! - vmov q10,q11 - -.LGemmU8X8.M2.StoreOutputPartial2.AddMode: - tst r12,#2 - beq .LGemmU8X8.M2.StoreOutputPartial1.AddMode - vld1.32 {d0},[r2] - vld1.32 {d4},[r0] - vadd.u32 d16,d16,d0 - vadd.u32 d20,d20,d4 - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - vst1.32 {d20},[r0]! - vmov d20,d21 - -.LGemmU8X8.M2.StoreOutputPartial1.AddMode: - tst r12,#1 - beq .LGemmU8X8.M2.ExitKernel - vld1.32 d0[0],[r2] - vld1.32 d4[0],[r0] - vadd.u32 d16,d16,d0 - vadd.u32 d20,d20,d4 - vst1.32 d16[0],[r2] - vst1.32 d20[0],[r0] - b .LGemmU8X8.M2.ExitKernel - -// -// Process 1 row of the matrices. -// - -.LGemmU8X8.M1.ProcessNextColumnLoop: - vldr d0,[r1] // load packed B0 - mov r0,r6 // reload matrix A - vld1.32 {d12-d15},[r8]! // load ColumnSumBuffer - mov r3,r9 // reload PackedCountK - vmovl.u8 q0,d0 - vdup.32 q9,d10[0] - cbz r7,.LGemmU8X8.M1.SkipScaleByZeroPointB - vld1.32 {d28-d31},[r7]! // load ZeroPointB - vmul.u32 q8,q9,q14 - vmul.u32 q9,q9,q15 - vld1.32 d8[0],[r0]! // load first packed A0 - vadd.u32 q8,q8,q6 - vadd.u32 q9,q9,q7 - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.SkipScaleByZeroPointB: - vld1.32 d8[0],[r0]! // load first packed A0 - vadd.u32 q8,q9,q6 - vadd.u32 q9,q9,q7 - -.LGemmU8X8.M1.ComputeBlockLoop: - vmovl.u8 q2,d8 - vldr d2,[r1,#8] // load packed B1 - vmlal.u16 q8,d0,d4[0] - vmlal.u16 q9,d1,d4[0] - vmovl.u8 q1,d2 - vldr d0,[r1,#16] // load packed B2 - vmlal.u16 q8,d2,d4[1] - vmlal.u16 q9,d3,d4[1] - vmovl.u8 q0,d0 - vldr d2,[r1,#24] // load packed B3 - add r1,#32 - subs r3,#1 - beq .LGemmU8X8.M1.ComputeBlockLoopFinish - vmlal.u16 q8,d0,d4[2] - vmlal.u16 q9,d1,d4[2] - vmovl.u8 q1,d2 - vld1.32 d8[0],[r0]! // load next packed A0 - vldr d0,[r1] // load packed B0 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - vmovl.u8 q0,d0 - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.ComputeBlockLoopFinish: - vmlal.u16 q8,d0,d4[2] // finish computing tail vectors - vmlal.u16 q9,d1,d4[2] - vmovl.u8 q1,d2 - vmlal.u16 q8,d2,d4[3] - vmlal.u16 q9,d3,d4[3] - subs r12,#8 // adjust CountN remaining - blo .LGemmU8X8.M1.StoreOutputPartial - cbnz r5,.LGemmU8X8.M1.SkipAccumulateOutput - vld1.32 {d0-d3},[r2] - vadd.u32 q8,q8,q0 - vadd.u32 q9,q9,q1 - -.LGemmU8X8.M1.SkipAccumulateOutput: - vst1.32 {d16-d19},[r2]! - cmp r12,#0 - bne .LGemmU8X8.M1.ProcessNextColumnLoop - -.LGemmU8X8.M1.ExitKernel: - mov r0,#1 // return number of rows handled - vpop {d8-d15} - pop {r4,r5,r6,r7,r8,r9,r10} - bx lr - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M1.StoreOutputPartial: - cbz r5,.LGemmU8X8.M1.StoreOutputPartial.AddMode - -.LGemmU8X8.M1.StoreOutputPartial.ZeroMode: - tst r12,#4 - beq .LGemmU8X8.M1.StoreOutputPartial2.ZeroMode - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode: - tst r12,#2 - beq .LGemmU8X8.M1.StoreOutputPartial1.ZeroMode - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode: - tst r12,#1 - beq .LGemmU8X8.M1.ExitKernel - vst1.32 d16[0],[r2] - b .LGemmU8X8.M1.ExitKernel - -.LGemmU8X8.M1.StoreOutputPartial.AddMode: - tst r12,#4 - beq .LGemmU8X8.M1.StoreOutputPartial2.AddMode - vld1.32 {d0-d1},[r2] - vadd.u32 q8,q8,q0 - vst1.32 {d16-d17},[r2]! - vmov q8,q9 // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.AddMode: - tst r12,#2 - beq .LGemmU8X8.M1.StoreOutputPartial1.AddMode - vld1.32 {d0},[r2] - vadd.u32 d16,d16,d0 - vst1.32 {d16},[r2]! - vmov d16,d17 // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.AddMode: - tst r12,#1 - beq .LGemmU8X8.M1.ExitKernel - vld1.32 d0[0],[r2] - vadd.u32 d16,d16,d0 - vst1.32 d16[0],[r2] - b .LGemmU8X8.M1.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch32/asmmacro.h b/onnxruntime/core/mlas/lib/aarch32/asmmacro.h deleted file mode 100644 index 72982db00352f..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch32/asmmacro.h +++ /dev/null @@ -1,95 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - asmmacro.h - -Abstract: - - This module implements common macros for the assembly modules. - ---*/ - -/*++ - -Macro Description: - - This macro emits the assembler directives to annotate a new function. - -Arguments: - - FunctionName - Supplies the name of the function. - ---*/ - - .macro FUNCTION_ENTRY FunctionName - - .p2align 2 -#if defined(__APPLE__) - .globl _\FunctionName\() -_\FunctionName\(): -#else - .globl \FunctionName\() - .type \FunctionName\(),%function -\FunctionName\(): -#endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count is greater than or - equal to Value. - -Arguments: - - Count - Supplies the variable used in the comparison. - - Value - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCountGE Count1, Value1, Statement - -.if (\Count1\() >= \Value1\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count1 is greater than or - equal to Value1 and Count2 is greater than or equal to Value2. - -Arguments: - - Count1 - Supplies the variable used in the comparison. - - Value1 - Supplies the static used in the comparison. - - Count2 - Supplies the variable used in the comparison. - - Value2 - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement - -.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) - \Statement\() -.endif - - .endm diff --git a/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h b/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h deleted file mode 100644 index 3af76acddbf87..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/AssembleDotProduct.h +++ /dev/null @@ -1,86 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - AssembleDotProduct.h - -Abstract: - - This module contains macros to build Advanced SIMD dot product instructions - for toolchains that do not natively support this newer instruction set - extension. - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -/*++ - -Macro Description: - - This macro builds a SDOT instruction of the form: - - SDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] - -Arguments: - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - - Index - Specifies the element index of the second source register. - ---*/ - - .macro SdotByElement DestReg, Src1Reg, Src2Reg, Index - - .set Instruction, 0x4F80E000 - .set Instruction, Instruction + (\DestReg\() << 0) - .set Instruction, Instruction + (\Src1Reg\() << 5) - .set Instruction, Instruction + (\Src2Reg\() << 16) - .set Instruction, Instruction + ((\Index\() & 2) << 10) - .set Instruction, Instruction + ((\Index\() & 1) << 21) - - .inst Instruction - - .endm - -/*++ - -Macro Description: - - This macro builds a UDOT instruction of the form: - - UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] - -Arguments: - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - - Index - Specifies the element index of the second source register. - ---*/ - - .macro UdotByElement DestReg, Src1Reg, Src2Reg, Index - - .set Instruction, 0x6F80E000 - .set Instruction, Instruction + (\DestReg\() << 0) - .set Instruction, Instruction + (\Src1Reg\() << 5) - .set Instruction, Instruction + (\Src2Reg\() << 16) - .set Instruction, Instruction + ((\Index\() & 2) << 10) - .set Instruction, Instruction + ((\Index\() & 1) << 21) - - .inst Instruction - - .endm - diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S deleted file mode 100644 index 30b7276340254..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S +++ /dev/null @@ -1,575 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelDot.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// - .equ .LConvSymFrame_SavedRegisters, (6 * 8) - .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters - .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_Min, 16 - .equ .LConvSymPostProcessParams_Max, 20 - .equ .LConvSymPostProcessParams_ZeroPoint, 24 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the indirection buffer. Every pointer in the indirection - buffer points at a InputChannels length vector (either from the input tensor - or a vector of padding values). These are grouped in batches of length - KernelSize. These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - Must be > 1 - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - FUNCTION_ENTRY MlasConvSymS8KernelDot - - stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! - ldr x8,[sp,#.LConvSymFrame_PostProcessParams] - str d10,[sp,#16] - cmp x7,2 // OutputCount < 2 ? - str d11,[sp,#24] - add x16,x2,x5 // x16 -> C1 - str x19,[sp,#32] - lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - add x4,x4,3 // InputChannels align to 4 - add x17,x16,x5 // x17 -> C2 - ldr x11,[x8,#.LConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - ldr w10,[sp,#.LConvSymFrame_KernelFlags] - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#.LConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - - // TODO!! tiptoe around loading biases if we need to support - // output channels none divisible by 16 -OutputChannelLoop: - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop: - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1: - mov x13,x12 -SkipLoadA2: - mov x14,x13 -SkipLoadA3: - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr: - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - ldr q6,[x1],16 - ldr q7,[x1],16 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop: - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - ldr d0,[x12],8 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - ldr d1,[x13],8 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - ldr d2,[x14],8 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - ldr d3,[x15],8 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x7,x7,16 // InputChannels -= 16 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr q7,[x1],16 - b.hs InputChannelLoop - -InChLoopEpilogue: - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - tst x7,15 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - b.ne InChannels8 // 4 ~ 12 InputChannels - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize: - tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#.LConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue: - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat: - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtn v0.8b,v16.8h - sqxtn v1.8b,v17.8h - sqxtn v2.8b,v18.8h - sqxtn v3.8b,v19.8h - sqxtn2 v0.16b,v24.8h - sqxtn2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtn2 v2.16b,v26.8h - sqxtn2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel: - ldr x19,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LConvSymFrame_SavedRegisters - ret - -InChannels8: - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldp q4, q5, [x1], 32 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - tbz x7,2,SkipInCh4 - -InChannels4: - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - -SkipInCh4: - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore: - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store: - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store: - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store: - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDotLd64.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDotLd64.S deleted file mode 100644 index 3e03ff7b42ec4..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDotLd64.S +++ /dev/null @@ -1,653 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelDotLd64.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - - .equ .LMLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// - .equ .LConvSymFrame_SavedRegisters, (10 * 8) - .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters - .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_Min, 16 - .equ .LConvSymPostProcessParams_Max, 20 - .equ .LConvSymPostProcessParams_ZeroPoint, 24 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - FUNCTION_ENTRY MlasConvSymS8KernelDotLd64 - - stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! - ldr x8,[sp,#.LConvSymFrame_PostProcessParams] - str d10,[sp,#16] - cmp x7,2 // OutputCount < 2 ? - str d11,[sp,#24] - add x16,x2,x5 // x16 -> C1 - str x19,[sp,#32] - lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - str x20,[sp,#40] - csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - str x21,[sp,#48] - add x4,x4,3 // InputChannels align to 4 - str x22,[sp,#56] - add x17,x16,x5 // x17 -> C2 - str x23,[sp,#64] - ldr x11,[x8,#.LConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - ldr w10,[sp,#.LConvSymFrame_KernelFlags] - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#.LConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - - // TODO!! tiptoe around loading biases if we need to support - // output channels none divisible by 16 -OutputChannelLoop: - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop: - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1: - mov x13,x12 -SkipLoadA2: - mov x14,x13 -SkipLoadA3: - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr: - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr d5,[x1],#8 - ldr x21,[x1],#8 - ldr d6,[x1],#8 - ldr x22,[x1],#8 - ldr d7,[x1],#8 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop: - SdotByElement 16, 4, 0,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,0 - ldr d8,[x12],8 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,0 - ldr d9,[x13],8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,0 - ldr d10,[x14],8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,0 - ldr d11,[x15],8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x21 - SdotByElement 18, 4,10,0 - ldr d0,[x12],8 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x22 - SdotByElement 22, 5,10,0 - ldr d1,[x13],8 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x23 - SdotByElement 26, 6,10,0 - ldr d2,[x14],8 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x20 - SdotByElement 30, 7,10,0 - ldr d3,[x15],8 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x21 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x22 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x23 - SdotByElement 26, 6,10,1 - subs x7,x7,16 // InputChannels -= 16 - SdotByElement 27, 6,11,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,1 - ins v4.d[1],x20 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr d7,[x1],#8 - b.hs InputChannelLoop - -InChLoopEpilogue: - SdotByElement 16, 4, 0,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,0 - ldr d8,[x12],8 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,0 - ldr d9,[x13],8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,0 - ldr d10,[x14],8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,0 - ldr d11,[x15],8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x21 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x22 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x23 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x20 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x21 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x22 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x23 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - - tst x7,15 - b.ne InChannels8 // 4 ~ 12 InputChannels - - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize: - tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#.LConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue: - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat: - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtn v0.8b,v16.8h - sqxtn v1.8b,v17.8h - sqxtn v2.8b,v18.8h - sqxtn v3.8b,v19.8h - sqxtn2 v0.16b,v24.8h - sqxtn2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtn2 v2.16b,v26.8h - sqxtn2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel: - ldr x23,[sp,#64] - ldp x21,x22,[sp,#48] - ldp x19,x20,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LConvSymFrame_SavedRegisters - ret - -InChannels8: - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldp q4, q5, [x1], 32 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - tbz x7,2,SkipInCh4 - -InChannels4: - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - ldr q5, [x1], 16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - -SkipInCh4: - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore: - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store: - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store: - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store: - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S deleted file mode 100644 index 9f623ee7b27d8..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S +++ /dev/null @@ -1,423 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelNeon.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "asmmacro.h" - - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// - .equ .LConvSymFrame_SavedNeonRegisters, (8 * 8) - .equ .LConvSymFrame_SavedRegisters, .LConvSymFrame_SavedNeonRegisters - .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters - .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_Min, 16 - .equ .LConvSymPostProcessParams_Max, 20 - .equ .LConvSymPostProcessParams_ZeroPoint, 24 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirect buffer. Every pointer in - the indirection buffer points at a InputChannels length vector (either - from the input tensor or a vector of padding values). These are grouped - in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. Must be > 1 - - InputChannels (x4) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 8. - - OutputChannels (x5) - Supplies the number of output channels. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 8. - - OutputCount (x7) - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be 1 or 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - FUNCTION_ENTRY MlasConvSymS8KernelNeon - - stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! - ldr x8,[sp,#.LConvSymFrame_PostProcessParams] - ldrb w10,[sp,#.LConvSymFrame_KernelFlags] - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp d14,d15,[sp,#48] - mov x9,x3 // save kernel size - ldr x11,[x8,#.LConvSymPostProcessParams_Bias] - mov x16,x4 // save input channels - ldr x12,[x8,#.LConvSymPostProcessParams_Scale] - cmp x7,2 // if OutputCount < 2 - add x5,x2,x5 // c1 = c0 + ldc - add x4,x4,7 // kc = (kc + 7) & ~7 - csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 - bic x4,x4,7 - ldp s16,s18,[x11],8 // init accumulators with bias - ldp s20,s22,[x11],8 - ldp s24,s26,[x11],8 - ldp s28,s30,[x11],8 - mov v17.16b,v16.16b - mov v19.16b,v18.16b - mov v21.16b,v20.16b - mov v23.16b,v22.16b - mov v25.16b,v24.16b - mov v27.16b,v26.16b - mov v29.16b,v28.16b - mov v31.16b,v30.16b - -// Nested loops, inner loop: input channel; outter loop: kernel size -// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. -// -// B 8x8 -// ------------------------------------------------------------------ -// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | -// | ... ... ... ... ... ... ... ... | -// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | -// A 2x8 ------------------------------------------------------------------ -// ------------------ ------------------------------------------------------------------ -// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | -// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | -// ------------------ ------------------------------------------------------------------ -// When Input Channels greater than 16, unroll: -// A registers v6 v7, -// B registers v8 v9 -// - -.LConvSym.KernelSizeLoop: - - # Load next 2 A pointers - cmp x7,2 // test if OutputCount < 2 - ldr x13,[x0] // x13 -> A0 - bhs .LConvSym.LoadA1 - ldr x15,[x0],#8 // x15 -> A0 - b .LConvSym.BlockLoopPrologue -.LConvSym.LoadA1: - ldr x15,[x0,x3,lsl#3] // x15 -> A1 - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop -.LConvSym.BlockLoopPrologue: - ldr d4,[x1] - subs x14,x4,16 // input channel - 16 - ldr d5,[x1,8] - blo .LConvSym.8InputChannels // less than 16 deep, no unroll - - ldr d0,[x13],8 - ldr d1,[x15],8 - ldr d8,[x1,64] - ldr d9,[x1,72] - ldr d6,[x13],8 - subs x14,x14,16 // input channel - 16 - ldr d7,[x15],8 - blo .LConvSym.BlockLoopEpilogue // need 32 input channel for full unrolled loop - -.LConvSym.Blockloop: - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d5,[x1,40] - sadalp v19.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,56] - sadalp v23.4s, v15.8h - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - add x1,x1,128 - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1] // Read B - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - ldr d0,[x13],8 // Read A0 - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d1,[x15],8 // Read A1 - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - ldr d5,[x1,8] // Read B - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,64] // Read B - smlal v14.8h,v9.8b,v6.8b - ldr d6,[x13],8 // Read A0 - smlal v15.8h,v9.8b,v7.8b - ldr d7,[x15],8 // Read A1 - sadalp v28.4s,v12.8h - ldr d9,[x1,72] // Read B - sadalp v29.4s,v13.8h - subs x14,x14,16 - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b.hs .LConvSym.Blockloop - -.LConvSym.BlockLoopEpilogue: // remaining 16 input channels - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v19.4s,v11.8h - ldr d5,[x1,40] - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - ldr d5,[x1,56] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - add x1,x1,128 - - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - tbnz x14,3,.LConvSym.8InputChannels - - subs x9,x9,1 - b.hi .LConvSym.KernelSizeLoop - -.LConvSym.Requantize: - ldr w11, [x8, #.LConvSymPostProcessParams_ZeroPoint] - tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - beq .LConvSym.BroadcastScaleValue - ld1 {v4.4s,v5.4s},[x12] // load scale vector - b .LConvSym.AccumulatorsToFloat - -.LConvSym.BroadcastScaleValue: - ld1r {v4.4s},[x12] // load scale Value - mov v5.16b, v4.16b - -.LConvSym.AccumulatorsToFloat: - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - addp v17.4s,v17.4s,v19.4s - addp v21.4s,v21.4s,v23.4s - addp v25.4s,v25.4s,v27.4s - addp v29.4s,v29.4s,v31.4s - addp v0.4s,v16.4s,v20.4s - addp v1.4s,v24.4s,v28.4s - addp v2.4s,v17.4s,v21.4s - addp v3.4s,v25.4s,v29.4s - scvtf v0.4s,v0.4s // convert to float - scvtf v1.4s,v1.4s - scvtf v2.4s,v2.4s - scvtf v3.4s,v3.4s - fmul v0.4s,v0.4s,v4.4s // multiply by scale - fmul v1.4s,v1.4s,v5.4s - fmul v2.4s,v2.4s,v4.4s - fmul v3.4s,v3.4s,v5.4s - fcvtns v0.4s,v0.4s // convert to int - fcvtns v1.4s,v1.4s - dup v9.8h,w11 - fcvtns v2.4s,v2.4s - fcvtns v3.4s,v3.4s - sqxtn v0.4h,v0.4s - sqxtn2 v0.8h,v1.4s - sqxtn v2.4h,v2.4s - sqxtn2 v2.8h,v3.4s - subs x6, x6, 8 - sqadd v0.8h,v0.8h,v9.8h - sqadd v2.8h,v2.8h,v9.8h - sqxtn v0.8b,v0.8h // shorten to int8 - sqxtn2 v0.16b,v2.8h - b.lo .LConvSym.PartialStore - - st1 {v0.d}[1],[x5] // full 2x8 store to c - st1 {v0.8b},[x2] - -.LConvSym.ExitKernel: - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LConvSym.8InputChannels: - ldr d0,[x13] - ldr d1,[x15] - ldr d4,[x1] - ldr d5,[x1,8] - ldr d6,[x1,16] - ldr d7,[x1,24] - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,32] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,40] - smull v12.8h,v6.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - ldr d6,[x1,48] - sadalp v17.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - ldr d7,[x1,56] - sadalp v19.4s,v11.8h - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - smull v12.8h,v6.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - sadalp v27.4s,v11.8h - add x1,x1,64 - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - - # ks loop - subs x9,x9,1 - b.hi .LConvSym.KernelSizeLoop - b .LConvSym.Requantize - -.LConvSym.PartialStore: - tbz x6,2,.LConvSym.Store2 - st1 {v0.s}[2],[x5],4 - str s0,[x2],4 - EXT v0.16b,v0.16b,v0.16b,4 - -.LConvSym.Store2: - tbz x6, 1, .LConvSym.Store1 - st1 {v0.h}[4], [x5], 2 - str h0, [x2], 2 - EXT v0.16b,v0.16b,v0.16b,2 -.LConvSym.Store1: - tbz x6,0,.LConvSym.ExitKernel - st1 {v0.b}[8],[x5] - str b0,[x2] - b .LConvSym.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S deleted file mode 100644 index dc11ad93c6349..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S +++ /dev/null @@ -1,628 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymKernelNeonDot.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - - .equ .LMLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// - .equ .LConvSymFrame_SavedRegisters, (8 * 8) - .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters - .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_Min, 16 - .equ .LConvSymPostProcessParams_Max, 20 - .equ .LConvSymPostProcessParams_ZeroPoint, 24 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - FUNCTION_ENTRY MlasConvSymU8KernelDot - - stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! - ldr x8,[sp,#.LConvSymFrame_PostProcessParams] - ldr w10,[sp,#.LConvSymFrame_KernelFlags] - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp x19,x20,[sp,#48] - - cmp x7,2 // OutputCount < 2 ? - add x16,x2,x5 // x16 -> C1 - lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - mov x20,x4 - add x4,x4,3 // InputChannels align to 4 - add x17,x16,x5 // x17 -> C2 - ldr x11,[x8,#.LConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#.LConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - movi v12.16b,128 // for top bit flipping - -OutputChannelLoop: - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop: - tst w10,#.LMLAS_CONV_SYM_FLAG_INPUT_DIRECT - beq InputIndirection - -InputDirect: - cmp x16,x2 - mov x12,x0 // x12 -> A0 - add x13,x0,x20 // x13 -> A1 = A0 + input channels - csel x13,x0,x13,eq - cmp x17,x16 - add x14,x0,x20,lsl#1 // x14 -> A2 - csel x14,x13,x14,eq - cmp x5,x17 - add x15,x13,x20,lsl#1 // x15 -> A3 - csel x15,x14,x15,eq - b FinishLoadAPtr - -InputIndirection: - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1: - mov x13,x12 -SkipLoadA2: - mov x14,x13 -SkipLoadA3: - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr: - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - ldr q6,[x1],16 - ldr q7,[x1],16 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop: - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - SdotByElement 16, 4, 0,0 - eor v2.8b,v2.8b,v12.8b - SdotByElement 17, 4, 1,0 - eor v3.8b,v3.8b,v12.8b - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - eor v8.8b,v8.8b,v12.8b - ldr q7,[x1],16 - eor v9.8b,v9.8b,v12.8b - SdotByElement 16, 4, 8,0 - eor v10.8b,v10.8b,v12.8b - SdotByElement 17, 4, 9,0 - ldr d0,[x12],8 - eor v11.8b,v11.8b,v12.8b - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - ldr d1,[x13],8 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - ldr d2,[x14],8 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - ldr d3,[x15],8 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x7,x7,16 // InputChannels -= 16 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr q7,[x1],16 - b.hs InputChannelLoop - -InChLoopEpilogue: - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - SdotByElement 16, 4, 0,0 - eor v2.8b,v2.8b,v12.8b - SdotByElement 17, 4, 1,0 - eor v3.8b,v3.8b,v12.8b - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - eor v8.8b,v8.8b,v12.8b - ldr q7,[x1],16 - eor v9.8b,v9.8b,v12.8b - SdotByElement 16, 4, 8,0 - eor v10.8b,v10.8b,v12.8b - SdotByElement 17, 4, 9,0 - eor v11.8b,v11.8b,v12.8b - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - - tst x7,15 - b.ne InChannels8 // 4 ~ 12 InputChannels - - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize: - tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#.LConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue: - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat: - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtun v0.8b,v16.8h - sqxtun v1.8b,v17.8h - sqxtun v2.8b,v18.8h - sqxtun v3.8b,v19.8h - sqxtun2 v0.16b,v24.8h - sqxtun2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtun2 v2.16b,v26.8h - sqxtun2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel: - ldp x19,x20,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LConvSymFrame_SavedRegisters - ret - -InChannels8: - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - eor v0.8b,v0.8b,v12.8b - ldr q5,[x1],16 - eor v1.8b,v1.8b,v12.8b - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - eor v2.8b,v2.8b,v12.8b - ldp q6, q7, [x1], 32 - eor v3.8b,v3.8b,v12.8b - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldp q4, q5, [x1], 32 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - tbz x7,2,SkipInCh4 - -InChannels4: - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - eor v0.8b,v0.8b,v12.8b - ldr q5, [x1], 16 - eor v1.8b,v1.8b,v12.8b - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - eor v2.8b,v2.8b,v12.8b - ldp q6, q7, [x1], 32 - eor v3.8b,v3.8b,v12.8b - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - -SkipInCh4: - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore: - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store: - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store: - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store: - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S deleted file mode 100644 index fd16b2cbae2cd..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S +++ /dev/null @@ -1,454 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymU8KernelNeon.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "asmmacro.h" - - .equ .LMLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// - .equ .LConvSymFrame_SavedNeonRegisters, (8 * 8) - .equ .LConvSymFrame_SavedRegisters, .LConvSymFrame_SavedNeonRegisters - .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters - .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_Min, 16 - .equ .LConvSymPostProcessParams_Max, 20 - .equ .LConvSymPostProcessParams_ZeroPoint, 24 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 8. - - OutputChannels (x5) - Supplies the number of output channels. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 8. - - OutputCount (x7) - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be 1 or 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - FUNCTION_ENTRY MlasConvSymU8KernelNeon - - stp d8,d9,[sp,#-64]! - ldr x8,[sp,#.LConvSymFrame_PostProcessParams] - ldrb w10,[sp,#.LConvSymFrame_KernelFlags] - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp d14,d15,[sp,#48] - mov x9,x3 // save kernel size - ldr x11,[x8,#.LConvSymPostProcessParams_Bias] - mov x16,x4 // save input channels - ldr x12,[x8,#.LConvSymPostProcessParams_Scale] - cmp x7,2 // if OutputCount < 2 - add x5,x2,x5 // c1 = c0 + ldc - add x4,x4,7 // kc = (kc + 7) & ~7 - csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 - bic x4,x4,7 - ldp s16,s18,[x11],8 // init accumulators with bias - ldp s20,s22,[x11],8 - ldp s24,s26,[x11],8 - ldp s28,s30,[x11],8 - mov v17.16b,v16.16b - mov v19.16b,v18.16b - mov v21.16b,v20.16b - mov v23.16b,v22.16b - mov v25.16b,v24.16b - mov v27.16b,v26.16b - mov v29.16b,v28.16b - mov v31.16b,v30.16b - -// Nested loops, inner loop: input channel; outter loop: kernel size -// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. -// -// B 8x8 -// ------------------------------------------------------------------ -// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | -// | ... ... ... ... ... ... ... ... | -// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | -// A 2x8 ------------------------------------------------------------------ -// ------------------ ------------------------------------------------------------------ -// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | -// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | -// ------------------ ------------------------------------------------------------------ -// When Input Channels greater than 16, unroll: -// A registers v6 v7, -// B registers v8 v9 -// - -.LConvSym.KernelSizeLoop: - - # Load next 2 A pointers - tst w10,#.LMLAS_CONV_SYM_FLAG_INPUT_DIRECT - ldr d4,[x1] - ldr d5,[x1,8] - beq .LConvSym.InputIndirection - -.LConvSym.InputDirect: - mov x13,x0 // x13 -> A0 - add x15,x0,x16 // x15 -> A1 = A0 + input channels - b .LConvSym.BlockLoopPrologue - -.LConvSym.InputIndirection: - cmp x7,2 // test if OutputCount < 2 - ldr x13,[x0] // x13 -> A0 - blo .LConvSym.SkipLoadA1 - ldr x15,[x0,x3,lsl#3] // x15 -> A1 -.LConvSym.SkipLoadA1: - -.LConvSym.BlockLoopPrologue: - cmp x7,2 // test if OutputCount < 2 - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - csel x15,x13,x15,lo // if OutputCount < 2 x15 -> A0 - subs x14,x4,16 // input channel - 16 - movi v12.8b,128 - blo .LConvSym.8InputChannels // less than 16 deep, no unroll - - ldr d0,[x13],8 - ldr d1,[x15],8 - ldr d8,[x1,64] - ldr d9,[x1,72] - ldr d6,[x13],8 - subs x14,x14,16 // input channel - 16 - ldr d7,[x15],8 - blo .LConvSym.BlockLoopEpilogue // need 32 input channel for full unrolled loop - -.LConvSym.Blockloop: - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - eor v6.8b,v6.8b,v12.8b - eor v7.8b,v7.8b,v12.8b - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d5,[x1,40] - sadalp v19.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,56] - sadalp v23.4s, v15.8h - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - add x1,x1,128 - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1] // Read B - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - ldr d0,[x13],8 // Read A0 - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d1,[x15],8 // Read A1 - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - ldr d5,[x1,8] // Read B - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,64] // Read B - smlal v14.8h,v9.8b,v6.8b - ldr d6,[x13],8 // Read A0 - smlal v15.8h,v9.8b,v7.8b - ldr d7,[x15],8 // Read A1 - sadalp v28.4s,v12.8h - ldr d9,[x1,72] // Read B - sadalp v29.4s,v13.8h - subs x14,x14,16 - sadalp v30.4s,v14.8h - movi v12.8b,128 - sadalp v31.4s,v15.8h - b.hs .LConvSym.Blockloop - -.LConvSym.BlockLoopEpilogue: // remaining 16 input channels - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - eor v6.8b,v6.8b,v12.8b - eor v7.8b,v7.8b,v12.8b - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v19.4s,v11.8h - ldr d5,[x1,40] - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - ldr d5,[x1,56] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - add x1,x1,128 - - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - movi v12.8b,128 - tbnz x14,3,.LConvSym.8InputChannels - - subs x9,x9,1 - b.hi .LConvSym.KernelSizeLoop - -.LConvSym.Requantize: - ldr w11, [x8, #.LConvSymPostProcessParams_ZeroPoint] - tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - beq .LConvSym.BroadcastScaleValue - ld1 {v4.4s,v5.4s},[x12] // load scale vector - b .LConvSym.AccumulatorsToFloat - -.LConvSym.BroadcastScaleValue: - ld1r {v4.4s},[x12] // load scale Value - mov v5.16b, v4.16b - -.LConvSym.AccumulatorsToFloat: - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - addp v17.4s,v17.4s,v19.4s - addp v21.4s,v21.4s,v23.4s - addp v25.4s,v25.4s,v27.4s - addp v29.4s,v29.4s,v31.4s - addp v0.4s,v16.4s,v20.4s - addp v1.4s,v24.4s,v28.4s - addp v2.4s,v17.4s,v21.4s - addp v3.4s,v25.4s,v29.4s - scvtf v0.4s,v0.4s // convert to float - scvtf v1.4s,v1.4s - scvtf v2.4s,v2.4s - scvtf v3.4s,v3.4s - fmul v0.4s,v0.4s,v4.4s // multiply by scale - fmul v1.4s,v1.4s,v5.4s - fmul v2.4s,v2.4s,v4.4s - fmul v3.4s,v3.4s,v5.4s - fcvtns v0.4s,v0.4s // convert to int - fcvtns v1.4s,v1.4s - dup v9.8h,w11 - fcvtns v2.4s,v2.4s - fcvtns v3.4s,v3.4s - sqxtn v0.4h,v0.4s - sqxtn2 v0.8h,v1.4s - sqxtn v2.4h,v2.4s - sqxtn2 v2.8h,v3.4s - subs x6, x6, 8 - sqadd v0.8h,v0.8h,v9.8h - sqadd v2.8h,v2.8h,v9.8h - sqxtun v0.8b,v0.8h // shorten to int8 - sqxtun2 v0.16b,v2.8h - b.lo .LConvSym.PartialStore - - st1 {v0.d}[1],[x5] // full 2x8 store to c - st1 {v0.8b},[x2] - -.LConvSym.ExitKernel: - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LConvSym.8InputChannels: - ldr d0,[x13] - ldr d1,[x15] - ldr d4,[x1] - ldr d5,[x1,8] - ldr d6,[x1,16] - ldr d7,[x1,24] - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,32] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,40] - smull v12.8h,v6.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - ldr d6,[x1,48] - sadalp v17.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - ldr d7,[x1,56] - sadalp v19.4s,v11.8h - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - smull v12.8h,v6.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - sadalp v27.4s,v11.8h - add x1,x1,64 - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - - # ks loop - subs x9,x9,1 - b.hi .LConvSym.KernelSizeLoop - b .LConvSym.Requantize - -.LConvSym.PartialStore: - tbz x6,2,.LConvSym.Store2 - st1 {v0.s}[2],[x5],4 - str s0,[x2],4 - EXT v0.16b,v0.16b,v0.16b,4 - -.LConvSym.Store2: - tbz x6, 1, .LConvSym.Store1 - st1 {v0.h}[4], [x5], 2 - str h0, [x2], 2 - EXT v0.16b,v0.16b,v0.16b,2 -.LConvSym.Store1: - tbz x6,0,.LConvSym.ExitKernel - st1 {v0.b}[8],[x5] - str b0,[x2] - b .LConvSym.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvKernelSize9Neon.S b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvKernelSize9Neon.S deleted file mode 100644 index 2be4b17a17f34..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvKernelSize9Neon.S +++ /dev/null @@ -1,656 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvKernelSize9Neon.asm - -Abstract: - - This module implements the routine for the depthwise convolution - operation with symmetrically quantized integer values for kernel - size 9. ie, 3x3, 1x9, 9x1 - ---*/ - -#include "asmmacro.h" - - - .equ .LConvSymDepthwisePostProcessParams_Bias, 0 - .equ .LConvSymDepthwisePostProcessParams_Scale, 8 - .equ .LConvSymDepthwisePostProcessParams_ZeroPoint, 24 - - .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, 1 - -// -// Stack frame layout for the depthwise conv kernel. d8-d15, x19-x30 need save -// - .equ .LMlasConvSymDepthwiseKernelSize9_backup_x19_x20, 0 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_x21_x22, 16 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_x23_x24, 32 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_x25_x26, 48 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_x27_x28, 64 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_d8_d9, 80 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_d10_d11, 96 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_d12_d13, 112 - .equ .LMlasConvSymDepthwiseKernelSize9_backup_d14_d15, 128 - .equ .LMlasConvSymDepthwiseKernelSize9_SavedRegisters, 144 - .equ .LMlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg, -144 - - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise quantized convolution - on kernel size 9 for u8s8 - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Channels (x2) - Supplies the number of input and output channels. - - Output (x3) - Supplies the address of the output buffer. - - OutputCount (x4)- Supplies the number of image pixels. - - PostProcessParams (x5) - Supplies the address of the post process parameter block. - - KernelFlags (x6) - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseKernelSize9Arm64U8S8 - - stp x19, x20, [sp, #.LMlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg]! - stp x21, x22, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x21_x22] - stp x23, x24, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x23_x24] - stp x25, x26, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x25_x26] - stp x27, x28, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x27_x28] - stp d8, d9, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d8_d9] - stp d10, d11, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d10_d11] - stp d12, d13, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d12_d13] - stp d14, d15, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d14_d15] - - ldr x9, [x5, #.LConvSymDepthwisePostProcessParams_Bias] - ldr x8, [x5, #.LConvSymDepthwisePostProcessParams_Scale] - add x5, x5, #.LConvSymDepthwisePostProcessParams_ZeroPoint - ins v12.d[0], x1 // Filter - ins v13.d[0], x9 // Bias - ins v13.d[1], x8 // Scale - ld1r {v0.8h}, [x5] // zero point - movi v5.16b, #0x80 // flip 0x80 - - tbnz x6, #.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, .LMlasConvSymDepthwiseKernelSize9_SkipPerTensorScaleInit - ld1r {v1.4s}, [x8] // load scale value - mov v2.16b, v1.16b - mov v3.16b, v1.16b - mov v4.16b, v1.16b - -.LMlasConvSymDepthwiseKernelSize9_SkipPerTensorScaleInit: - - add x9, x3, x2 // x9 <---- Ouput1 - cbz x4, .LMlasConvSymDepthwiseKernelSize9_Exit - -.LMlasConvSymDepthwiseKernelSize9_OutputLoop: - ldp x20, x21, [x0], #72 // input ptrs for Output0 - ldp x22, x23, [x0, #-56] - sub x4, x4, #1 - ldp x24, x25, [x0, #-40] - ldp x26, x27, [x0, #-24] - ldur x28, [x0, #-8] - - cbz x4, .LMlasConvSymDepthwiseKernelSize9_Dup_Inputs - ldp x10, x11, [x0], #72 // input ptrs for Output0 - ldp x12, x13, [x0, #-56] - sub x4, x4, #1 - ldp x14, x15, [x0, #-40] - ldp x16, x17, [x0, #-24] - ldur x19, [x0, #-8] - b .LMlasConvSymDepthwiseKernelSize9_Loaded_Input - -.LMlasConvSymDepthwiseKernelSize9_Dup_Inputs: - mov x9, x3 // Output1 <-- Output0 - mov x10, x20 - mov x11, x21 - mov x12, x22 - mov x13, x23 - mov x14, x24 - mov x15, x25 - mov x16, x26 - mov x17, x27 - mov x19, x28 - -.LMlasConvSymDepthwiseKernelSize9_Loaded_Input: - - eor x8, x8, x8 // Processed channels - umov x1, v12.D[0] // filter - umov x5, v13.D[0] // bias - umov x7, v13.D[1] // scale - - cmp x8, x2 // Save one register by not using count down to zero here - bhs .LMlasConvSymDepthwiseKernelSize9_Finish_Channels16_Loop - -.LMlasConvSymDepthwiseKernelSize9_Channels16_Loop: - ld1 {v10.16b}, [x1], x2 // vk0 - ldr q16, [x20, x8] // out0 vi0 - ldr q17, [x10, x8] // out1 vi0 - ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [x5], #64 // bias vacc 0-15 for outs - ld1 {v11.16b}, [x1], x2 // vk1 - ldr q18, [x21, x8] // out0 vi1 - ldr q19, [x11, x8] // out1 vi1 - - eor v16.16b, v16.16b, v5.16b // -128 to signed int8 - eor v17.16b, v17.16b, v5.16b - ld1 {v14.16b}, [x1], x2 // vk2 - eor v18.16b, v18.16b, v5.16b - eor v19.16b, v19.16b, v5.16b - - ldr q20, [x22, x8] // out0 vi2 - smull v24.8h, v10.8b, v16.8b - smull2 v25.8h, v10.16b, v16.16b - ldr q21, [x12, x8] // out1 vi2 - smull v26.8h, v10.8b, v17.8b - ld1 {v15.16b}, [x1], x2 // vk3 - smull2 v27.8h, v10.16b, v17.16b - ldr q22, [x23, x8] // out0 vi3 - smull v28.8h, v11.8b, v18.8b - smull2 v29.8h, v11.16b, v18.16b - ldr q23, [x13, x8] // out1 vi3 - smull v30.8h, v11.8b, v19.8b - smull2 v31.8h, v11.16b, v19.16b - - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ld1 {v10.16b}, [x1], x2 // vk4 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - ld1 {v11.16b}, [x1], x2 // vk5 - - saddw v16.4s, v6.4s, v24.4h // dup acc for out1 - saddw2 v17.4s, v7.4s, v24.8h - saddw v18.4s, v8.4s, v25.4h - saddw2 v19.4s, v9.4s, v25.8h - - ldr q20, [x24, x8] // out0 vi4 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x14, x8] // out1 vi4 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - ldr q22, [x25, x8] // out0 vi5 - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - ldr q23, [x15, x8] // out1 vi5 - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - ld1 {v14.16b}, [x1], x2 // vk6 - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ld1 {v15.16b}, [x1], x2 // vk7 - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - smull v28.8h, v11.8b, v22.8b - smull2 v29.8h, v11.16b, v22.16b - smull v30.8h, v11.8b, v23.8b - smull2 v31.8h, v11.16b, v23.16b - - ldr q20, [x26, x8] // out0 vi6 - ldr q21, [x16, x8] // out1 vi6 - ldr q22, [x27, x8] // out0 vi7 - ldr q23, [x17, x8] // out1 vi7 - - tbz x6, #.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, .LDonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9 - ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x7], #64 // scales 0-15 for outs - -.LDonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9: - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ldr q10, [x1] // vk8 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - ldr q20, [x28, x8] // out0 vi8 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x19, x8] // out1 vi8 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - scvtf v16.4s, v16.4s // Requantize - scvtf v17.4s, v17.4s - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v16.4s, v16.4s, v1.4s - fmul v17.4s, v17.4s, v2.4s - fmul v18.4s, v18.4s, v3.4s - fmul v19.4s, v19.4s, v4.4s - fmul v6.4s, v6.4s, v1.4s - fmul v7.4s, v7.4s, v2.4s - fmul v8.4s, v8.4s, v3.4s - fmul v9.4s, v9.4s, v4.4s - - fcvtns v16.4s, v16.4s - fcvtns v17.4s, v17.4s - fcvtns v18.4s, v18.4s - fcvtns v19.4s, v19.4s - fcvtns v6.4s, v6.4s - fcvtns v7.4s, v7.4s - fcvtns v8.4s, v8.4s - fcvtns v9.4s, v9.4s - - sqxtn v16.4h, v16.4s // +zp, narrow and combine - sqxtn v18.4h, v18.4s - sqxtn v6.4h, v6.4s - sqxtn v8.4h, v8.4s - sqxtn2 v16.8h, v17.4s - sqxtn2 v18.8h, v19.4s - sqxtn2 v6.8h, v7.4s - sqxtn2 v8.8h, v9.4s - sqadd v16.8h, v16.8h, v0.8h - sqadd v18.8h, v18.8h, v0.8h - sqadd v6.8h, v6.8h, v0.8h - sqadd v8.8h, v8.8h, v0.8h - sqxtun v16.8b, v16.8h - sqxtun2 v16.16b, v18.8h - sqxtun v6.8b, v6.8h - sqxtun2 v6.16b, v8.8h - - str q16, [x3, x8] - str q6, [x9, x8] - add x8, x8, #16 - umov x1, v12.D[0] // filter - cmp x8, x2 - add x1, x1, x8 - blo .LMlasConvSymDepthwiseKernelSize9_Channels16_Loop - -.LMlasConvSymDepthwiseKernelSize9_Finish_Channels16_Loop: - add x3, x3, x2, LSL #1 - add x9, x9, x2, LSL #1 - cbnz x4, .LMlasConvSymDepthwiseKernelSize9_OutputLoop - -.LMlasConvSymDepthwiseKernelSize9_Exit: - ldp d14, d15, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d14_d15] - ldp d12, d13, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d12_d13] - ldp d10, d11, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d10_d11] - ldp d8, d9, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d8_d9] - ldp x27, x28, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x27_x28] - ldp x25, x26, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x25_x26] - ldp x23, x24, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x23_x24] - ldp x21, x22, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x21_x22] - ldp x19, x20, [sp], #.LMlasConvSymDepthwiseKernelSize9_SavedRegisters - ret - - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise quantized convolution - on kernel size 9 for s8s8 - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Channels (x2) - Supplies the number of input and output channels. - - Output (x3) - Supplies the address of the output buffer. - - OutputCount (x4)- Supplies the number of image pixels. - - PostProcessParams (x5) - Supplies the address of the post process parameter block. - - KernelFlags (x6) - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseKernelSize9Arm64S8S8 - - stp x19, x20, [sp, #.LMlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg]! - stp x21, x22, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x21_x22] - stp x23, x24, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x23_x24] - stp x25, x26, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x25_x26] - stp x27, x28, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x27_x28] - stp d8, d9, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d8_d9] - stp d10, d11, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d10_d11] - stp d12, d13, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d12_d13] - stp d14, d15, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d14_d15] - - ldr x9, [x5, #.LConvSymDepthwisePostProcessParams_Bias] - ldr x8, [x5, #.LConvSymDepthwisePostProcessParams_Scale] - add x5, x5, #.LConvSymDepthwisePostProcessParams_ZeroPoint - ins v12.d[0], x1 // Filter - ins v13.d[0], x9 // Bias - ins v13.d[1], x8 // Scale - ld1r {v0.8h}, [x5] // zero point - - tbnz x6, #.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, .LMlasConvSymDepthwiseKernelSize9S8S8_SkipPerTensorScaleInit - ld1r {v1.4s}, [x8] // load scale value - mov v2.16b, v1.16b - mov v3.16b, v1.16b - mov v4.16b, v1.16b - -.LMlasConvSymDepthwiseKernelSize9S8S8_SkipPerTensorScaleInit: - - add x9, x3, x2 // x9 <---- Ouput1 - cbz x4, .LMlasConvSymDepthwiseKernelSize9S8S8_Exit - -.LMlasConvSymDepthwiseKernelSize9S8S8_OutputLoop: - ldp x20, x21, [x0], #72 // input ptrs for Output0 - ldp x22, x23, [x0, #-56] - sub x4, x4, #1 - ldp x24, x25, [x0, #-40] - ldp x26, x27, [x0, #-24] - ldur x28, [x0, #-8] - - cbz x4, .LMlasConvSymDepthwiseKernelSize9S8S8_Dup_Inputs - ldp x10, x11, [x0], #72 // input ptrs for Output0 - ldp x12, x13, [x0, #-56] - sub x4, x4, #1 - ldp x14, x15, [x0, #-40] - ldp x16, x17, [x0, #-24] - ldur x19, [x0, #-8] - b .LMlasConvSymDepthwiseKernelSize9S8S8_Loaded_Input - -.LMlasConvSymDepthwiseKernelSize9S8S8_Dup_Inputs: - mov x9, x3 // Output1 <-- Output0 - mov x10, x20 - mov x11, x21 - mov x12, x22 - mov x13, x23 - mov x14, x24 - mov x15, x25 - mov x16, x26 - mov x17, x27 - mov x19, x28 - -.LMlasConvSymDepthwiseKernelSize9S8S8_Loaded_Input: - - eor x8, x8, x8 // Processed channels - umov x1, v12.D[0] // filter - umov x5, v13.D[0] // bias - umov x7, v13.D[1] // scale - - cmp x8, x2 // Save one register by not using count down to zero here - bhs .LMlasConvSymDepthwiseKernelSize9S8S8_Finish_Channels16_Loop - -.LMlasConvSymDepthwiseKernelSize9S8S8_Channels16_Loop: - ld1 {v10.16b}, [x1], x2 // vk0 - ldr q16, [x20, x8] // out0 vi0 - ldr q17, [x10, x8] // out1 vi0 - ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [x5], #64 // bias vacc 0-15 for outs - ld1 {v11.16b}, [x1], x2 // vk1 - ldr q18, [x21, x8] // out0 vi1 - ldr q19, [x11, x8] // out1 vi1 - - ld1 {v14.16b}, [x1], x2 // vk2 - - ldr q20, [x22, x8] // out0 vi2 - smull v24.8h, v10.8b, v16.8b - smull2 v25.8h, v10.16b, v16.16b - ldr q21, [x12, x8] // out1 vi2 - smull v26.8h, v10.8b, v17.8b - ld1 {v15.16b}, [x1], x2 // vk3 - smull2 v27.8h, v10.16b, v17.16b - ldr q22, [x23, x8] // out0 vi3 - smull v28.8h, v11.8b, v18.8b - smull2 v29.8h, v11.16b, v18.16b - ldr q23, [x13, x8] // out1 vi3 - smull v30.8h, v11.8b, v19.8b - smull2 v31.8h, v11.16b, v19.16b - - ld1 {v10.16b}, [x1], x2 // vk4 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - ld1 {v11.16b}, [x1], x2 // vk5 - - saddw v16.4s, v6.4s, v24.4h // dup acc for out1 - saddw2 v17.4s, v7.4s, v24.8h - saddw v18.4s, v8.4s, v25.4h - saddw2 v19.4s, v9.4s, v25.8h - - ldr q20, [x24, x8] // out0 vi4 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x14, x8] // out1 vi4 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - ldr q22, [x25, x8] // out0 vi5 - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - ldr q23, [x15, x8] // out1 vi5 - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - ld1 {v14.16b}, [x1], x2 // vk6 - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - ld1 {v15.16b}, [x1], x2 // vk7 - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - smull v28.8h, v11.8b, v22.8b - smull2 v29.8h, v11.16b, v22.16b - smull v30.8h, v11.8b, v23.8b - smull2 v31.8h, v11.16b, v23.16b - - ldr q20, [x26, x8] // out0 vi6 - ldr q21, [x16, x8] // out1 vi6 - ldr q22, [x27, x8] // out0 vi7 - ldr q23, [x17, x8] // out1 vi7 - - tbz x6, #.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, .LDonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9S8S8 - ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x7], #64 // scales 0-15 for outs - -.LDonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9S8S8: - ldr q10, [x1] // vk8 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - ldr q20, [x28, x8] // out0 vi8 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x19, x8] // out1 vi8 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - scvtf v16.4s, v16.4s // Requantize - scvtf v17.4s, v17.4s - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v16.4s, v16.4s, v1.4s - fmul v17.4s, v17.4s, v2.4s - fmul v18.4s, v18.4s, v3.4s - fmul v19.4s, v19.4s, v4.4s - fmul v6.4s, v6.4s, v1.4s - fmul v7.4s, v7.4s, v2.4s - fmul v8.4s, v8.4s, v3.4s - fmul v9.4s, v9.4s, v4.4s - - fcvtns v16.4s, v16.4s - fcvtns v17.4s, v17.4s - fcvtns v18.4s, v18.4s - fcvtns v19.4s, v19.4s - fcvtns v6.4s, v6.4s - fcvtns v7.4s, v7.4s - fcvtns v8.4s, v8.4s - fcvtns v9.4s, v9.4s - - sqxtn v16.4h, v16.4s // +zp, narrow and combine - sqxtn v18.4h, v18.4s - sqxtn v6.4h, v6.4s - sqxtn v8.4h, v8.4s - sqxtn2 v16.8h, v17.4s - sqxtn2 v18.8h, v19.4s - sqxtn2 v6.8h, v7.4s - sqxtn2 v8.8h, v9.4s - sqadd v16.8h, v16.8h, v0.8h - sqadd v18.8h, v18.8h, v0.8h - sqadd v6.8h, v6.8h, v0.8h - sqadd v8.8h, v8.8h, v0.8h - sqxtn v16.8b, v16.8h - sqxtn2 v16.16b, v18.8h - sqxtn v6.8b, v6.8h - sqxtn2 v6.16b, v8.8h - - str q16, [x3, x8] - str q6, [x9, x8] - add x8, x8, #16 - umov x1, v12.D[0] // filter - cmp x8, x2 - add x1, x1, x8 - blo .LMlasConvSymDepthwiseKernelSize9S8S8_Channels16_Loop - -.LMlasConvSymDepthwiseKernelSize9S8S8_Finish_Channels16_Loop: - add x3, x3, x2, LSL #1 - add x9, x9, x2, LSL #1 - cbnz x4, .LMlasConvSymDepthwiseKernelSize9S8S8_OutputLoop - -.LMlasConvSymDepthwiseKernelSize9S8S8_Exit: - ldp d14, d15, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d14_d15] - ldp d12, d13, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d12_d13] - ldp d10, d11, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d10_d11] - ldp d8, d9, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_d8_d9] - ldp x27, x28, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x27_x28] - ldp x25, x26, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x25_x26] - ldp x23, x24, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x23_x24] - ldp x21, x22, [sp, #.LMlasConvSymDepthwiseKernelSize9_backup_x21_x22] - ldp x19, x20, [sp], #.LMlasConvSymDepthwiseKernelSize9_SavedRegisters - ret - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S deleted file mode 100644 index 7a27b9c92a3e5..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S +++ /dev/null @@ -1,692 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvSymS8KernelNeon.S - -Abstract: - - This module implements the kernels for the depthwise convolution - operation with symmetrically quantized integer values - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the depthwise conv kernel. -// d8-d15, x19-x30 need to be preserved if used -// - - .equ .LConvSymDepthwiseKernelFrame_SavedRegisters, (4 * 8) - .equ .LConvSymDepthwiseKernelFrame_PostProcessParams, 0 + .LConvSymDepthwiseKernelFrame_SavedRegisters - .equ .LConvSymDepthwiseKernelFrame_KernelFlags, 8 + .LConvSymDepthwiseKernelFrame_SavedRegisters - - .equ .LConvSymDepthwisePostProcessParams_Bias, 0 - .equ .LConvSymDepthwisePostProcessParams_Scale, 8 - .equ .LConvSymDepthwisePostProcessParams_Min, 16 - .equ .LConvSymDepthwisePostProcessParams_Max, 20 - .equ .LConvSymDepthwisePostProcessParams_ZeroPoint, 24 - - .equ MLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 - .equ MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - Channels (x4) - Supplies the number of input and output channels. - - ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 16 or 8 - - OutputCount (x7)- Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseS8KernelNeon - - stp d12,d13,[sp,#-.LConvSymDepthwiseKernelFrame_SavedRegisters]! - ldr x8,[sp,#.LConvSymDepthwiseKernelFrame_PostProcessParams] - stp d14,d15,[sp,#16] - cmp x7,2 - add x9,x0,x3,lsl#3 // x9 -> &A1 - add x14,x0,x3,lsl#4 // x14 -> &A2 - add x15,x9,x3,lsl#4 // x15 -> &A3 - ldr x16,[x8,#.LConvSymDepthwisePostProcessParams_Bias] - csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 - csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 - ldr x11,[x9],#8 // x11 -> A1 iter 0 - cmp x7,4 - ldp q24,q25,[x16],#32 // init accumulators with bias - csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 - cmp x6,16 - ldr x10,[x0],#8 // x10 -> A0 iter 0 - b.lo .LProcess8Channels - -// -// Process an input block of length Channels for each element of the kernel. -// -// Filter: v0, -// v1 // unroll -// Input: -// x0 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x9 -> x11 -> v6 -// -> x13 -> v3 // unroll -// x14 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x15 -> x11 -> v6 -// -> x13 -> v3 // unroll -// - -.LProcess16Channels: - cmp x3,1 - ldp q26,q27,[x16] - b.eq .LProcC16P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1],x4 // filter iter 0 - ld1 {v1.16b},[x1],x4 // filter iter 1 - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] // A0 iter 0 - mov v20.16b,v24.16b - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v21.16b,v25.16b - ldr q6,[x11,x5] // A1 iter 0 - mov v30.16b,v26.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v31.16b,v27.16b - ldr q2,[x12,x5] // A0 iter 1 - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - mov v19.16b,v27.16b - ldr q3,[x13,x5] // A1 iter 1 - mov v22.16b,v26.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v23.16b,v27.16b - -.LBlockLoopC16: - - // - // Process 2 pixels, and load next two pixels - // - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A2 iter 0 - b.eq .LEpilogueC16P2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x0],#8 // x10 -> A0 iter 2 - smull2 v15.8h,v0.16b,v6.16b - cmp x3,1 - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - b.eq .LEpilogueC16P3 // 3 pixel remains - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x0],#8 // x12 -> A0 iter 3 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A1 iter 2 - ld1 {v0.16b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A1 iter 3 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.16b},[x1],x4 // filter iter 3 - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b .LBlockLoopC16 - -.LEpilogueC16P2: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad2 - ldp q4,q5,[x12],#32 // load scale vector if per channel - ldp q6,q3,[x12] -.LSkipScaleVecLoad2: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b .LDequantization - -.LProcC16P1: - // - // Channel 16 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - ldr x13,[x15],#8 // x13 -> A3 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1] - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] - mov v20.16b,v24.16b - mov v21.16b,v25.16b - ldr q6,[x11,x5] - mov v30.16b,v26.16b - mov v31.16b,v27.16b - ldr q2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - mov v19.16b,v27.16b - ldr q3,[x13,x5] - mov v22.16b,v26.16b - mov v23.16b,v27.16b - b .LEpilogueC16P1 - -.LEpilogueC16P3: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 2 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ld1 {v0.16b},[x1] // filter iter 2 - ldr q6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -.LEpilogueC16P1: - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - smull v12.8h,v0.8b,v2.8b - smull2 v13.8h,v0.16b,v2.16b - smull v14.8h,v0.8b,v3.8b - smull2 v15.8h,v0.16b,v3.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad - ldp q4,q5,[x12],#32 // load scale vector if per channel - ldp q6,q3,[x12] -.LSkipScaleVecLoad: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -.LDequantization: - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - b.ne .LSkipScaleBroadcast - mov v5.16b,v4.16b // broadcast scale val if not per channel - mov v6.16b,v4.16b - mov v3.16b,v4.16b -.LSkipScaleBroadcast: - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v5.4s - fmul v26.4s,v26.4s,v6.4s - fmul v27.4s,v27.4s,v3.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v5.4s - fmul v30.4s,v30.4s,v6.4s - fmul v31.4s,v31.4s,v3.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v5.4s - fmul v18.4s,v18.4s,v6.4s - fmul v19.4s,v19.4s,v3.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v5.4s - fmul v22.4s,v22.4s,v6.4s - fmul v23.4s,v23.4s,v3.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn v26.4h,v26.4s - sqxtn2 v24.8h,v25.4s - sqxtn2 v26.8h,v27.4s - sqxtn v28.4h,v28.4s - sqxtn v30.4h,v30.4s - sqxtn2 v28.8h,v29.4s - sqxtn2 v30.8h,v31.4s - dup v0.8h,w15 - sqxtn v16.4h,v16.4s - sqxtn v18.4h,v18.4s - sqxtn2 v16.8h,v17.4s - sqxtn2 v18.8h,v19.4s - sqxtn v20.4h,v20.4s - sqxtn v22.4h,v22.4s - sqxtn2 v20.8h,v21.4s - sqxtn2 v22.8h,v23.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v26.8h,v26.8h,v0.8h - sqadd v28.8h,v28.8h,v0.8h - sqadd v30.8h,v30.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v18.8h,v18.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqadd v22.8h,v22.8h,v0.8h - sqxtn v24.8b,v24.8h // shorten to int8 - sqxtn2 v24.16b,v26.8h - sqxtn v28.8b,v28.8h - sqxtn2 v28.16b,v30.8h - sqxtn v16.8b,v16.8h - sqxtn2 v16.16b,v18.8h - sqxtn v20.8b,v20.8h - sqxtn2 v20.16b,v22.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.16b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 2 - st1 {v28.16b},[x2],x4 - b.ls .LExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.16b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 4 - str q20,[x2] - -.LExitKernel: - ldp d14,d15,[sp,#16] - ldp d12,d13,[sp],#.LConvSymDepthwiseKernelFrame_SavedRegisters - ret - -.LProcess8Channels: - cmp x3,1 - b.eq .LProcC8P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - ld1 {v0.8b},[x1],x4 // filter iter 0 - ld1 {v1.8b},[x1],x4 // filter iter 1 - ldr d4,[x10,x5] // A0 iter 0 - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v28.16b,v24.16b - ldr d6,[x11,x5] // A1 iter 0 - mov v29.16b,v25.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v16.16b,v24.16b - ldr d2,[x12,x5] // A0 iter 1 - mov v17.16b,v25.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - subs x3,x3,2 // decrement input blocks remaining - ldr d3,[x13,x5] // A1 iter 1 - mov v20.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v21.16b,v25.16b - -.LBlockLoopC8: - // - // Process 2 pixels, and load next two pixels - // - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A2 iter 0 - smull v14.8h,v0.8b,v6.8b - b.eq .LEpilogueC8P2 - ldr x10,[x0],#8 // x10 -> A0 iter 2 - ldr d6,[x11,x5] // A3 iter 0 - cmp x3,1 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal v14.8h,v1.8b,v3.8b - ldr d2,[x12,x5] // A2 iter 1 - b.eq .LEpilogueC8P3 // 3 pixel remains - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - ldr x12,[x0],#8 // x12 -> A0 iter 3 - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - ldr d6,[x11,x5] // A1 iter 2 - ld1 {v0.8b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - ldr d2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - saddw v16.4s,v16.4s,v12.4h - ldr d3,[x13,x5] // A1 iter 3 - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.8b},[x1],x4 // filter iter 3 - saddw v20.4s,v20.4s,v14.4h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw2 v21.4s,v21.4s,v14.8h - b .LBlockLoopC8 - -.LEpilogueC8P2: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr d6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - smull v12.8h,v0.8b,v4.8b - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v14.8h,v0.8b,v6.8b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smlal v12.8h,v1.8b,v2.8b - smlal v14.8h,v1.8b,v3.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad2C8 - ldp q4,q5,[x12],#32 // load scale vector if per channel -.LSkipScaleVecLoad2C8: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - b .LDequantC8 - -.LProcC8P1: - // - // Channel 8 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - mov v28.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 - mov v29.16b,v25.16b - ld1 {v0.8b},[x1] - mov v16.16b,v24.16b - ldr d4,[x10,x5] - mov v17.16b,v25.16b - ldr d6,[x11,x5] - mov v20.16b,v24.16b - ldr d2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - ldr d3,[x13,x5] - mov v21.16b,v25.16b - b .LEpilogueC8P1 - -.LEpilogueC8P3: - // - // Loop epilogue (process 2 of last 3 pixels) - // - ldr x12,[x14],#8 // x12 -> A2 iter 2 - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ld1 {v0.8b},[x1] // filter iter 2 - ldr d6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v3.8b - ldr d3,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -.LEpilogueC8P1: - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - smull v12.8h,v0.8b,v2.8b - smull v14.8h,v0.8b,v3.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoadC8 - ldp q4,q5,[x12] // load scale vector if per channel -.LSkipScaleVecLoadC8: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -.LDequantC8: - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - b.ne .LSkipScaleBroadcastC8 - mov v5.16b,v4.16b // broadcast scale val if not per channel -.LSkipScaleBroadcastC8: - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v5.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v5.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v5.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v5.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - dup v0.8h,w15 - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn2 v24.8h,v25.4s - sqxtn v28.4h,v28.4s - sqxtn2 v28.8h,v29.4s - sqxtn v16.4h,v16.4s - sqxtn2 v16.8h,v17.4s - sqxtn v20.4h,v20.4s - sqxtn2 v20.8h,v21.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v28.8h,v28.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqxtn v24.8b,v24.8h // shorten to int8 - sqxtn v28.8b,v28.8h - sqxtn v16.8b,v16.8h - sqxtn v20.8b,v20.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.8b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 2 - st1 {v28.8b},[x2],x4 - b.ls .LExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.8b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 4 - str d20,[x2] - b .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S deleted file mode 100644 index 05a51253b0c8a..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S +++ /dev/null @@ -1,744 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvSymU8KernelNeon.S - -Abstract: - - This module implements the kernels for the depthwise convolution - operation with symmetrically quantized integer values - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the depthwise conv kernel. -// d8-d15, x19-x30 need to be preserved if used -// - - .equ .LConvSymDepthwiseKernelFrame_SavedNeonRegisters, (8 * 8) - .equ .LConvSymDepthwiseKernelFrame_SavedRegisters, .LConvSymDepthwiseKernelFrame_SavedNeonRegisters - .equ .LConvSymDepthwiseKernelFrame_PostProcessParams, 0 + .LConvSymDepthwiseKernelFrame_SavedRegisters - .equ .LConvSymDepthwiseKernelFrame_KernelFlags, 8 + .LConvSymDepthwiseKernelFrame_SavedRegisters - - .equ .LConvSymDepthwisePostProcessParams_Bias, 0 - .equ .LConvSymDepthwisePostProcessParams_Scale, 8 - .equ .LConvSymDepthwisePostProcessParams_Min, 16 - .equ .LConvSymDepthwisePostProcessParams_Max, 20 - .equ .LConvSymDepthwisePostProcessParams_ZeroPoint, 24 - - .equ MLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 - .equ MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 - - .text - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - Channels (x4) - Supplies the number of input and output channels. - - ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 16 or 8 - - OutputCount (x7)- Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseU8KernelNeon - - stp d8,d9,[sp,#-64]! - ldr x8,[sp,#.LConvSymDepthwiseKernelFrame_PostProcessParams] - mov w10,#0x80808080 - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp d14,d15,[sp,#48] - dup v8.4s,w10 // bit flip vector - ldr x16,[x8,#.LConvSymDepthwisePostProcessParams_Bias] - cmp x7,2 - add x9,x0,x3,lsl#3 // x9 -> &A1 - add x14,x0,x3,lsl#4 // x14 -> &A2 - add x15,x9,x3,lsl#4 // x15 -> &A3 - csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 - csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 - ldr x11,[x9],#8 // x11 -> A1 iter 0 - cmp x7,4 - ldp q24,q25,[x16],#32 // init accumulators with bias - csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 - cmp x6,16 - ldr x10,[x0],#8 // x10 -> A0 iter 0 - b.lo .LProcess8Channels - -// -// Process an input block of length Channels for each element of the kernel. -// -// Filter: v0, -// v1 // unroll -// Input: -// x0 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x9 -> x11 -> v6 -// -> x13 -> v10 // unroll -// x14 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x15 -> x11 -> v6 -// -> x13 -> v10 // unroll -// - -.LProcess16Channels: - cmp x3,1 - ldp q26,q27,[x16] - b.eq .LProcC16P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1],x4 // filter iter 0 - ld1 {v1.16b},[x1],x4 // filter iter 1 - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] // A0 iter 0 - mov v20.16b,v24.16b - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v21.16b,v25.16b - ldr q6,[x11,x5] // A1 iter 0 - mov v30.16b,v26.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v31.16b,v27.16b - ldr q2,[x12,x5] // A0 iter 1 - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - mov v19.16b,v27.16b - ldr q10,[x13,x5] // A1 iter 1 - mov v22.16b,v26.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v23.16b,v27.16b - -.LBlockLoopC16: - - // - // Process 2 pixels, and load next two pixels - // - eor v4.16b,v4.16b,v8.16b // fix sign bits - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A2 iter 0 - b.eq .LEpilogueC16P2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x0],#8 // x10 -> A0 iter 2 - smull2 v15.8h,v0.16b,v6.16b - eor v2.16b,v2.16b,v8.16b - cmp x3,1 - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - b.eq .LEpilogueC16P3 // 3 pixel remains - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x0],#8 // x12 -> A0 iter 3 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v4.16b,v4.16b,v8.16b - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A1 iter 2 - eor v2.16b,v2.16b,v8.16b - ld1 {v0.16b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A1 iter 3 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.16b},[x1],x4 // filter iter 3 - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b .LBlockLoopC16 - -.LEpilogueC16P2: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A3 iter 0 - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.16b,v4.16b,v8.16b - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad2 - ldp q4,q11,[x12],#32 // load scale vector if per channel - ldp q6,q9,[x12] -.LSkipScaleVecLoad2: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b .LDequantization - -.LProcC16P1: - // - // Channel 16 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - ldr x13,[x15],#8 // x13 -> A3 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1] - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] - mov v20.16b,v24.16b - mov v21.16b,v25.16b - ldr q6,[x11,x5] - mov v30.16b,v26.16b - mov v31.16b,v27.16b - ldr q2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - mov v19.16b,v27.16b - ldr q10,[x13,x5] - mov v22.16b,v26.16b - mov v23.16b,v27.16b - b .LEpilogueC16P1 - -.LEpilogueC16P3: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 2 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v4.16b,v4.16b,v8.16b - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ld1 {v0.16b},[x1] // filter iter 2 - ldr q6,[x11,x5] // A1 iter 2 - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -.LEpilogueC16P1: - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.16b,v4.16b,v8.16b - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v2.16b,v2.16b,v8.16b - smull v12.8h,v0.8b,v2.8b - smull2 v13.8h,v0.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - smull v14.8h,v0.8b,v10.8b - smull2 v15.8h,v0.16b,v10.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad - ldp q4,q11,[x12],#32 // load scale vector if per channel - ldp q6,q9,[x12] -.LSkipScaleVecLoad: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -.LDequantization: - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - b.ne .LSkipScaleBroadcast - mov v11.16b,v4.16b // broadcast scale val if not per channel - mov v6.16b,v4.16b - mov v9.16b,v4.16b -.LSkipScaleBroadcast: - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v11.4s - fmul v26.4s,v26.4s,v6.4s - fmul v27.4s,v27.4s,v9.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v11.4s - fmul v30.4s,v30.4s,v6.4s - fmul v31.4s,v31.4s,v9.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v11.4s - fmul v18.4s,v18.4s,v6.4s - fmul v19.4s,v19.4s,v9.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v11.4s - fmul v22.4s,v22.4s,v6.4s - fmul v23.4s,v23.4s,v9.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn v26.4h,v26.4s - sqxtn2 v24.8h,v25.4s - sqxtn2 v26.8h,v27.4s - sqxtn v28.4h,v28.4s - sqxtn v30.4h,v30.4s - sqxtn2 v28.8h,v29.4s - sqxtn2 v30.8h,v31.4s - dup v0.8h,w15 - sqxtn v16.4h,v16.4s - sqxtn v18.4h,v18.4s - sqxtn2 v16.8h,v17.4s - sqxtn2 v18.8h,v19.4s - sqxtn v20.4h,v20.4s - sqxtn v22.4h,v22.4s - sqxtn2 v20.8h,v21.4s - sqxtn2 v22.8h,v23.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v26.8h,v26.8h,v0.8h - sqadd v28.8h,v28.8h,v0.8h - sqadd v30.8h,v30.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v18.8h,v18.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqadd v22.8h,v22.8h,v0.8h - sqxtun v24.8b,v24.8h // shorten to int8 - sqxtun2 v24.16b,v26.8h - sqxtun v28.8b,v28.8h - sqxtun2 v28.16b,v30.8h - sqxtun v16.8b,v16.8h - sqxtun2 v16.16b,v18.8h - sqxtun v20.8b,v20.8h - sqxtun2 v20.16b,v22.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.16b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 2 - st1 {v28.16b},[x2],x4 - b.ls .LExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.16b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 4 - str q20,[x2] - -.LExitKernel: - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LProcess8Channels: - cmp x3,1 - b.eq .LProcC8P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - ld1 {v0.8b},[x1],x4 // filter iter 0 - ld1 {v1.8b},[x1],x4 // filter iter 1 - ldr d4,[x10,x5] // A0 iter 0 - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v28.16b,v24.16b - ldr d6,[x11,x5] // A1 iter 0 - mov v29.16b,v25.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v16.16b,v24.16b - ldr d2,[x12,x5] // A0 iter 1 - mov v17.16b,v25.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - subs x3,x3,2 // decrement input blocks remaining - ldr d10,[x13,x5] // A1 iter 1 - mov v20.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v21.16b,v25.16b - -.LBlockLoopC8: - // - // Process 2 pixels, and load next two pixels - // - eor v4.8b,v4.8b,v8.8b // fix sign bits - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A2 iter 0 - smull v14.8h,v0.8b,v6.8b - b.eq .LEpilogueC8P2 - ldr x10,[x0],#8 // x10 -> A0 iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ldr d6,[x11,x5] // A3 iter 0 - cmp x3,1 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal v14.8h,v1.8b,v10.8b - ldr d2,[x12,x5] // A2 iter 1 - b.eq .LEpilogueC8P3 // 3 pixel remains - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - ldr x12,[x0],#8 // x12 -> A0 iter 3 - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - ldr d6,[x11,x5] // A1 iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ld1 {v0.8b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - ldr d2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - saddw v16.4s,v16.4s,v12.4h - ldr d10,[x13,x5] // A1 iter 3 - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.8b},[x1],x4 // filter iter 3 - saddw v20.4s,v20.4s,v14.4h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw2 v21.4s,v21.4s,v14.8h - b .LBlockLoopC8 - -.LEpilogueC8P2: - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr d6,[x11,x5] // A3 iter 0 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v14.8h,v0.8b,v6.8b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smlal v12.8h,v1.8b,v2.8b - smlal v14.8h,v1.8b,v10.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoad2C8 - ldp q4,q11,[x12],#32 // load scale vector if per channel -.LSkipScaleVecLoad2C8: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - b .LDequantC8 - -.LProcC8P1: - // - // Channel 8 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - mov v28.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 - mov v29.16b,v25.16b - ld1 {v0.8b},[x1] - mov v16.16b,v24.16b - ldr d4,[x10,x5] - mov v17.16b,v25.16b - ldr d6,[x11,x5] - mov v20.16b,v24.16b - ldr d2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - ldr d10,[x13,x5] - mov v21.16b,v25.16b - b .LEpilogueC8P1 - -.LEpilogueC8P3: - // - // Loop epilogue (process 2 of last 3 pixels) - // - ldr x12,[x14],#8 // x12 -> A2 iter 2 - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ld1 {v0.8b},[x1] // filter iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ldr d6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v10.8b - ldr d10,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -.LEpilogueC8P1: - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smull v12.8h,v0.8b,v2.8b - smull v14.8h,v0.8b,v10.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq .LSkipScaleVecLoadC8 - ldp q4,q11,[x12] // load scale vector if per channel -.LSkipScaleVecLoadC8: - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -.LDequantC8: - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - b.ne .LSkipScaleBroadcastC8 - mov v11.16b,v4.16b // broadcast scale val if not per channel -.LSkipScaleBroadcastC8: - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v11.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v11.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v11.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v11.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - dup v0.8h,w15 - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn2 v24.8h,v25.4s - sqxtn v28.4h,v28.4s - sqxtn2 v28.8h,v29.4s - sqxtn v16.4h,v16.4s - sqxtn2 v16.8h,v17.4s - sqxtn v20.4h,v20.4s - sqxtn2 v20.8h,v21.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v28.8h,v28.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqxtun v24.8b,v24.8h // shorten to int8 - sqxtun v28.8b,v28.8h - sqxtun v16.8b,v16.8h - sqxtun v20.8b,v20.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.8b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 2 - st1 {v28.8b},[x2],x4 - b.ls .LExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.8b},[x2],x4 - b.lo .LExitKernel // exit if OutputCount < 4 - str d20,[x2] - b .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S deleted file mode 100644 index 036928d21b8ca..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S +++ /dev/null @@ -1,550 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - HalfGemmKernelNeon.s - -Abstract: - - This module implements the kernels for the half precision matrix/matrix - multiply operation (HALF GEMM). - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the half gemm kernel. -// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. -// - .equ .LHGemmKernelFrame_SavedRegs, (2 * 8) - .equ .LHGemmKernelFrame_B, 0 + .LHGemmKernelFrame_SavedRegs - .equ .LHGemmKernelFrame_ldb, 8 + .LHGemmKernelFrame_SavedRegs - .equ .LHGemmKernelFrame_ZeroMode, 16 + .LHGemmKernelFrame_SavedRegs - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute 6 rows of GEMM - -Arguments: - - CountM - (x0) the number of rows for matrix A and matrix C. - only process 6 rows - - CountN - (x1) the number of columns from matrix B and matrix C - - CountK - (x2/x0) the number of columns from matrix A and the - number of rows from matrix B. - - C - (x3) the address of matrix C. - - ldc - (x4) - the first dimension of matrix C. - - Bias - (x5) - the address of the Bias vector (optional) - - A - (x6) - the address of matrix A - - lda - (x7) - the first dimension of matrix A - - B - the address of matrix B - - ldb - the first dimension of matrix B - - ZeroMode - true if the output matrix must be zero initialized, else - if the output matrix is accumulated into - ---*/ - - FUNCTION_ENTRY MlasHalfGemmKernelNeon - - str x19,[sp,#-.LHGemmKernelFrame_SavedRegs]! - ldr x9,[sp,#.LHGemmKernelFrame_ldb] - lsl x2,x2,#1 // k *= sizeof(fp16) - cmp x0,2 - add x14,x6,x7,lsl #1 // a1 = a0 + lda - add x10,x3,x4,lsl #1 // c1 = c0 + ldc - ldr x8,[sp,#.LHGemmKernelFrame_B] - csel x14,x6,x14,LO // M < 2 ? a1 = a0 - csel x10,x3,x10,LO // c1 = c0 - add x15,x14,x7,lsl #1 // a2 = a1 + lda - add x11,x10,x4,lsl #1 // c2 = c1 + ldc - csel x15,x14,x15,LS // M <= 2 ? a2 = a1 - csel x11,x10,x11,LS // c2 = c1 - cmp x0,4 - add x16,x15,x7,lsl #1 // a3 = a2 + lda - add x12,x11,x4,lsl #1 // c3 = c2 + ldc - csel x16,x15,x16,LO // M < 4 ? a3 = a2 - csel x12,x11,x12,LO // c3 = c2 - add x17,x16,x7,lsl #1 // a4 = a3 + lda - add x13,x12,x4,lsl #1 // c4 = c3 + ldc - csel x17,x16,x17,LS // M <= 4 ? a4 = a3 - csel x13,x12,x13,LS // c4 = c3 - cmp x0,6 - add x7,x17,x7,lsl #1 // a5 = a4 + lda - add x4,x13,x4,lsl #1 // c5 = c4 + ldc - csel x7,x17,x7,LO // M < 6 ? a5 = a4 - csel x4,x13,x4,LO // c5 = c4 - lsl x9,x9,#1 // ldb *= sizeof(fp16) - ldrb w19,[sp,#.LHGemmKernelFrame_ZeroMode] - sub x9,x9,16 // ldb -= 16 - -/**** -Main loop processes 6x16 tile, depth 4. - B 4x16 - --------------------------------------- - |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 - |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 - |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 - |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 - A 6x4 --------------------------------------- - ------------------ --------------------------------------- -x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 -x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 -x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 -x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 -x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 -x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 - ------------------ --------------------------------------- -****/ - -.LM6N16OutterLoopN: - cbz x5,.LM6N16SkipBias - ldp q20,q21,[x5],32 // Load 16 Bias values - b .LM6N16PopulateAccumulators - -.LM6N16SkipBias: - eor v20.16b,v20.16b,v20.16b // No bias, reset regs - eor v21.16b,v21.16b,v21.16b - -.LM6N16PopulateAccumulators: - mov v22.16b,v20.16b - mov v23.16b,v21.16b - mov v24.16b,v20.16b - mov v25.16b,v21.16b - mov v26.16b,v20.16b - mov v27.16b,v21.16b - mov v28.16b,v20.16b - subs x0,x2,8 // k -= 4 (8 bytes) - mov v29.16b,v21.16b - mov v30.16b,v20.16b - mov v31.16b,v21.16b - b.LO .LM6N16RemainderK123 // remaining k 1~3 - - ldr d0,[x6],8 // A0 - ldr q16,[x8],16 // B0.l - ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row - subs x0,x0,8 // over decement k -= 4 (8 bytes) - ldr d1,[x14],8 // A1 - ldr d2,[x15],8 // A2 - ldr d3,[x16],8 // A3 - b.LO .LM6N16LoopK_Epilogue // need k>=8 for main loop - -.LM6N16InnerLoopK: - fmla v20.8h,v16.8h,v0.h[0] - fmla v21.8h,v17.8h,v0.h[0] - ldr d4,[x17],8 // A4 - fmla v22.8h,v16.8h,v1.h[0] - fmla v23.8h,v17.8h,v1.h[0] - ldr d5,[x7],8 // A5 - fmla v24.8h,v16.8h,v2.h[0] - fmla v25.8h,v17.8h,v2.h[0] - ldr q18,[x8],16 // B1.low - fmla v26.8h,v16.8h,v3.h[0] - fmla v27.8h,v17.8h,v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - fmla v28.8h,v16.8h,v4.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v31.8h,v17.8h,v5.h[0] - subs x0,x0,8 // k -= 4 - - fmla v20.8h,v18.8h,v0.h[1] - fmla v21.8h,v19.8h,v0.h[1] - ldr q16,[x8],16 // B2.low - fmla v22.8h,v18.8h,v1.h[1] - fmla v23.8h,v19.8h,v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - fmla v24.8h,v18.8h,v2.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v31.8h,v19.8h,v5.h[1] - - fmla v20.8h,v16.8h,v0.h[2] - fmla v21.8h,v17.8h,v0.h[2] - ldr q18,[x8],16 // B3.low - fmla v22.8h,v16.8h,v1.h[2] - fmla v23.8h,v17.8h,v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - fmla v24.8h,v16.8h,v2.h[2] - fmla v25.8h,v17.8h,v2.h[2] - fmla v26.8h,v16.8h,v3.h[2] - fmla v27.8h,v17.8h,v3.h[2] - fmla v28.8h,v16.8h,v4.h[2] - fmla v29.8h,v17.8h,v4.h[2] - fmla v30.8h,v16.8h,v5.h[2] - fmla v31.8h,v17.8h,v5.h[2] - - ldr q16,[x8],16 // Load B0.low for next iter - fmla v20.8h,v18.8h,v0.h[3] - fmla v21.8h,v19.8h,v0.h[3] - ld1 {v17.16b},[x8],x9 // Load B0.high for next iter - fmla v22.8h,v18.8h,v1.h[3] - fmla v23.8h,v19.8h,v1.h[3] - ldr d0,[x6],8 // Load A0 for next iter - fmla v24.8h,v18.8h,v2.h[3] - fmla v25.8h,v19.8h,v2.h[3] - ldr d1,[x14],8 // Load A1 for next iter - fmla v26.8h,v18.8h,v3.h[3] - fmla v27.8h,v19.8h,v3.h[3] - ldr d2,[x15],8 // Load A2 for next iter - fmla v28.8h,v18.8h,v4.h[3] - fmla v29.8h,v19.8h,v4.h[3] - ldr d3,[x16],8 // Load A3 for next iter - fmla v30.8h,v18.8h,v5.h[3] - fmla v31.8h,v19.8h,v5.h[3] - b.hs .LM6N16InnerLoopK // k >= 8 for main loop - -.LM6N16LoopK_Epilogue: - // last block of k >= 4, no pre-load for next iter - fmla v20.8h,v16.8h,v0.h[0] - fmla v21.8h,v17.8h,v0.h[0] - ldr d4,[x17],8 // A4 - fmla v22.8h,v16.8h,v1.h[0] - fmla v23.8h,v17.8h,v1.h[0] - ldr d5,[x7],8 // A5 - fmla v24.8h,v16.8h,v2.h[0] - fmla v25.8h,v17.8h,v2.h[0] - ldr q18,[x8],16 // B1.low - fmla v26.8h,v16.8h,v3.h[0] - fmla v27.8h,v17.8h,v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - fmla v28.8h,v16.8h,v4.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v31.8h,v17.8h,v5.h[0] - adds x0,x0,8 // revert k over-decrement - - fmla v20.8h,v18.8h,v0.h[1] - fmla v21.8h,v19.8h,v0.h[1] - ldr q16,[x8],16 // B2.low - fmla v22.8h,v18.8h,v1.h[1] - fmla v23.8h,v19.8h,v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - fmla v24.8h,v18.8h,v2.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v31.8h,v19.8h,v5.h[1] - - fmla v20.8h,v16.8h,v0.h[2] - fmla v21.8h,v17.8h,v0.h[2] - ldr q18,[x8],16 // B3.low - fmla v22.8h,v16.8h,v1.h[2] - fmla v23.8h,v17.8h,v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - fmla v24.8h,v16.8h,v2.h[2] - fmla v25.8h,v17.8h,v2.h[2] - fmla v26.8h,v16.8h,v3.h[2] - fmla v27.8h,v17.8h,v3.h[2] - fmla v28.8h,v16.8h,v4.h[2] - fmla v29.8h,v17.8h,v4.h[2] - fmla v30.8h,v16.8h,v5.h[2] - fmla v31.8h,v17.8h,v5.h[2] - - fmla v20.8h,v18.8h,v0.h[3] - fmla v21.8h,v19.8h,v0.h[3] - fmla v22.8h,v18.8h,v1.h[3] - fmla v23.8h,v19.8h,v1.h[3] - fmla v24.8h,v18.8h,v2.h[3] - fmla v25.8h,v19.8h,v2.h[3] - fmla v26.8h,v18.8h,v3.h[3] - fmla v27.8h,v19.8h,v3.h[3] - fmla v28.8h,v18.8h,v4.h[3] - fmla v29.8h,v19.8h,v4.h[3] - fmla v30.8h,v18.8h,v5.h[3] - fmla v31.8h,v19.8h,v5.h[3] - b.NE .LM6N16RemainderK123 // remaining k 1~3 - -.LM6N16OutterLoopNTail: - subs x1,x1,16 // N -= 16 - ldr x8,[sp,#.LHGemmKernelFrame_B] - b.LO .LM6StoreRemainderN // remaining N < 16 - - cbnz x19,.LM6N16SkipAccumulateOutput - ldp q0,q1,[x3] - ldp q2,q3,[x10] - ldp q4,q5,[x11] - ldp q6,q7,[x12] - ldp q16,q17,[x13] - ldp q18,q19,[x4] - fadd v20.8h,v20.8h,v0.8h // !ZeroMode - fadd v21.8h,v21.8h,v1.8h // accumulate into C - fadd v22.8h,v22.8h,v2.8h - fadd v23.8h,v23.8h,v3.8h - fadd v24.8h,v24.8h,v4.8h - fadd v25.8h,v25.8h,v5.8h - fadd v26.8h,v26.8h,v6.8h - fadd v27.8h,v27.8h,v7.8h - fadd v28.8h,v28.8h,v16.8h - fadd v29.8h,v29.8h,v17.8h - fadd v30.8h,v30.8h,v18.8h - fadd v31.8h,v31.8h,v19.8h - -.LM6N16SkipAccumulateOutput: - st1 {v20.16b,v21.16b},[x3],32 - sub x6,x6,x2 // restore a0 - st1 {v22.16b,v23.16b},[x10],32 - sub x14,x14,x2 // restore a1 - st1 {v24.16b,v25.16b},[x11],32 - sub x15,x15,x2 // restore a2 - st1 {v26.16b,v27.16b},[x12],32 - sub x16,x16,x2 // restore a3 - st1 {v28.16b,v29.16b},[x13],32 - sub x17,x17,x2 // restore a4 - add x8,x8,32 // B <- next 16 columns - st1 {v30.16b,v31.16b},[x4],32 - sub x7,x7,x2 // restore a5 - str x8,[sp,#.LHGemmKernelFrame_B] - b.HI .LM6N16OutterLoopN - -.LExitKernel: - ldr x19,[sp],#.LHGemmKernelFrame_SavedRegs - ret - -.LM6N16RemainderK123: - tbz x0,2,.LM6N16RemainderK1 - ldr s0,[x6],4 // A0 - ldr q16,[x8],16 // B0.low - ld1 {v17.16b},[x8],x9 // B0.high - ldr s1,[x14],4 // A1 - ldr s2,[x15],4 // A2 - ldr s3,[x16],4 // A3 - ldr s4,[x17],4 // A4 - ldr s5,[x7],4 // A5 - ldr q18,[x8],16 // B1.low - ld1 {v19.16b},[x8],x9 // B2.high - fmla v20.8h,v16.8h,v0.h[0] - fmla v22.8h,v16.8h,v1.h[0] - fmla v24.8h,v16.8h,v2.h[0] - fmla v26.8h,v16.8h,v3.h[0] - fmla v28.8h,v16.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v21.8h,v17.8h,v0.h[0] - fmla v23.8h,v17.8h,v1.h[0] - fmla v25.8h,v17.8h,v2.h[0] - fmla v27.8h,v17.8h,v3.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v31.8h,v17.8h,v5.h[0] - - fmla v20.8h,v18.8h,v0.h[1] - fmla v22.8h,v18.8h,v1.h[1] - fmla v24.8h,v18.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v21.8h,v19.8h,v0.h[1] - fmla v23.8h,v19.8h,v1.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v31.8h,v19.8h,v5.h[1] - tbz x0,1,.LM6N16OutterLoopNTail - -.LM6N16RemainderK1: - ldr h0,[x6],2 // A0 - ldr q16,[x8],16 // B0.low - ld1 {v17.16b},[x8],x9 // B0.high - ldr h1,[x14],2 // A1 - ldr h2,[x15],2 // A2 - ldr h3,[x16],2 // A3 - ldr h4,[x17],2 // A4 - ldr h5,[x7],2 // A5 - fmla v20.8h,v16.8h,v0.h[0] - fmla v22.8h,v16.8h,v1.h[0] - fmla v24.8h,v16.8h,v2.h[0] - fmla v26.8h,v16.8h,v3.h[0] - fmla v28.8h,v16.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v21.8h,v17.8h,v0.h[0] - fmla v23.8h,v17.8h,v1.h[0] - fmla v25.8h,v17.8h,v2.h[0] - fmla v27.8h,v17.8h,v3.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v31.8h,v17.8h,v5.h[0] - b .LM6N16OutterLoopNTail - -.LM6StoreRemainderN: - cbnz x19,.LM6StoreRemainderNZeroMode - tbz x1,3,.LM6StoreRemainderN4 - ldr q0,[x3] - ldr q1,[x10] - ldr q2,[x11] - ldr q3,[x12] - ldr q4,[x13] - ldr q5,[x4] - fadd v20.8h,v20.8h,v0.8h - fadd v22.8h,v22.8h,v1.8h - fadd v24.8h,v24.8h,v2.8h - str q20,[x3],16 - mov v20.16b,v21.16b - str q22,[x10],16 - mov v22.16b,v23.16b - str q24,[x11],16 - mov v24.16b,v25.16b - fadd v26.8h,v26.8h,v3.8h - fadd v28.8h,v28.8h,v4.8h - fadd v30.8h,v30.8h,v5.8h - str q26,[x12],16 - mov v26.16b,v27.16b - str q28,[x13],16 - mov v28.16b,v29.16b - str q30,[x4],16 - mov v30.16b,v31.16b - -.LM6StoreRemainderN4: - tbz x1,2,.LM6StoreRemainderN2 - ldr d0,[x3] - ldr d1,[x10] - ldr d2,[x11] - ldr d3,[x12] - ldr d4,[x13] - ldr d5,[x4] - fadd v21.4h,v20.4h,v0.4h - dup d20,v20.d[1] - fadd v23.4h,v22.4h,v1.4h - dup d22,v22.d[1] - fadd v25.4h,v24.4h,v2.4h - dup d24,v24.d[1] - fadd v27.4h,v26.4h,v3.4h - dup d26,v26.d[1] - fadd v29.4h,v28.4h,v4.4h - dup d28,v28.d[1] - fadd v31.4h,v30.4h,v5.4h - dup d30,v30.d[1] - str d21,[x3],8 - str d23,[x10],8 - str d25,[x11],8 - str d27,[x12],8 - str d29,[x13],8 - str d31,[x4],8 - -.LM6StoreRemainderN2: - tbz x1,1,.LM6StoreRemainderN1 - ldr s0,[x3] - ldr s1,[x10] - ldr s2,[x11] - ldr s3,[x12] - ldr s4,[x13] - ldr s5,[x4] - fadd v21.4h,v20.4h,v0.4h - fadd v23.4h,v22.4h,v1.4h - fadd v25.4h,v24.4h,v2.4h - fadd v27.4h,v26.4h,v3.4h - fadd v29.4h,v28.4h,v4.4h - fadd v31.4h,v30.4h,v5.4h - str s21,[x3],4 - str s23,[x10],4 - dup s20,v20.s[1] - dup s22,v22.s[1] - str s25,[x11],4 - str s27,[x12],4 - dup s24,v24.s[1] - dup s26,v26.s[1] - str s29,[x13],4 - str s31,[x4],4 - dup s28,v28.s[1] - dup s30,v30.s[1] - -.LM6StoreRemainderN1: - tbz x1,0,.LExitKernel - ldr h0,[x3] - ldr h1,[x10] - ldr h2,[x11] - ldr h3,[x12] - ldr h4,[x13] - ldr h5,[x4] - fadd v20.4h,v20.4h,v0.4h - fadd v22.4h,v22.4h,v1.4h - fadd v24.4h,v24.4h,v2.4h - fadd v26.4h,v26.4h,v3.4h - fadd v28.4h,v28.4h,v4.4h - fadd v30.4h,v30.4h,v5.4h - str h20,[x3] - str h22,[x10] - str h24,[x11] - str h26,[x12] - str h28,[x13] - str h30,[x4] - b .LExitKernel - -.LM6StoreRemainderNZeroMode: - tbz x1,3,.LM6StoreRemainderN4ZeroMode - str q20,[x3],16 - mov v20.16b,v21.16b - str q22,[x10],16 - mov v22.16b,v23.16b - str q24,[x11],16 - mov v24.16b,v25.16b - str q26,[x12],16 - mov v26.16b,v27.16b - str q28,[x13],16 - mov v28.16b,v29.16b - str q30,[x4],16 - mov v30.16b,v31.16b - -.LM6StoreRemainderN4ZeroMode: - tbz x1,2,.LM6StoreRemainderN2ZeroMode - str d20,[x3],8 - str d22,[x10],8 - dup d20,v20.d[1] - dup d22,v22.d[1] - str d24,[x11],8 - str d26,[x12],8 - dup d24,v24.d[1] - dup d26,v26.d[1] - str d28,[x13],8 - str d30,[x4],8 - dup d28,v28.d[1] - dup d30,v30.d[1] - -.LM6StoreRemainderN2ZeroMode: - tbz x1,1,.LM6StoreRemainderN1ZeroMode - str s20,[x3],4 - str s22,[x10],4 - dup s20,v20.s[1] - dup s22,v22.s[1] - str s24,[x11],4 - str s26,[x12],4 - dup s24,v24.s[1] - dup s26,v26.s[1] - str s28,[x13],4 - str s30,[x4],4 - dup s28,v28.s[1] - dup s30,v30.s[1] - -.LM6StoreRemainderN1ZeroMode: - tbz x1,0,.LExitKernel - str h20,[x3] - str h22,[x10] - str h24,[x11] - str h26,[x12] - str h28,[x13] - str h30,[x4] - b .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelNeon.S deleted file mode 100644 index 92a94286f09ff..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelNeon.S +++ /dev/null @@ -1,691 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmS8S8KernelNeon.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the S8S8 kernel. -// - - .equ .LGemmS8S8KernelFrame_SavedNeonRegisters, (8 * 8) - .equ .LGemmS8S8KernelFrame_SavedRegisters, .LGemmS8S8KernelFrame_SavedNeonRegisters - .equ .LGemmS8S8KernelFrame_ColumnSumBuffer, 0 + .LGemmS8S8KernelFrame_SavedRegisters - .equ .LGemmS8S8KernelFrame_ZeroPointB, 8 + .LGemmS8S8KernelFrame_SavedRegisters - .equ .LGemmS8S8KernelFrame_ZeroMode, 16 + .LGemmS8S8KernelFrame_SavedRegisters - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmS8S8KernelNeon - - stp d8,d9,[sp,#-64]! - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp d14,d15,[sp,#48] - ldr x8,[sp,#.LGemmS8S8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#.LGemmS8S8KernelFrame_ZeroPointB] - ldrb w13,[sp,#.LGemmS8S8KernelFrame_ZeroMode] - mov x14,x0 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq .LGemmS8S8.M1.ProcessLoop - cmp x4,#4 // CountM < 4? - blo .LGemmS8S8.M2.ProcessLoop - -// -// Process 4 rows of the matrices. -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v8.b[0] v9.b[0] v10.b[0] v11.b[0]| -// | ... ... ... ... | -// |v8.b[7] v9.b[7] v10.b[7] v11.b[7]| -// A 4x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v24.4s v25.4s v26.4s v27.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v28.4s v29.4s v30.4s v31.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -.LGemmS8S8.M4.ProcessNextColumnLoop: - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#64 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d4,d8,[x1],#64 // B - movi v18.4s,#0 - movi v19.4s,#0 - ldp d5,d9,[x1,#-48] - movi v20.4s,#0 - movi v21.4s,#0 - ldp d6,d10,[x1,#-32] - movi v22.4s,#0 - movi v23.4s,#0 - ldp d7,d11,[x1,#-16] - movi v24.4s,#0 - movi v25.4s,#0 - ldp d1,d3,[x0,#-48] - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -.LGemmS8S8.M4.ComputeBlockLoop: - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0,#-32] - sadalp v16.4s,v12.8h - sadalp v17.4s,v13.8h - sadalp v18.4s,v14.8h - sadalp v19.4s,v15.8h - sub x3,x3,#1 - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - ldp d1,d3,[x0,#-16] - sadalp v20.4s,v12.8h - sadalp v21.4s,v13.8h - sadalp v22.4s,v14.8h - sadalp v23.4s,v15.8h - cbz x3,.LGemmS8S8.M4.ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0],#64 - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - ldp d4,d8,[x1],#64 // B - smlal v13.8h,v3.8b,v9.8b - ldp d5,d9,[x1,#-48] - smlal v14.8h,v3.8b,v10.8b - ldp d6,d10,[x1,#-32] - smlal v15.8h,v3.8b,v11.8b - ldp d7,d11,[x1,#-16] - sadalp v28.4s,v12.8h - ldp d1,d3,[x0,#-48] - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b .LGemmS8S8.M4.ComputeBlockLoop - -.LGemmS8S8.M4.ComputeBlockLoopFinish: - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - ld1 {v0.4s},[x7] - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v24.4s,v24.4s,v25.4s - addp v26.4s,v26.4s,v27.4s - addp v28.4s,v28.4s,v29.4s - addp v30.4s,v30.4s,v31.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - dup v8.4s,v0.s[0] // broadcast row fixups - dup v9.4s,v0.s[1] - dup v10.4s,v0.s[2] - dup v11.4s,v0.s[3] - cbz x9,.LGemmS8S8.M4.SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v30.4s},[x9],#16 // load ZeroPointB - mul v17.4s,v30.4s,v8.4s - mul v21.4s,v30.4s,v9.4s - mul v25.4s,v30.4s,v10.4s - mul v29.4s,v30.4s,v11.4s - add v16.4s,v16.4s,v17.4s - add v20.4s,v20.4s,v21.4s - add v24.4s,v24.4s,v25.4s - add v28.4s,v28.4s,v29.4s - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - b .LGemmS8S8.M4.StoreOutput - -.LGemmS8S8.M4.SkipScaleByZeroPointB: - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v8.4s - add v20.4s,v20.4s,v9.4s - add v24.4s,v24.4s,v10.4s - add v28.4s,v28.4s,v11.4s - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - -.LGemmS8S8.M4.StoreOutput: - add x10,x2,x6,lsl #2 - add x11,x10,x6,lsl #2 - add x12,x11,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo .LGemmS8S8.M4.StoreOutputPartial - cbnz x13,.LGemmS8S8.M4.SkipAccumulateOutput - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v3.4s - -.LGemmS8S8.M4.SkipAccumulateOutput: - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - st1 {v24.4s},[x11] - st1 {v28.4s},[x12] - cbnz x5,.LGemmS8S8.M4.ProcessNextColumnLoop - -.LGemmS8S8.M4.ExitKernel: - mov x0,#4 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LGemmS8S8.M4.StoreOutputPartial: - cbz x13,.LGemmS8S8.M4.StoreOutputPartial.AddMode - -.LGemmS8S8.M4.StoreOutputPartial.ZeroMode: - tbz x5,#1,.LGemmS8S8.M4.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -.LGemmS8S8.M4.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M4.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b .LGemmS8S8.M4.ExitKernel - -.LGemmS8S8.M4.StoreOutputPartial.AddMode: - tbz x5,#1,.LGemmS8S8.M4.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v3.4s - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -.LGemmS8S8.M4.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M4.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v20.4s,v20.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v24.4s,v24.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - add v28.4s,v28.4s,v3.4s - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b .LGemmS8S8.M4.ExitKernel - -// -// Process 2 rows of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// Each row sum replicated to all 4 elements of a vector register -// v30 v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 2x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) - -.LGemmS8S8.M2.ProcessLoop: - -.LGemmS8S8.M2.ProcessNextColumnLoop: - ldp d4,d24,[x1],#16 // B - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d5,d25,[x1],#16 - movi v18.4s,#0 - movi v19.4s,#0 - ldp d6,d26,[x1],#16 - movi v20.4s,#0 - movi v21.4s,#0 - ldp d7,d27,[x1],#16 - movi v22.4s,#0 - movi v23.4s,#0 - ldp d1,d3,[x0],#16 // A1 - -.LGemmS8S8.M2.ComputeBlockLoop: - - sub x3,x3,#1 - smull v28.8h,v0.8b,v4.8b - smull v29.8h,v0.8b,v5.8b - smull v30.8h,v0.8b,v6.8b - smull v31.8h,v0.8b,v7.8b - cbz x3,.LGemmS8S8.M2.ComputeBlockLoopFinish - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v29.8h,v3.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v30.8h,v3.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v31.8h,v3.8b,v27.8b - ldp d7,d27,[x1],#16 - sadalp v20.4s,v28.8h - ldp d1,d3,[x0],#16 // A1 - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - b .LGemmS8S8.M2.ComputeBlockLoop - -.LGemmS8S8.M2.ComputeBlockLoopFinish: - ld1 {v0.4s},[x8],#16 // load ColumnSumBuffer[0] - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldr d2,[x7] // load row sums - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - smlal v29.8h,v3.8b,v25.8b - smlal v30.8h,v3.8b,v26.8b - smlal v31.8h,v3.8b,v27.8b - sadalp v20.4s,v28.8h - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - dup v30.4s,v2.s[0] // broadcast row fixups - dup v31.4s,v2.s[1] // broadcast row fixups - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - cbz x9,.LGemmS8S8.M2.SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v18.4s},[x9],#16 // load ZeroPointB[0] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - mul v17.4s,v18.4s,v30.4s - mul v21.4s,v18.4s,v31.4s - add v16.4s,v16.4s,v17.4s - add v20.4s,v20.4s,v21.4s - b .LGemmS8S8.M2.StoreOutput - -.LGemmS8S8.M2.SkipScaleByZeroPointB: - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - add v16.4s,v16.4s,v30.4s - add v20.4s,v20.4s,v31.4s - -.LGemmS8S8.M2.StoreOutput: - add x10,x2,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo .LGemmS8S8.M2.StoreOutputPartial - cbnz x13,.LGemmS8S8.M2.SkipAccumulateOutput - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - -.LGemmS8S8.M2.SkipAccumulateOutput: - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - cbnz x5,.LGemmS8S8.M2.ProcessNextColumnLoop - -.LGemmS8S8.M2.ExitKernel: - mov x0,#2 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LGemmS8S8.M2.StoreOutputPartial: - cbz x13,.LGemmS8S8.M2.StoreOutputPartial.AddMode - -.LGemmS8S8.M2.StoreOutputPartial.ZeroMode: - tbz x5,#1,.LGemmS8S8.M2.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -.LGemmS8S8.M2.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M2.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b .LGemmS8S8.M2.ExitKernel - -.LGemmS8S8.M2.StoreOutputPartial.AddMode: - tbz x5,#1,.LGemmS8S8.M2.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -.LGemmS8S8.M2.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M2.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b .LGemmS8S8.M2.ExitKernel - -// -// Process 1 row of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// row sum replicated to all 4 elements of a vector register -// v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 1x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -.LGemmS8S8.M1.ProcessLoop: - ldr d31,[x7] - dup v31.4s,v31.s[0] // broadcast row fixups - -.LGemmS8S8.M1.ProcessNextColumnLoop: - ldp d4,d24,[x1],#16 // B - ldp d5,d25,[x1],#16 - ldp d6,d26,[x1],#16 - ldp d7,d27,[x1],#16 - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - -.LGemmS8S8.M1.ComputeBlockLoop: - sub x3,x3,#1 - smull v20.8h,v0.8b,v4.8b - smull v21.8h,v0.8b,v5.8b - cbz x3,.LGemmS8S8.M1.ComputeBlockLoopFinish - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v21.8h,v2.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v22.8h,v2.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v23.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - ldp d7,d27,[x1],#16 - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - b .LGemmS8S8.M1.ComputeBlockLoop - -.LGemmS8S8.M1.ComputeBlockLoopFinish: - ld1 {v4.4s},[x8],#16 // load ColumnSumBuffer[0] - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - smlal v21.8h,v2.8b,v25.8b - smlal v22.8h,v2.8b,v26.8b - smlal v23.8h,v2.8b,v27.8b - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v16.4s,v16.4s,v18.4s - cbz x9,.LGemmS8S8.M1.SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v17.4s,v30.4s,v31.4s - add v16.4s,v16.4s,v17.4s - add v16.4s,v16.4s,v4.4s - b .LGemmS8S8.M1.StoreOutput -.LGemmS8S8.M1.SkipScaleByZeroPointB: - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v31.4s - add v16.4s,v16.4s,v4.4s - -.LGemmS8S8.M1.StoreOutput: - subs x5,x5,#4 // adjust CountN remaining - blo .LGemmS8S8.M1.StoreOutputPartial - cbnz x13,.LGemmS8S8.M1.SkipAccumulateOutput - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - -.LGemmS8S8.M1.SkipAccumulateOutput: - st1 {v16.4s},[x2],#16 - cbnz x5,.LGemmS8S8.M1.ProcessNextColumnLoop - -.LGemmS8S8.M1.ExitKernel: - mov x0,#1 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#64 - ret - -.LGemmS8S8.M1.StoreOutputPartial: - cbz x13,.LGemmS8S8.M1.StoreOutputPartial.AddMode - -.LGemmS8S8.M1.StoreOutputPartial.ZeroMode: - tbz x5,#1,.LGemmS8S8.M1.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M1.ExitKernel - st1 {v16.s}[0],[x2] - b .LGemmS8S8.M1.ExitKernel - -.LGemmS8S8.M1.StoreOutputPartial.AddMode: - tbz x5,#1,.LGemmS8S8.M1.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M1.ExitKernel - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b .LGemmS8S8.M1.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSdot.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSdot.S deleted file mode 100644 index 22632d4e794d5..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSdot.S +++ /dev/null @@ -1,1056 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmS8S8KernelSdot.S - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the S8S8 kernel. -// Defining spaces for saving 2 vector registers, and pointers to parameters -// on the stack -// - - .equ .LGemmS8S8KernelFrame_SavedNeonRegisters, (2 * 8) - .equ .LGemmS8S8KernelFrame_SavedRegisters, .LGemmS8S8KernelFrame_SavedNeonRegisters - .equ .LGemmS8S8KernelFrame_ColumnSumBuffer, 0 + .LGemmS8S8KernelFrame_SavedRegisters - .equ .LGemmS8S8KernelFrame_ZeroPointB, 8 + .LGemmS8S8KernelFrame_SavedRegisters - .equ .LGemmS8S8KernelFrame_ZeroMode, 16 + .LGemmS8S8KernelFrame_SavedRegisters - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmS8S8KernelSDot - - stp d8,d9,[sp,#-16]! - ldr x8,[sp,#.LGemmS8S8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#.LGemmS8S8KernelFrame_ZeroPointB] - ldrb w13,[sp,#.LGemmS8S8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v8.4s},[x7],#16 // load row sum 1 ~ 4 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq .LGemmS8S8.M1.ProcessLoop - cmp x4,#4 // CountM < 4? - blo .LGemmS8S8.M2.ProcessLoop - cmp x4,#8 // CountM < 8? - blo .LGemmS8S8.M4.ProcessNextColumnLoop - ld1 {v9.4s},[x7] // load row sum 5 ~ 8 - -// -// Process 8 rows of the matrices. -// Row Sums: v8 ~ v9 -// B 4x8 block -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// A 8x4 block -// --------------------- ----------------------------------------- -// |v4.b[0] ... v4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v4.b[4] ... v4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v4.b[8] ... v4.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v4.b[12] ... v4.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v5.b[0] ... v5.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v5.b[4] ... v5.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v5.b[8] ... v5.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v5.b[12] ... v5.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- -// -// unroll for the next 4 in k dimension -// ----------------------------------------- -// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| -// | ... ... | -// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| -// ----------------------------------------- -// --------------------- ----------------------------------------- -// |v6.b[0] ... v6.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v6.b[4] ... v6.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v6.b[8] ... v6.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v6.b[12] ... v6.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v7.b[0] ... v7.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v7.b[4] ... v7.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v7.b[8] ... v7.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v7.b[12] ... v7.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- - -// Starting the loop: initialize accumulators with scaled combination -// of row and column sums - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - dup v25.4s,v9.s[0] - dup v27.4s,v9.s[1] - dup v29.4s,v9.s[2] - dup v31.4s,v9.s[3] - -.LGemmS8S8.M8.ProcessNextColumnLoop: - mov x0,x14 // reload matrix A - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v7.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,.LGemmS8S8.M8.SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v0.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v0.4s,v17.4s - mul v18.4s,v0.4s,v19.4s - ld1 {v1.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v0.4s,v21.4s - mul v22.4s,v0.4s,v23.4s - mul v24.4s,v0.4s,v25.4s - mul v26.4s,v0.4s,v27.4s - mul v28.4s,v0.4s,v29.4s - mul v30.4s,v0.4s,v31.4s - mul v17.4s,v1.4s,v17.4s - mul v19.4s,v1.4s,v19.4s - mul v21.4s,v1.4s,v21.4s - mul v23.4s,v1.4s,v23.4s - mul v25.4s,v1.4s,v25.4s - mul v27.4s,v1.4s,v27.4s - mul v29.4s,v1.4s,v29.4s - mul v31.4s,v1.4s,v31.4s - - // preloading mixed with accumulator inits - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v16.4s - add v18.4s,v3.4s,v18.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v20.4s - add v22.4s,v3.4s,v22.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v24.4s - add v26.4s,v3.4s,v26.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v28.4s - add v30.4s,v3.4s,v30.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - b .LGemmS8S8.M8.ComputeBlockLoop - -.LGemmS8S8.M8.SkipScaleByZeroPointB: - // accumulator = row sum A + column sum B - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v17.4s - add v18.4s,v3.4s,v19.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v21.4s - add v22.4s,v3.4s,v23.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v25.4s - add v26.4s,v3.4s,v27.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v29.4s - add v30.4s,v3.4s,v31.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - -.LGemmS8S8.M8.ComputeBlockLoop: - sub x3,x3,#1 - ld1 {v3.16b},[x1],#16 // load packed B1_next4k - SdotByElement 16, 0, 4, 0 - SdotByElement 18, 0, 4, 1 - ldr q7,[x0],#16 // load packed A3 - SdotByElement 20, 0, 4, 2 - SdotByElement 22, 0, 4, 3 - cbz x3,.LGemmS8S8.M8.ComputeBlockLoopFinish - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - SdotByElement 21, 1, 4, 2 - SdotByElement 23, 1, 4, 3 - ldr q4,[x0],#16 // load packed A0 for next iteration - SdotByElement 24, 0, 5, 0 - SdotByElement 26, 0, 5, 1 - SdotByElement 28, 0, 5, 2 - SdotByElement 30, 0, 5, 3 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - SdotByElement 25, 1, 5, 0 - SdotByElement 27, 1, 5, 1 - SdotByElement 29, 1, 5, 2 - SdotByElement 31, 1, 5, 3 - ld1 {v1.16b},[x1],#16 // load packed B1 for next iteration - - SdotByElement 16, 2, 6, 0 - SdotByElement 18, 2, 6, 1 - ldr q5,[x0],#16 // load packed A1 for next iteration - SdotByElement 20, 2, 6, 2 - SdotByElement 22, 2, 6, 3 - SdotByElement 17, 3, 6, 0 - SdotByElement 19, 3, 6, 1 - SdotByElement 21, 3, 6, 2 - SdotByElement 23, 3, 6, 3 - ldr q6,[x0],#16 // load packed A2 for next iteration - SdotByElement 24, 2, 7, 0 - SdotByElement 26, 2, 7, 1 - SdotByElement 28, 2, 7, 2 - SdotByElement 30, 2, 7, 3 - ld1 {v2.16b},[x1],#16 // load packed B0_next4k for next iteration - SdotByElement 25, 3, 7, 0 - SdotByElement 27, 3, 7, 1 - SdotByElement 29, 3, 7, 2 - SdotByElement 31, 3, 7, 3 - b .LGemmS8S8.M8.ComputeBlockLoop - -.LGemmS8S8.M8.ComputeBlockLoopFinish: - // postfix, compute tail values and prepare to write results - // We are either about to go to ProcessNextColumnLoopM8 - // where x0 and x3 are about to be restored, or exit - // when x0 and x3 will not be used. - // x4 x7 has finished their task - // so we can use x0 x3 x4 x7 as output row pointers - - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add x11,x10,x6,lsl #2 // compute output row 3 - SdotByElement 21, 1, 4, 2 - SdotByElement 23, 1, 4, 3 - add x12,x11,x6,lsl #2 // compute output row 4 - add x0,x12,x6,lsl #2 // compute output row 5 - SdotByElement 24, 0, 5, 0 - SdotByElement 26, 0, 5, 1 - add x3,x0,x6,lsl #2 // compute output row 6 - add x4,x3,x6,lsl #2 // compute output row 7 - SdotByElement 28, 0, 5, 2 - SdotByElement 30, 0, 5, 3 - add x7,x4,x6,lsl #2 // compute output row 8 - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 25, 1, 5, 0 - SdotByElement 27, 1, 5, 1 - SdotByElement 29, 1, 5, 2 - SdotByElement 31, 1, 5, 3 - - SdotByElement 16, 2, 6, 0 - SdotByElement 18, 2, 6, 1 - SdotByElement 20, 2, 6, 2 - SdotByElement 22, 2, 6, 3 - SdotByElement 17, 3, 6, 0 - SdotByElement 19, 3, 6, 1 - SdotByElement 21, 3, 6, 2 - SdotByElement 23, 3, 6, 3 - SdotByElement 24, 2, 7, 0 - SdotByElement 26, 2, 7, 1 - SdotByElement 28, 2, 7, 2 - SdotByElement 30, 2, 7, 3 - SdotByElement 25, 3, 7, 0 - SdotByElement 27, 3, 7, 1 - SdotByElement 29, 3, 7, 2 - SdotByElement 31, 3, 7, 3 - blo .LGemmS8S8.M8.StoreOutputPartial - cbnz x13,.LGemmS8S8.M8.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - ldp q0, q1, [x0] - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - ldp q2, q3, [x3] - add v24.4s,v24.4s,v0.4s - add v25.4s,v25.4s,v1.4s - ldp q4, q5, [x4] - add v26.4s,v26.4s,v2.4s - add v27.4s,v27.4s,v3.4s - ldp q6, q7, [x7] - add v28.4s,v28.4s,v4.4s - add v29.4s,v29.4s,v5.4s - add v30.4s,v30.4s,v6.4s - add v31.4s,v31.4s,v7.4s - -.LGemmS8S8.M8.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - dup v17.4s,v8.s[0] // broadcast row sums - stp q18,q19,[x10] - dup v19.4s,v8.s[1] - stp q20,q21,[x11] - dup v21.4s,v8.s[2] - stp q22,q23,[x12] - dup v23.4s,v8.s[3] - stp q24,q25,[x0] - dup v25.4s,v9.s[0] - stp q26,q27,[x3] - dup v27.4s,v9.s[1] - stp q28,q29,[x4] - dup v29.4s,v9.s[2] - stp q30,q31,[x7] - dup v31.4s,v9.s[3] - - cbnz x5,.LGemmS8S8.M8.ProcessNextColumnLoop - -.LGemmS8S8.M8.ExitKernel: - mov x0,#8 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmS8S8.M8.StoreOutputPartial: - cbz x13,.LGemmS8S8.M8.StoreOutputPartialAddMode - -.LGemmS8S8.M8.StoreOutputPartialZeroMode: - tbz x5,#2,.LGemmS8S8.M8.StoreOutputPartial2ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -.LGemmS8S8.M8.StoreOutputPartial2ZeroMode: - tbz x5,#1,.LGemmS8S8.M8.StoreOutputPartial1ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -.LGemmS8S8.M8.StoreOutputPartial1ZeroMode: - tbz x5,#0,.LGemmS8S8.M8.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - st1 {v24.s}[0],[x0] - st1 {v26.s}[0],[x3] - st1 {v28.s}[0],[x4] - st1 {v30.s}[0],[x7] - b .LGemmS8S8.M8.ExitKernel - -.LGemmS8S8.M8.StoreOutputPartialAddMode: - tbz x5,#2,.LGemmS8S8.M8.StoreOutputPartial2AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - ld1 {v4.4s},[x0] - ld1 {v5.4s},[x3] - ld1 {v6.4s},[x4] - ld1 {v7.4s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -.LGemmS8S8.M8.StoreOutputPartial2AddMode: - tbz x5,#1,.LGemmS8S8.M8.StoreOutputPartial1AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - ld1 {v4.2s},[x0] - ld1 {v5.2s},[x3] - ld1 {v6.2s},[x4] - ld1 {v7.2s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -.LGemmS8S8.M8.StoreOutputPartial1AddMode: - tbz x5,#0,.LGemmS8S8.M8.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - ld1 {v4.s}[0],[x0] - ld1 {v5.s}[0],[x3] - ld1 {v6.s}[0],[x4] - ld1 {v7.s}[0],[x7] - add v24.4s,v24.4s,v4.4s - st1 {v24.s}[0],[x0] - add v26.4s,v26.4s,v5.4s - st1 {v26.s}[0],[x3] - add v28.4s,v28.4s,v6.4s - st1 {v28.s}[0],[x4] - add v30.4s,v30.4s,v7.4s - st1 {v30.s}[0],[x7] - b .LGemmS8S8.M8.ExitKernel - - -// -// Process 4 rows of the matrices. -// -// -// The packing layout is setup to have a pair of four quad vectors from -// packed matrix A and a pair of eight quad vectors from packed matrix B. -// With this scheme, alternating loads from the packed matrices can be -// interleaved with the dot product instructions. -// -// One negative consequence of using four rows here is that the accumulator -// register tile is too small for processors with high out of order execution -// windows (such as the Apple M1). The dot product instructions for a given -// cell are too close to each other to avoid dependencies. To workaround this, -// the below loop uses a pair of accumulator registers that are then added -// together when the loop finishes. -// -// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for -// packed matrix A. At the time of this implementation, using a wider 128-bit -// load did not affect performance for higher end cores. -// -// B 4x8 block -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// A 4x4 block -// --------------------- ----------------------------------------- -// |d4.b[0] ... d4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |d4.b[4] ... d4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |d5.b[0] ... d5.b[3] | |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |d5.b[4] ... d5.b[7] | |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// --------------------- ----------------------------------------- -// -// unroll for the next 4 in k dimension -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// --------------------- ----------------------------------------- -// |d6.b[0] ... d6.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |d6.b[4] ... d6.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |d7.b[0] ... d7.b[3] | |v28.s[0] .. v24.s[3] v29.s[0] .. v29.s[3]| -// |d7.b[4] ... d7.b[7] | |v30.s[0] .. v24.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- - -.LGemmS8S8.M4.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - cbz x9,.LGemmS8S8.M4.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v30.4s,v17.4s - mul v18.4s,v30.4s,v19.4s - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v30.4s,v21.4s - mul v22.4s,v30.4s,v23.4s - mul v17.4s,v31.4s,v17.4s - mul v19.4s,v31.4s,v19.4s - mul v21.4s,v31.4s,v21.4s - mul v23.4s,v31.4s,v23.4s - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - add v20.4s,v2.4s,v20.4s - add v22.4s,v2.4s,v22.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - b .LGemmS8S8.M4.ComputeBlockLoopStart - -.LGemmS8S8.M4.SkipScaleByZeroPointB: - add v16.4s,v2.4s,v17.4s - add v18.4s,v2.4s,v19.4s - add v20.4s,v2.4s,v21.4s - add v22.4s,v2.4s,v23.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - -.LGemmS8S8.M4.ComputeBlockLoopStart: - ldr d4,[x0],#32 // load packed A0.l - movi v24.4s,#0 - movi v25.4s,#0 - ldur d5,[x0,#-24] // load packed A0.h - movi v26.4s,#0 - movi v27.4s,#0 - ldur d6,[x0,#-16] // load packed A1.l - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -.LGemmS8S8.M4.ComputeBlockLoop: - ld1 {v1.16b},[x1],#16 // load packed B1 - SdotByElement 16, 0, 4, 0 - SdotByElement 18, 0, 4, 1 - ldur d7,[x0,#-8] // load packed A1.h - SdotByElement 20, 0, 5, 0 - SdotByElement 22, 0, 5, 1 - ld1 {v0.16b},[x1],#16 // load packed B0_next4k - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - sub x3,x3,#1 - cbz x3,.LGemmS8S8.M4.ComputeBlockLoopFinish - ldr d4,[x0],#32 // load packed A0.l for next iteration - SdotByElement 21, 1, 5, 0 - SdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - SdotByElement 24, 0, 6, 0 - SdotByElement 26, 0, 6, 1 - ldur d5,[x0,#-24] // load packed A0.h for next iteration - SdotByElement 28, 0, 7, 0 - SdotByElement 30, 0, 7, 1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - SdotByElement 25, 1, 6, 0 - SdotByElement 27, 1, 6, 1 - ldur d6,[x0,#-16] // load packed A1.l for next iteration - SdotByElement 29, 1, 7, 0 - SdotByElement 31, 1, 7, 1 - b .LGemmS8S8.M4.ComputeBlockLoop - -.LGemmS8S8.M4.ComputeBlockLoopFinish: - SdotByElement 21, 1, 5, 0 - SdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - SdotByElement 24, 0, 6, 0 - SdotByElement 26, 0, 6, 1 - SdotByElement 28, 0, 7, 0 - SdotByElement 30, 0, 7, 1 - SdotByElement 25, 1, 6, 0 - SdotByElement 27, 1, 6, 1 - SdotByElement 29, 1, 7, 0 - SdotByElement 31, 1, 7, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add v16.4s,v16.4s,v24.4s // fold high results into low results - add v18.4s,v18.4s,v26.4s - add v20.4s,v20.4s,v28.4s - add v22.4s,v22.4s,v30.4s - add x11,x10,x6,lsl #2 // compute output row 3 - add v17.4s,v17.4s,v25.4s - add v19.4s,v19.4s,v27.4s - add v21.4s,v21.4s,v29.4s - add v23.4s,v23.4s,v31.4s - add x12,x11,x6,lsl #2 // compute output row 4 - subs x5,x5,#8 // adjust CountN remaining - blo .LGemmS8S8.M4.StoreOutputPartial - cbnz x13,.LGemmS8S8.M4.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -.LGemmS8S8.M4.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,.LGemmS8S8.M4.ProcessNextColumnLoop - -.LGemmS8S8.M4.ExitKernel: - mov x0,#4 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmS8S8.M4.StoreOutputPartial: - cbz x13,.LGemmS8S8.M4.StoreOutputPartial.AddMode - -.LGemmS8S8.M4.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmS8S8.M4.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmS8S8.M4.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmS8S8.M4.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmS8S8.M4.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M4.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmS8S8.M4.ExitKernel - -.LGemmS8S8.M4.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmS8S8.M4.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmS8S8.M4.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmS8S8.M4.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmS8S8.M4.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M4.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmS8S8.M4.ExitKernel - -// -// Process 2 rows of the matrices. -// -.LGemmS8S8.M2.ProcessLoop: - dup v9.4s, v8.s[1] - dup v8.4s, v8.s[0] - -.LGemmS8S8.M2.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,.LGemmS8S8.M2.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v16.4s,v30.4s,v8.4s - mul v18.4s,v30.4s,v9.4s - mul v17.4s,v31.4s,v8.4s - mul v19.4s,v31.4s,v9.4s - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - b .LGemmS8S8.M2.ComputeBlockLoop - -.LGemmS8S8.M2.SkipScaleByZeroPointB: - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v8.4s - add v18.4s,v2.4s,v9.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v8.4s - add v19.4s,v3.4s,v9.4s - -.LGemmS8S8.M2.ComputeBlockLoop: - sub x3,x3,#1 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - SdotByElement 16, 0, 4, 0 - SdotByElement 17, 1, 4, 0 - SdotByElement 18, 0, 4, 1 - SdotByElement 19, 1, 4, 1 - cbz x3,.LGemmS8S8.M2.ComputeBlockLoopFinish - ldr d4,[x0],#8 // load packed A0.l for next iter - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - SdotByElement 16, 6, 5, 0 - SdotByElement 17, 7, 5, 0 - SdotByElement 18, 6, 5, 1 - SdotByElement 19, 7, 5, 1 - ldr d5,[x0],#8 // load packed A0.h for next iter - b .LGemmS8S8.M2.ComputeBlockLoop - -.LGemmS8S8.M2.ComputeBlockLoopFinish: - add x10,x2,x6,lsl #2 // compute output row 2 - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 16, 6, 5, 0 - SdotByElement 17, 7, 5, 0 - SdotByElement 18, 6, 5, 1 - SdotByElement 19, 7, 5, 1 - blo .LGemmS8S8.M2.StoreOutputPartial - cbnz x13,.LGemmS8S8.M2.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -.LGemmS8S8.M2.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,.LGemmS8S8.M2.ProcessNextColumnLoop - -.LGemmS8S8.M2.ExitKernel: - mov x0,#2 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmS8S8.M2.StoreOutputPartial: - cbz x13,.LGemmS8S8.M2.StoreOutputPartial.AddMode - -.LGemmS8S8.M2.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmS8S8.M2.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmS8S8.M2.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmS8S8.M2.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmS8S8.M2.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M2.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmS8S8.M2.ExitKernel - -.LGemmS8S8.M2.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmS8S8.M2.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmS8S8.M2.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmS8S8.M2.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmS8S8.M2.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M2.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmS8S8.M2.ExitKernel - -// -// Process 1 row of the matrices. -// -.LGemmS8S8.M1.ProcessLoop: - - dup v8.4s,v8.s[0] - -.LGemmS8S8.M1.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - cbz x9,.LGemmS8S8.M1.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB0 - ld1 {v31.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v30.4s,v8.4s - mul v17.4s,v31.4s,v8.4s - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b .LGemmS8S8.M1.ComputeBlockLoop - -.LGemmS8S8.M1.SkipScaleByZeroPointB: - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v8.4s - add v17.4s,v3.4s,v8.4s - -.LGemmS8S8.M1.ComputeBlockLoop: - sub x3,x3,#1 - SdotByElement 16, 0, 4, 0 - SdotByElement 17, 1, 4, 0 - cbz x3,.LGemmS8S8.M1.ComputeBlockLoopFinish - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - SdotByElement 16, 6, 4, 1 - SdotByElement 17, 7, 4, 1 - ldr d4,[x0],#8 // load packed A0 for next iter - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k for next iter - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k for next iter - b .LGemmS8S8.M1.ComputeBlockLoop - -.LGemmS8S8.M1.ComputeBlockLoopFinish: - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 16, 6, 4, 1 - SdotByElement 17, 7, 4, 1 - blo .LGemmS8S8.M1.StoreOutputPartial - cbnz x13,.LGemmS8S8.M1.SkipAccumulateOutput - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -.LGemmS8S8.M1.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - cbnz x5,.LGemmS8S8.M1.ProcessNextColumnLoop - -.LGemmS8S8.M1.ExitKernel: - mov x0,#1 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmS8S8.M1.StoreOutputPartial: - cbz x13,.LGemmS8S8.M1.StoreOutputPartial.AddMode - -.LGemmS8S8.M1.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmS8S8.M1.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmS8S8.M1.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmS8S8.M1.ExitKernel - st1 {v16.s}[0],[x2] - b .LGemmS8S8.M1.ExitKernel - -.LGemmS8S8.M1.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmS8S8.M1.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmS8S8.M1.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmS8S8.M1.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmS8S8.M1.ExitKernel - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b .LGemmS8S8.M1.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S deleted file mode 100644 index e18846c89030e..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S +++ /dev/null @@ -1,922 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmS8S8KernelSmmla.s - -Abstract: - - This module implements the kernels for the Int8 precision matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "asmmacro.h" - - .text - -// -// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save -// - .equ .LMlasQgemmKernel_backup_x19_x20, 0 - .equ .LMlasQgemmKernel_backup_x21_x22, 16 - .equ .LMlasQgemmKernel_backup_x23_x24, 32 - .equ .LMlasQgemmKernel_backup_x25_x26, 48 - .equ .LMlasQgemmKernel_backup_x27_x28, 64 - .equ .LMlasQgemmKernel_backup_d8_d9, 80 - .equ .LMlasQgemmKernel_backup_d10_d11, 96 - .equ .LMlasQgemmKernel_backup_d12_d13, 112 - .equ .LMlasQgemmKernel_backup_d14_d15, 128 - .equ .LMlasQgemmKernel_SavedRegisters, 144 - .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 - - -// -// Init Row Accumulators -// -// Generates the code to initialize the accumulators for a single row of the output -// block. -// -// -// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum -// x7 for RowSumsBuffer pointer -// x10 for ColumnSumBuffer pointer -// x11 for ZeroPointB buffer pointer -// -// v12~v13 for RowSums values -// v14~v15 for ColumnSums values -// v0~v3 for ZeroPointB values -// - .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg - - mul v7.4s, v\RowSumReg\().4s, v8.4s - mov v\Vec1Reg\().16b, v7.16b - add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s -.if \Columns\() > 2 - mul v7.4s, v\RowSumReg\().4s, v9.4s - mov v\Vec2Reg\().16b, v7.16b - add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s -.endif -.if \Columns\() > 4 - mul v7.4s, v\RowSumReg\().4s, v10.4s - mov v\Vec3Reg\().16b, v7.16b - add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s -.endif -.if \Columns\() > 6 - mul v7.4s, v\RowSumReg\().4s, v11.4s - mov v\Vec4Reg\().16b, v7.16b - add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s -.endif - - .endm - - -// -// InitBlockAccumulators -// -// Generates the code to initialize the accumulators for 8x8 output -// block. -// - .macro InitBlockAccumulators Mode, Columns, Rows - - ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] -.if \Columns\() > 4 - ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] -.endif - // v4~v7 will be set to matrixB after this, so, they can used now - dup v4.4s,v14.s[0] // broadcast column - dup v5.4s,v14.s[1] - dup v6.4s,v14.s[2] - dup v7.4s,v14.s[3] - - zip1 v0.4s, v4.4s, v5.4s - zip2 v1.4s, v6.4s, v7.4s -.if \Columns\() > 4 - dup v4.4s,v15.s[0] // broadcast column - dup v5.4s,v15.s[1] - dup v6.4s,v15.s[2] - dup v7.4s,v15.s[3] - - zip1 v2.4s, v4.4s, v5.4s - zip2 v3.4s, v6.4s, v7.4s -.endif - - // v8~v11 will anyway get set in MatrixA loading, so they are free to use now - movi v8.4s, #1 - movi v9.4s, #1 - movi v10.4s, #1 - movi v11.4s, #1 - - cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB - - ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] - ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] - - dup v6.4s, v4.s[0] - dup v7.4s, v4.s[1] - zip1 v8.4s, v6.4s, v7.4s - - dup v6.4s, v4.s[2] - dup v7.4s, v4.s[3] - zip1 v9.4s, v6.4s, v7.4s - - dup v6.4s, v5.s[0] - dup v7.4s, v5.s[1] - zip1 v10.4s, v6.4s, v7.4s - - dup v6.4s, v5.s[2] - dup v7.4s, v5.s[3] - zip1 v11.4s, v6.4s, v7.4s - -.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: - dup v4.4s, v12.s[0] //boardcast RowSums - dup v5.4s, v12.s[1] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),16,17,18,19,6 -.if \Rows\() > 2 - dup v4.4s, v12.s[2] //boardcast RowSums - dup v5.4s, v12.s[3] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),20,21,22,23,6 -.endif -.if \Rows\() > 4 - dup v4.4s,v13.s[0] // broadcast row sums - dup v5.4s,v13.s[1] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),24,25,26,27,6 -.endif -.if \Rows\() > 6 - dup v4.4s,v13.s[2] // broadcast row sums - dup v5.4s,v13.s[3] - - uzp1 v6.2d, v4.2d, v5.2d - InitRowAccumulators \Columns\(),28,29,30,31,6 -.endif - - .endm - - -// LoadPackedMatrixABy16Elements -// -// Generates the code to load 16 elements from matrix A. -// - .macro LoadPackedMatrixABy16Elements Rows -.if \Rows\() == 1 - ldr q8,[x0],#8 -.else - ldr q8,[x0],#16 - -.if \Rows\() > 2 - ldr q9,[x0],#16 -.endif - -.if \Rows\() > 4 - ldr q10,[x0],#16 -.endif - -.if \Rows\() > 6 - ldr q11,[x0],#16 -.endif -.endif - .endm - - -// -// MultiplyAccumulateRow -// -// Generates the code to multiply and accumulate a single row of the output -// block. -// - - .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b -.if \Columns\() > 2 - smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b -.endif -.if \Columns\() > 4 - smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b -.endif -.if \Columns\() > 6 - smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b -.endif - - .endm - -// -// MultiplyAccumulateBlock -// -// Generates the code to multiply and accumulate into the output block. -// - - .macro MultiplyAccumulateBlock Columns, Rows - - MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 -.if \Rows\() > 2 - MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 -.endif -.if \Rows\() > 4 - MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 -.endif -.if \Rows\() > 6 - MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 -.endif - - .endm - -// -// ComputeBlockLoop -// -// Generates the code to loop over K entries of the input matrices to produce -// the output block. -// - - .macro ComputeBlockLoop Mode, Columns, Rows - - InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() - - sub x9,x3,#1 // block count to process - tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: - - LoadPackedMatrixABy16Elements \Rows\() - ld1 {v4.16b - v7.16b}, [x1], #64 - MultiplyAccumulateBlock \Columns\(),\Rows\() - - sub x9,x9,#1 - tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop -.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: - add x9,x9,#1 // correct for over-subtract above - cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: - LoadPackedMatrixABy16Elements \Rows\() - ld1 {v4.16b - v7.16b}, [x1], #64 - MultiplyAccumulateBlock \Columns\(),\Rows\() - -.L\Mode\().Output\Columns\().x\Rows\().Block: - - .endm - - -// -// OutputRow2Element -// OutputRow4Element -// OutputRow6Element -// OutputRow8Element -// OutputRow10Element -// OutputRow12Element -// OutputRow14Element -// OutputRow16Element -// -// Generates the code to store elements to the output block. -// - - .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr s8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr s9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - mov v8.S[2], v9.S[0] - add v8.4s,v8.4s,v\Vec1Reg\().4s - - mov w27, v8.S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v8.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - mov w27, v\Vec1Reg\().S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v\Vec1Reg\().S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - - mov v8.D[1], v9.D[0] - - add v8.4s,v8.4s,v\Vec1Reg\().4s - - mov x27, v8.D[0] - mov x28, v8.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - mov x27, v\Vec1Reg\().D[0] - mov x28, v\Vec1Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.endif - - .endm - - - .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#8 - ldr w28,[\AddrReg1\()],#-8 - mov v8.S[2], w28 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-8 - mov v9.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - mov x27, v8.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v8.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v9.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v9.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - mov x27, v4.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v4.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v5.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v5.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - -.endif - - .endm - - - .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr w28, [\AddrReg1\()],#-16 - -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr w27,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - mov v8.S[0], w28 - mov v8.S[2], w27 - - add v8.4s,v8.4s,v\Vec3Reg\().4s - - mov w27, v8.S[0] - mov w28, v8.S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov w27, v\Vec3Reg\().S[0] - mov w28, v\Vec3Reg\().S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif -.endif - -.endm - - - .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#-16 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - mov v11.D[0],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - - mov v10.D[1], v11.D[0] - - add v10.4s,v10.4s,v\Vec3Reg\().4s - - mov x27, v10.D[0] - mov x28, v10.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov x27, v\Vec3Reg\().D[0] - mov x28, v\Vec3Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif -.endif - - .endm - - .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#8 - ldr w28, [\AddrReg1\()],#-24 - mov v10.S[2], w28 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-24 - mov v11.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - add v10.4s,v10.4s,v6.4s - add v11.4s,v11.4s,v7.4s - - str q8,[\AddrReg1\()],#16 - - mov x27, v10.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v10.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 - mov x27, v11.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v11.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - str q4,[\AddrReg1\()],#16 - mov x27, v6.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v6.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 - mov x27, v7.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v7.S[2] - str w27, [\AddrReg2\()],#4 -.endif -.endif - - .endm - - - .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldp q8,q10,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldp q9,q11,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - add v10.4s,v10.4s,v6.4s - add v11.4s,v11.4s,v7.4s - - stp q8,q10,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q9,q11,[\AddrReg2\()],#32 -.endif -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - stp q4,q6,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q5,q7,[\AddrReg2\()],#32 -.endif -.endif - - .endm - -// -// OutputBlock -// -// Generates the code to store the output block. -// - - .macro OutputBlock Mode, Columns, Rows - - OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) - -.if \Rows\() > 2 - OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) -.endif - -.if \Rows\() > 4 - OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) -.endif - -.if \Rows\() > 6 - OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) -.endif - - .endm -// -// ProcessRows -// -// Generates the code to process a compute and store the output block for a -// fixed number of rows. -// - - .macro ProcessRows Mode, Rows - mov x4,#\Rows\() // return number of rows handled - cmp x5,#6 - ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() - -.L\Mode\().ProcessNextColumnLoop8x\Rows\(): - ComputeBlockLoop \Mode\(),8,\Rows\() - - sub x5,x5,#8 - cmp x5,#0 - blt .L\Mode\().Output14ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),16,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#6 - bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() - cbz x5,.L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop6x\Rows\(): - - cmp x5,#4 - ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() - ComputeBlockLoop \Mode\(),6,\Rows\() - sub x5,x5,#6 - cmp x5,#0 - blt .L\Mode\().Output10ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),12,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#4 - bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop4x\Rows\(): - cmp x5,#2 - ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() - ComputeBlockLoop \Mode\(),4,\Rows\() - sub x5,x5,#4 - cmp x5,#0 - blt .L\Mode\().Output6ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),8,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#2 - bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop2x\Rows\(): - ComputeBlockLoop \Mode\(),2,\Rows\() - sub x5,x5,#2 - cmp x5,#0 - blt .L\Mode\().Output2ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),4,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#2 - b .L\Mode\().ExitKernel - -.L\Mode\().Output14ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),14,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output10ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),10,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output6ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),6,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output2ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),2,\Rows\() - b .L\Mode\().ExitKernel - - .endm - - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .macro QgemmS8S8KernelSmmlaFunction Mode - - FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() - - ldr x10,[sp, #0] - ldr x11,[sp,#8] - - stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! - stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] - stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] - stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] - stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] - stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] - stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] - stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] - stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] - - add x13,x2,x6,lsl #2 // compute matrix C plus 1 row - add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows - add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows - add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows - add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows - add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows - add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows - - mov x8,x0 // save matrix A - -// -// Process 8 rows of the matrices. -// - ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 - cmp x4,#8 - blt .L\Mode\().ProcessCountMLessThan8 - ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 - ProcessRows \Mode\(),8 - -// -// Restore non-volatile registers and return. -// - -.L\Mode\().ExitKernel: - mov x0,x4 - - ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] - ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] - ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] - ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] - ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] - ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] - ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] - ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] - ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters - - ret - -// -// Process 4 rows of the matrix. -// - -.L\Mode\().ProcessCountMLessThan8: - cmp x4,#4 - blt .L\Mode\().ProcessCountMLessThan4 - ProcessRows \Mode\(),4 - b .L\Mode\().ExitKernel - -// -// Process 2 row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan4: - cmp x4,#2 - blt .L\Mode\().ProcessCountMLessThan2 - - ProcessRows \Mode\(),2 - b .L\Mode\().ExitKernel - - -// -// Process the last row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan2: - ProcessRows \Mode\(),1 - b .L\Mode\().ExitKernel - - - .endm - - QgemmS8S8KernelSmmlaFunction Zero - QgemmS8S8KernelSmmlaFunction Add - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S deleted file mode 100644 index 2f084764ceb09..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelNeon.S +++ /dev/null @@ -1,600 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelNeon.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the U8X8 kernel. -// - - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 - .equ .LGemmU8X8KernelFrame_ZeroPointB, 8 - .equ .LGemmU8X8KernelFrame_ZeroMode, 16 - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmU8X8KernelNeon - - ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] - ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v27.4s},[x7] - mov x15,x3 - dup v24.4s,v27.s[0] // broadcast row fixups - cmp x4,#1 // CountM == 1? - beq .LGemmU8X8.M1.ProcessNextColumnLoop - dup v25.4s,v27.s[1] - cmp x4,#4 // CountM < 4? - blo .LGemmU8X8.M2.ProcessNextColumnLoop - dup v26.4s,v27.s[2] - dup v27.4s,v27.s[3] - -// -// Process 4 rows of the matrices. -// - -.LGemmU8X8.M4.ProcessNextColumnLoop: - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - mul v18.4s,v25.4s,v28.4s - mul v19.4s,v25.4s,v29.4s - mul v20.4s,v26.4s,v28.4s - mul v21.4s,v26.4s,v29.4s - mul v22.4s,v27.4s,v28.4s - mul v23.4s,v27.4s,v29.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - add v18.4s,v2.4s,v18.4s - add v19.4s,v3.4s,v19.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v20.4s,v2.4s,v20.4s - add v21.4s,v3.4s,v21.4s - add v22.4s,v2.4s,v22.4s - add v23.4s,v3.4s,v23.4s - b .LGemmU8X8.M4.ComputeBlockLoop - -.LGemmU8X8.M4.SkipScaleByZeroPointB: - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - add v18.4s,v2.4s,v25.4s - add v19.4s,v3.4s,v25.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v20.4s,v2.4s,v26.4s - add v21.4s,v3.4s,v26.4s - add v22.4s,v2.4s,v27.4s - add v23.4s,v3.4s,v27.4s - -.LGemmU8X8.M4.ComputeBlockLoop: - uxtl v2.8h,v4.8b - uxtl v3.8h,v5.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - umlal v18.4s,v0.4h,v2.h[4] - umlal2 v19.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - umlal v20.4s,v0.4h,v3.h[0] - umlal2 v21.4s,v0.8h,v3.h[0] - umlal v22.4s,v0.4h,v3.h[4] - umlal2 v23.4s,v0.8h,v3.h[4] - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - umlal v18.4s,v1.4h,v2.h[5] - umlal2 v19.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - umlal v20.4s,v1.4h,v3.h[1] - umlal2 v21.4s,v1.8h,v3.h[1] - umlal v22.4s,v1.4h,v3.h[5] - umlal2 v23.4s,v1.8h,v3.h[5] - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,.LGemmU8X8.M4.ComputeBlockLoopFinish - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - umlal v20.4s,v0.4h,v3.h[2] - umlal2 v21.4s,v0.8h,v3.h[2] - umlal v22.4s,v0.4h,v3.h[6] - umlal2 v23.4s,v0.8h,v3.h[6] - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - ld1 {v5.8b},[x0],#8 // load next packed A1 - umlal v20.4s,v1.4h,v3.h[3] - umlal2 v21.4s,v1.8h,v3.h[3] - umlal v22.4s,v1.4h,v3.h[7] - umlal2 v23.4s,v1.8h,v3.h[7] - b .LGemmU8X8.M4.ComputeBlockLoop - -.LGemmU8X8.M4.ComputeBlockLoopFinish: - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v20.4s,v0.4h,v3.h[2] - umlal2 v21.4s,v0.8h,v3.h[2] - umlal v22.4s,v0.4h,v3.h[6] - umlal2 v23.4s,v0.8h,v3.h[6] - add x11,x10,x6,lsl #2 // compute output row 3 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - umlal v20.4s,v1.4h,v3.h[3] - umlal2 v21.4s,v1.8h,v3.h[3] - add x12,x11,x6,lsl #2 // compute output row 4 - umlal v22.4s,v1.4h,v3.h[7] - umlal2 v23.4s,v1.8h,v3.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo .LGemmU8X8.M4.StoreOutputPartial - cbnz x13,.LGemmU8X8.M4.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -.LGemmU8X8.M4.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,.LGemmU8X8.M4.ProcessNextColumnLoop - -.LGemmU8X8.M4.ExitKernel: - mov x0,#4 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M4.StoreOutputPartial: - cbz x13,.LGemmU8X8.M4.StoreOutputPartial.AddMode - -.LGemmU8X8.M4.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M4.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmU8X8.M4.ExitKernel - -.LGemmU8X8.M4.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmU8X8.M4.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmU8X8.M4.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M4.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmU8X8.M4.ExitKernel - -// -// Process 2 rows of the matrices. -// - -.LGemmU8X8.M2.ProcessNextColumnLoop: - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - mul v18.4s,v25.4s,v28.4s - mul v19.4s,v25.4s,v29.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - add v18.4s,v2.4s,v18.4s - add v19.4s,v3.4s,v19.4s - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.SkipScaleByZeroPointB: - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - add v18.4s,v2.4s,v25.4s - add v19.4s,v3.4s,v25.4s - -.LGemmU8X8.M2.ComputeBlockLoop: - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - umlal v18.4s,v0.4h,v2.h[4] - umlal2 v19.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - umlal v18.4s,v1.4h,v2.h[5] - umlal2 v19.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,.LGemmU8X8.M2.ComputeBlockLoopFinish - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.ComputeBlockLoopFinish: - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo .LGemmU8X8.M2.StoreOutputPartial - cbnz x13,.LGemmU8X8.M2.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -.LGemmU8X8.M2.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,.LGemmU8X8.M2.ProcessNextColumnLoop - -.LGemmU8X8.M2.ExitKernel: - mov x0,#2 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M2.StoreOutputPartial: - cbz x13,.LGemmU8X8.M2.StoreOutputPartial.AddMode - -.LGemmU8X8.M2.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M2.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmU8X8.M2.ExitKernel - -.LGemmU8X8.M2.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmU8X8.M2.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmU8X8.M2.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M2.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmU8X8.M2.ExitKernel - -// -// Process 1 row of the matrices. -// - -.LGemmU8X8.M1.ProcessNextColumnLoop: - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - ldr s4,[x0],#4 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.SkipScaleByZeroPointB: - ldr s4,[x0],#4 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - -.LGemmU8X8.M1.ComputeBlockLoop: - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,.LGemmU8X8.M1.ComputeBlockLoopFinish - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - ldr s4,[x0],#4 // load first packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - uxtl v0.8h,v0.8b - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.ComputeBlockLoopFinish: - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - subs x5,x5,#8 // adjust CountN remaining - blo .LGemmU8X8.M1.StoreOutputPartial - cbnz x13,.LGemmU8X8.M1.SkipAccumulateOutput - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -.LGemmU8X8.M1.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - cbnz x5,.LGemmU8X8.M1.ProcessNextColumnLoop - -.LGemmU8X8.M1.ExitKernel: - mov x0,#1 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M1.StoreOutputPartial: - cbz x13,.LGemmU8X8.M1.StoreOutputPartial.AddMode - -.LGemmU8X8.M1.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M1.ExitKernel - st1 {v16.s}[0],[x2] - b .LGemmU8X8.M1.ExitKernel - -.LGemmU8X8.M1.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M1.ExitKernel - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b .LGemmU8X8.M1.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S deleted file mode 100644 index 5d4fa4d09b458..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUdot.S +++ /dev/null @@ -1,1056 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelUdot.S - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the U8X8 kernel. -// Defining spaces for saving 2 vector registers, and pointers to parameters -// on the stack -// - - .equ .LGemmU8X8KernelFrame_SavedNeonRegisters, (2 * 8) - .equ .LGemmU8X8KernelFrame_SavedRegisters, .LGemmU8X8KernelFrame_SavedNeonRegisters - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 0 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ZeroPointB, 8 + .LGemmU8X8KernelFrame_SavedRegisters - .equ .LGemmU8X8KernelFrame_ZeroMode, 16 + .LGemmU8X8KernelFrame_SavedRegisters - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmU8X8KernelUdot - - stp d8,d9,[sp,#-16]! - ldr x8,[sp,#.LGemmU8X8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#.LGemmU8X8KernelFrame_ZeroPointB] - ldrb w13,[sp,#.LGemmU8X8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v8.4s},[x7],#16 // load row sum 1 ~ 4 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq .LGemmU8X8.M1.ProcessLoop - cmp x4,#4 // CountM < 4? - blo .LGemmU8X8.M2.ProcessLoop - cmp x4,#8 // CountM < 8? - blo .LGemmU8X8.M4.ProcessNextColumnLoop - ld1 {v9.4s},[x7] // load row sum 5 ~ 8 - -// -// Process 8 rows of the matrices. -// Row Sums: v8 ~ v9 -// B 4x8 block -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// A 8x4 block -// --------------------- ----------------------------------------- -// |v4.b[0] ... v4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v4.b[4] ... v4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v4.b[8] ... v4.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v4.b[12] ... v4.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v5.b[0] ... v5.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v5.b[4] ... v5.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v5.b[8] ... v5.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v5.b[12] ... v5.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- -// -// unroll for the next 4 in k dimension -// ----------------------------------------- -// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| -// | ... ... | -// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| -// ----------------------------------------- -// --------------------- ----------------------------------------- -// |v6.b[0] ... v6.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v6.b[4] ... v6.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v6.b[8] ... v6.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v6.b[12] ... v6.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v7.b[0] ... v7.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v7.b[4] ... v7.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v7.b[8] ... v7.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v7.b[12] ... v7.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- - -// Starting the loop: initialize accumulators with scaled combination -// of row and column sums - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - dup v25.4s,v9.s[0] - dup v27.4s,v9.s[1] - dup v29.4s,v9.s[2] - dup v31.4s,v9.s[3] - -.LGemmU8X8.M8.ProcessNextColumnLoop: - mov x0,x14 // reload matrix A - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v7.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,.LGemmU8X8.M8.SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v0.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v0.4s,v17.4s - mul v18.4s,v0.4s,v19.4s - ld1 {v1.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v0.4s,v21.4s - mul v22.4s,v0.4s,v23.4s - mul v24.4s,v0.4s,v25.4s - mul v26.4s,v0.4s,v27.4s - mul v28.4s,v0.4s,v29.4s - mul v30.4s,v0.4s,v31.4s - mul v17.4s,v1.4s,v17.4s - mul v19.4s,v1.4s,v19.4s - mul v21.4s,v1.4s,v21.4s - mul v23.4s,v1.4s,v23.4s - mul v25.4s,v1.4s,v25.4s - mul v27.4s,v1.4s,v27.4s - mul v29.4s,v1.4s,v29.4s - mul v31.4s,v1.4s,v31.4s - - // preloading mixed with accumulator inits - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v16.4s - add v18.4s,v3.4s,v18.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v20.4s - add v22.4s,v3.4s,v22.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v24.4s - add v26.4s,v3.4s,v26.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v28.4s - add v30.4s,v3.4s,v30.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - b .LGemmU8X8.M8.ComputeBlockLoop - -.LGemmU8X8.M8.SkipScaleByZeroPointB: - // accumulator = row sum A + column sum B - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v17.4s - add v18.4s,v3.4s,v19.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v21.4s - add v22.4s,v3.4s,v23.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v25.4s - add v26.4s,v3.4s,v27.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v29.4s - add v30.4s,v3.4s,v31.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - -.LGemmU8X8.M8.ComputeBlockLoop: - sub x3,x3,#1 - ld1 {v3.16b},[x1],#16 // load packed B1_next4k - UdotByElement 16, 0, 4, 0 - UdotByElement 18, 0, 4, 1 - ldr q7,[x0],#16 // load packed A3 - UdotByElement 20, 0, 4, 2 - UdotByElement 22, 0, 4, 3 - cbz x3,.LGemmU8X8.M8.ComputeBlockLoopFinish - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - UdotByElement 21, 1, 4, 2 - UdotByElement 23, 1, 4, 3 - ldr q4,[x0],#16 // load packed A0 for next iteration - UdotByElement 24, 0, 5, 0 - UdotByElement 26, 0, 5, 1 - UdotByElement 28, 0, 5, 2 - UdotByElement 30, 0, 5, 3 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - UdotByElement 25, 1, 5, 0 - UdotByElement 27, 1, 5, 1 - UdotByElement 29, 1, 5, 2 - UdotByElement 31, 1, 5, 3 - ld1 {v1.16b},[x1],#16 // load packed B1 for next iteration - - UdotByElement 16, 2, 6, 0 - UdotByElement 18, 2, 6, 1 - ldr q5,[x0],#16 // load packed A1 for next iteration - UdotByElement 20, 2, 6, 2 - UdotByElement 22, 2, 6, 3 - UdotByElement 17, 3, 6, 0 - UdotByElement 19, 3, 6, 1 - UdotByElement 21, 3, 6, 2 - UdotByElement 23, 3, 6, 3 - ldr q6,[x0],#16 // load packed A2 for next iteration - UdotByElement 24, 2, 7, 0 - UdotByElement 26, 2, 7, 1 - UdotByElement 28, 2, 7, 2 - UdotByElement 30, 2, 7, 3 - ld1 {v2.16b},[x1],#16 // load packed B0_next4k for next iteration - UdotByElement 25, 3, 7, 0 - UdotByElement 27, 3, 7, 1 - UdotByElement 29, 3, 7, 2 - UdotByElement 31, 3, 7, 3 - b .LGemmU8X8.M8.ComputeBlockLoop - -.LGemmU8X8.M8.ComputeBlockLoopFinish: - // postfix, compute tail values and prepare to write results - // We are either about to go to ProcessNextColumnLoopM8 - // where x0 and x3 are about to be restored, or exit - // when x0 and x3 will not be used. - // x4 x7 has finished their task - // so we can use x0 x3 x4 x7 as output row pointers - - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add x11,x10,x6,lsl #2 // compute output row 3 - UdotByElement 21, 1, 4, 2 - UdotByElement 23, 1, 4, 3 - add x12,x11,x6,lsl #2 // compute output row 4 - add x0,x12,x6,lsl #2 // compute output row 5 - UdotByElement 24, 0, 5, 0 - UdotByElement 26, 0, 5, 1 - add x3,x0,x6,lsl #2 // compute output row 6 - add x4,x3,x6,lsl #2 // compute output row 7 - UdotByElement 28, 0, 5, 2 - UdotByElement 30, 0, 5, 3 - add x7,x4,x6,lsl #2 // compute output row 8 - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 25, 1, 5, 0 - UdotByElement 27, 1, 5, 1 - UdotByElement 29, 1, 5, 2 - UdotByElement 31, 1, 5, 3 - - UdotByElement 16, 2, 6, 0 - UdotByElement 18, 2, 6, 1 - UdotByElement 20, 2, 6, 2 - UdotByElement 22, 2, 6, 3 - UdotByElement 17, 3, 6, 0 - UdotByElement 19, 3, 6, 1 - UdotByElement 21, 3, 6, 2 - UdotByElement 23, 3, 6, 3 - UdotByElement 24, 2, 7, 0 - UdotByElement 26, 2, 7, 1 - UdotByElement 28, 2, 7, 2 - UdotByElement 30, 2, 7, 3 - UdotByElement 25, 3, 7, 0 - UdotByElement 27, 3, 7, 1 - UdotByElement 29, 3, 7, 2 - UdotByElement 31, 3, 7, 3 - blo .LGemmU8X8.M8.StoreOutputPartial - cbnz x13,.LGemmU8X8.M8.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - ldp q0, q1, [x0] - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - ldp q2, q3, [x3] - add v24.4s,v24.4s,v0.4s - add v25.4s,v25.4s,v1.4s - ldp q4, q5, [x4] - add v26.4s,v26.4s,v2.4s - add v27.4s,v27.4s,v3.4s - ldp q6, q7, [x7] - add v28.4s,v28.4s,v4.4s - add v29.4s,v29.4s,v5.4s - add v30.4s,v30.4s,v6.4s - add v31.4s,v31.4s,v7.4s - -.LGemmU8X8.M8.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - dup v17.4s,v8.s[0] // broadcast row sums - stp q18,q19,[x10] - dup v19.4s,v8.s[1] - stp q20,q21,[x11] - dup v21.4s,v8.s[2] - stp q22,q23,[x12] - dup v23.4s,v8.s[3] - stp q24,q25,[x0] - dup v25.4s,v9.s[0] - stp q26,q27,[x3] - dup v27.4s,v9.s[1] - stp q28,q29,[x4] - dup v29.4s,v9.s[2] - stp q30,q31,[x7] - dup v31.4s,v9.s[3] - - cbnz x5,.LGemmU8X8.M8.ProcessNextColumnLoop - -.LGemmU8X8.M8.ExitKernel: - mov x0,#8 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M8.StoreOutputPartial: - cbz x13,.LGemmU8X8.M8.StoreOutputPartialAddMode - -.LGemmU8X8.M8.StoreOutputPartialZeroMode: - tbz x5,#2,.LGemmU8X8.M8.StoreOutputPartial2ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -.LGemmU8X8.M8.StoreOutputPartial2ZeroMode: - tbz x5,#1,.LGemmU8X8.M8.StoreOutputPartial1ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -.LGemmU8X8.M8.StoreOutputPartial1ZeroMode: - tbz x5,#0,.LGemmU8X8.M8.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - st1 {v24.s}[0],[x0] - st1 {v26.s}[0],[x3] - st1 {v28.s}[0],[x4] - st1 {v30.s}[0],[x7] - b .LGemmU8X8.M8.ExitKernel - -.LGemmU8X8.M8.StoreOutputPartialAddMode: - tbz x5,#2,.LGemmU8X8.M8.StoreOutputPartial2AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - ld1 {v4.4s},[x0] - ld1 {v5.4s},[x3] - ld1 {v6.4s},[x4] - ld1 {v7.4s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -.LGemmU8X8.M8.StoreOutputPartial2AddMode: - tbz x5,#1,.LGemmU8X8.M8.StoreOutputPartial1AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - ld1 {v4.2s},[x0] - ld1 {v5.2s},[x3] - ld1 {v6.2s},[x4] - ld1 {v7.2s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -.LGemmU8X8.M8.StoreOutputPartial1AddMode: - tbz x5,#0,.LGemmU8X8.M8.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - ld1 {v4.s}[0],[x0] - ld1 {v5.s}[0],[x3] - ld1 {v6.s}[0],[x4] - ld1 {v7.s}[0],[x7] - add v24.4s,v24.4s,v4.4s - st1 {v24.s}[0],[x0] - add v26.4s,v26.4s,v5.4s - st1 {v26.s}[0],[x3] - add v28.4s,v28.4s,v6.4s - st1 {v28.s}[0],[x4] - add v30.4s,v30.4s,v7.4s - st1 {v30.s}[0],[x7] - b .LGemmU8X8.M8.ExitKernel - - -// -// Process 4 rows of the matrices. -// -// -// The packing layout is setup to have a pair of four quad vectors from -// packed matrix A and a pair of eight quad vectors from packed matrix B. -// With this scheme, alternating loads from the packed matrices can be -// interleaved with the dot product instructions. -// -// One negative consequence of using four rows here is that the accumulator -// register tile is too small for processors with high out of order execution -// windows (such as the Apple M1). The dot product instructions for a given -// cell are too close to each other to avoid dependencies. To workaround this, -// the below loop uses a pair of accumulator registers that are then added -// together when the loop finishes. -// -// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for -// packed matrix A. At the time of this implementation, using a wider 128-bit -// load did not affect performance for higher end cores. -// -// B 4x8 block -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// A 4x4 block -// --------------------- ----------------------------------------- -// |d4.b[0] ... d4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |d4.b[4] ... d4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |d5.b[0] ... d5.b[3] | |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |d5.b[4] ... d5.b[7] | |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// --------------------- ----------------------------------------- -// -// unroll for the next 4 in k dimension -// ----------------------------------------- -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// ----------------------------------------- -// --------------------- ----------------------------------------- -// |d6.b[0] ... d6.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |d6.b[4] ... d6.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |d7.b[0] ... d7.b[3] | |v28.s[0] .. v24.s[3] v29.s[0] .. v29.s[3]| -// |d7.b[4] ... d7.b[7] | |v30.s[0] .. v24.s[3] v31.s[0] .. v31.s[3]| -// --------------------- ----------------------------------------- - -.LGemmU8X8.M4.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - cbz x9,.LGemmU8X8.M4.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v30.4s,v17.4s - mul v18.4s,v30.4s,v19.4s - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v30.4s,v21.4s - mul v22.4s,v30.4s,v23.4s - mul v17.4s,v31.4s,v17.4s - mul v19.4s,v31.4s,v19.4s - mul v21.4s,v31.4s,v21.4s - mul v23.4s,v31.4s,v23.4s - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - add v20.4s,v2.4s,v20.4s - add v22.4s,v2.4s,v22.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - b .LGemmU8X8.M4.ComputeBlockLoopStart - -.LGemmU8X8.M4.SkipScaleByZeroPointB: - add v16.4s,v2.4s,v17.4s - add v18.4s,v2.4s,v19.4s - add v20.4s,v2.4s,v21.4s - add v22.4s,v2.4s,v23.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - -.LGemmU8X8.M4.ComputeBlockLoopStart: - ldr d4,[x0],#32 // load packed A0.l - movi v24.4s,#0 - movi v25.4s,#0 - ldur d5,[x0,#-24] // load packed A0.h - movi v26.4s,#0 - movi v27.4s,#0 - ldur d6,[x0,#-16] // load packed A1.l - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -.LGemmU8X8.M4.ComputeBlockLoop: - ld1 {v1.16b},[x1],#16 // load packed B1 - UdotByElement 16, 0, 4, 0 - UdotByElement 18, 0, 4, 1 - ldur d7,[x0,#-8] // load packed A1.h - UdotByElement 20, 0, 5, 0 - UdotByElement 22, 0, 5, 1 - ld1 {v0.16b},[x1],#16 // load packed B0_next4k - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - sub x3,x3,#1 - cbz x3,.LGemmU8X8.M4.ComputeBlockLoopFinish - ldr d4,[x0],#32 // load packed A0.l for next iteration - UdotByElement 21, 1, 5, 0 - UdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - UdotByElement 24, 0, 6, 0 - UdotByElement 26, 0, 6, 1 - ldur d5,[x0,#-24] // load packed A0.h for next iteration - UdotByElement 28, 0, 7, 0 - UdotByElement 30, 0, 7, 1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - UdotByElement 25, 1, 6, 0 - UdotByElement 27, 1, 6, 1 - ldur d6,[x0,#-16] // load packed A1.l for next iteration - UdotByElement 29, 1, 7, 0 - UdotByElement 31, 1, 7, 1 - b .LGemmU8X8.M4.ComputeBlockLoop - -.LGemmU8X8.M4.ComputeBlockLoopFinish: - UdotByElement 21, 1, 5, 0 - UdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - UdotByElement 24, 0, 6, 0 - UdotByElement 26, 0, 6, 1 - UdotByElement 28, 0, 7, 0 - UdotByElement 30, 0, 7, 1 - UdotByElement 25, 1, 6, 0 - UdotByElement 27, 1, 6, 1 - UdotByElement 29, 1, 7, 0 - UdotByElement 31, 1, 7, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add v16.4s,v16.4s,v24.4s // fold high results into low results - add v18.4s,v18.4s,v26.4s - add v20.4s,v20.4s,v28.4s - add v22.4s,v22.4s,v30.4s - add x11,x10,x6,lsl #2 // compute output row 3 - add v17.4s,v17.4s,v25.4s - add v19.4s,v19.4s,v27.4s - add v21.4s,v21.4s,v29.4s - add v23.4s,v23.4s,v31.4s - add x12,x11,x6,lsl #2 // compute output row 4 - subs x5,x5,#8 // adjust CountN remaining - blo .LGemmU8X8.M4.StoreOutputPartial - cbnz x13,.LGemmU8X8.M4.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -.LGemmU8X8.M4.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,.LGemmU8X8.M4.ProcessNextColumnLoop - -.LGemmU8X8.M4.ExitKernel: - mov x0,#4 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M4.StoreOutputPartial: - cbz x13,.LGemmU8X8.M4.StoreOutputPartial.AddMode - -.LGemmU8X8.M4.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmU8X8.M4.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmU8X8.M4.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M4.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmU8X8.M4.ExitKernel - -.LGemmU8X8.M4.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M4.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -.LGemmU8X8.M4.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M4.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -.LGemmU8X8.M4.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M4.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b .LGemmU8X8.M4.ExitKernel - -// -// Process 2 rows of the matrices. -// -.LGemmU8X8.M2.ProcessLoop: - dup v9.4s, v8.s[1] - dup v8.4s, v8.s[0] - -.LGemmU8X8.M2.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,.LGemmU8X8.M2.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v16.4s,v30.4s,v8.4s - mul v18.4s,v30.4s,v9.4s - mul v17.4s,v31.4s,v8.4s - mul v19.4s,v31.4s,v9.4s - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.SkipScaleByZeroPointB: - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v8.4s - add v18.4s,v2.4s,v9.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v8.4s - add v19.4s,v3.4s,v9.4s - -.LGemmU8X8.M2.ComputeBlockLoop: - sub x3,x3,#1 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - UdotByElement 16, 0, 4, 0 - UdotByElement 17, 1, 4, 0 - UdotByElement 18, 0, 4, 1 - UdotByElement 19, 1, 4, 1 - cbz x3,.LGemmU8X8.M2.ComputeBlockLoopFinish - ldr d4,[x0],#8 // load packed A0.l for next iter - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - UdotByElement 16, 6, 5, 0 - UdotByElement 17, 7, 5, 0 - UdotByElement 18, 6, 5, 1 - UdotByElement 19, 7, 5, 1 - ldr d5,[x0],#8 // load packed A0.h for next iter - b .LGemmU8X8.M2.ComputeBlockLoop - -.LGemmU8X8.M2.ComputeBlockLoopFinish: - add x10,x2,x6,lsl #2 // compute output row 2 - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 16, 6, 5, 0 - UdotByElement 17, 7, 5, 0 - UdotByElement 18, 6, 5, 1 - UdotByElement 19, 7, 5, 1 - blo .LGemmU8X8.M2.StoreOutputPartial - cbnz x13,.LGemmU8X8.M2.SkipAccumulateOutput - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -.LGemmU8X8.M2.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,.LGemmU8X8.M2.ProcessNextColumnLoop - -.LGemmU8X8.M2.ExitKernel: - mov x0,#2 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M2.StoreOutputPartial: - cbz x13,.LGemmU8X8.M2.StoreOutputPartial.AddMode - -.LGemmU8X8.M2.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmU8X8.M2.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmU8X8.M2.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M2.ExitKernel - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmU8X8.M2.ExitKernel - -.LGemmU8X8.M2.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M2.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -.LGemmU8X8.M2.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M2.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -.LGemmU8X8.M2.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M2.ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b .LGemmU8X8.M2.ExitKernel - -// -// Process 1 row of the matrices. -// -.LGemmU8X8.M1.ProcessLoop: - - dup v8.4s,v8.s[0] - -.LGemmU8X8.M1.ProcessNextColumnLoop: - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - cbz x9,.LGemmU8X8.M1.SkipScaleByZeroPointB - ld1 {v30.4s},[x9],#16 // load ZeroPointB0 - ld1 {v31.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v30.4s,v8.4s - mul v17.4s,v31.4s,v8.4s - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.SkipScaleByZeroPointB: - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v8.4s - add v17.4s,v3.4s,v8.4s - -.LGemmU8X8.M1.ComputeBlockLoop: - sub x3,x3,#1 - UdotByElement 16, 0, 4, 0 - UdotByElement 17, 1, 4, 0 - cbz x3,.LGemmU8X8.M1.ComputeBlockLoopFinish - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - UdotByElement 16, 6, 4, 1 - UdotByElement 17, 7, 4, 1 - ldr d4,[x0],#8 // load packed A0 for next iter - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k for next iter - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k for next iter - b .LGemmU8X8.M1.ComputeBlockLoop - -.LGemmU8X8.M1.ComputeBlockLoopFinish: - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 16, 6, 4, 1 - UdotByElement 17, 7, 4, 1 - blo .LGemmU8X8.M1.StoreOutputPartial - cbnz x13,.LGemmU8X8.M1.SkipAccumulateOutput - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -.LGemmU8X8.M1.SkipAccumulateOutput: - stp q16,q17,[x2],#32 - cbnz x5,.LGemmU8X8.M1.ProcessNextColumnLoop - -.LGemmU8X8.M1.ExitKernel: - mov x0,#1 // return number of rows handled - ldp d8,d9,[sp],#16 - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -.LGemmU8X8.M1.StoreOutputPartial: - cbz x13,.LGemmU8X8.M1.StoreOutputPartial.AddMode - -.LGemmU8X8.M1.StoreOutputPartial.ZeroMode: - tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.ZeroMode: - tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.ZeroMode: - tbz x5,#0,.LGemmU8X8.M1.ExitKernel - st1 {v16.s}[0],[x2] - b .LGemmU8X8.M1.ExitKernel - -.LGemmU8X8.M1.StoreOutputPartial.AddMode: - tbz x5,#2,.LGemmU8X8.M1.StoreOutputPartial2.AddMode - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial2.AddMode: - tbz x5,#1,.LGemmU8X8.M1.StoreOutputPartial1.AddMode - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -.LGemmU8X8.M1.StoreOutputPartial1.AddMode: - tbz x5,#0,.LGemmU8X8.M1.ExitKernel - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b .LGemmU8X8.M1.ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S deleted file mode 100644 index baf6e21e6ff06..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S +++ /dev/null @@ -1,922 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelUmmla.s - -Abstract: - - This module implements the kernels for the Int8 precision matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "asmmacro.h" - - .text - -// -// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save -// - .equ .LMlasQgemmKernel_backup_x19_x20, 0 - .equ .LMlasQgemmKernel_backup_x21_x22, 16 - .equ .LMlasQgemmKernel_backup_x23_x24, 32 - .equ .LMlasQgemmKernel_backup_x25_x26, 48 - .equ .LMlasQgemmKernel_backup_x27_x28, 64 - .equ .LMlasQgemmKernel_backup_d8_d9, 80 - .equ .LMlasQgemmKernel_backup_d10_d11, 96 - .equ .LMlasQgemmKernel_backup_d12_d13, 112 - .equ .LMlasQgemmKernel_backup_d14_d15, 128 - .equ .LMlasQgemmKernel_SavedRegisters, 144 - .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 - - -// -// Init Row Accumulators -// -// Generates the code to initialize the accumulators for a single row of the output -// block. -// -// -// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum -// x7 for RowSumsBuffer pointer -// x10 for ColumnSumBuffer pointer -// x11 for ZeroPointB buffer pointer -// -// v12~v13 for RowSums values -// v14~v15 for ColumnSums values -// v0~v3 for ZeroPointB values -// - .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg - - mul v7.4s, v\RowSumReg\().4s, v8.4s - mov v\Vec1Reg\().16b, v7.16b - add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s -.if \Columns\() > 2 - mul v7.4s, v\RowSumReg\().4s, v9.4s - mov v\Vec2Reg\().16b, v7.16b - add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s -.endif -.if \Columns\() > 4 - mul v7.4s, v\RowSumReg\().4s, v10.4s - mov v\Vec3Reg\().16b, v7.16b - add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s -.endif -.if \Columns\() > 6 - mul v7.4s, v\RowSumReg\().4s, v11.4s - mov v\Vec4Reg\().16b, v7.16b - add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s -.endif - - .endm - - -// -// InitBlockAccumulators -// -// Generates the code to initialize the accumulators for 8x8 output -// block. -// - .macro InitBlockAccumulators Mode, Columns, Rows - - ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] -.if \Columns\() > 4 - ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] -.endif - // v4~v7 will be set to matrixB after this, so, they can used now - dup v4.4s,v14.s[0] // broadcast column - dup v5.4s,v14.s[1] - dup v6.4s,v14.s[2] - dup v7.4s,v14.s[3] - - zip1 v0.4s, v4.4s, v5.4s - zip2 v1.4s, v6.4s, v7.4s -.if \Columns\() > 4 - dup v4.4s,v15.s[0] // broadcast column - dup v5.4s,v15.s[1] - dup v6.4s,v15.s[2] - dup v7.4s,v15.s[3] - - zip1 v2.4s, v4.4s, v5.4s - zip2 v3.4s, v6.4s, v7.4s -.endif - - // v8~v11 will anyway get set in MatrixA loading, so they are free to use now - movi v8.4s, #1 - movi v9.4s, #1 - movi v10.4s, #1 - movi v11.4s, #1 - - cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB - - ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] - ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] - - dup v6.4s, v4.s[0] - dup v7.4s, v4.s[1] - zip1 v8.4s, v6.4s, v7.4s - - dup v6.4s, v4.s[2] - dup v7.4s, v4.s[3] - zip1 v9.4s, v6.4s, v7.4s - - dup v6.4s, v5.s[0] - dup v7.4s, v5.s[1] - zip1 v10.4s, v6.4s, v7.4s - - dup v6.4s, v5.s[2] - dup v7.4s, v5.s[3] - zip1 v11.4s, v6.4s, v7.4s - -.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: - dup v4.4s, v12.s[0] //boardcast RowSums - dup v5.4s, v12.s[1] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),16,17,18,19,6 -.if \Rows\() > 2 - dup v4.4s, v12.s[2] //boardcast RowSums - dup v5.4s, v12.s[3] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),20,21,22,23,6 -.endif -.if \Rows\() > 4 - dup v4.4s,v13.s[0] // broadcast row sums - dup v5.4s,v13.s[1] - - uzp1 v6.2d, v4.2d, v5.2d - - InitRowAccumulators \Columns\(),24,25,26,27,6 -.endif -.if \Rows\() > 6 - dup v4.4s,v13.s[2] // broadcast row sums - dup v5.4s,v13.s[3] - - uzp1 v6.2d, v4.2d, v5.2d - InitRowAccumulators \Columns\(),28,29,30,31,6 -.endif - - .endm - - -// LoadPackedMatrixABy16Elements -// -// Generates the code to load 16 elements from matrix A. -// - .macro LoadPackedMatrixABy16Elements Rows -.if \Rows\() == 1 - ldr q8,[x0],#8 -.else - ldr q8,[x0],#16 - -.if \Rows\() > 2 - ldr q9,[x0],#16 -.endif - -.if \Rows\() > 4 - ldr q10,[x0],#16 -.endif - -.if \Rows\() > 6 - ldr q11,[x0],#16 -.endif -.endif - .endm - - -// -// MultiplyAccumulateRow -// -// Generates the code to multiply and accumulate a single row of the output -// block. -// - - .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b -.if \Columns\() > 2 - ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b -.endif -.if \Columns\() > 4 - ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b -.endif -.if \Columns\() > 6 - ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b -.endif - - .endm - -// -// MultiplyAccumulateBlock -// -// Generates the code to multiply and accumulate into the output block. -// - - .macro MultiplyAccumulateBlock Columns, Rows - - MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 -.if \Rows\() > 2 - MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 -.endif -.if \Rows\() > 4 - MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 -.endif -.if \Rows\() > 6 - MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 -.endif - - .endm - -// -// ComputeBlockLoop -// -// Generates the code to loop over K entries of the input matrices to produce -// the output block. -// - - .macro ComputeBlockLoop Mode, Columns, Rows - - InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() - - sub x9,x3,#1 // block count to process - tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: - - LoadPackedMatrixABy16Elements \Rows\() - ld1 {v4.16b - v7.16b}, [x1], #64 - MultiplyAccumulateBlock \Columns\(),\Rows\() - - sub x9,x9,#1 - tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop -.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: - add x9,x9,#1 // correct for over-subtract above - cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: - LoadPackedMatrixABy16Elements \Rows\() - ld1 {v4.16b - v7.16b}, [x1], #64 - MultiplyAccumulateBlock \Columns\(),\Rows\() - -.L\Mode\().Output\Columns\().x\Rows\().Block: - - .endm - - -// -// OutputRow2Element -// OutputRow4Element -// OutputRow6Element -// OutputRow8Element -// OutputRow10Element -// OutputRow12Element -// OutputRow14Element -// OutputRow16Element -// -// Generates the code to store elements to the output block. -// - - .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr s8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr s9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - mov v8.S[2], v9.S[0] - add v8.4s,v8.4s,v\Vec1Reg\().4s - - mov w27, v8.S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v8.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - mov w27, v\Vec1Reg\().S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v\Vec1Reg\().S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - - mov v8.D[1], v9.D[0] - - add v8.4s,v8.4s,v\Vec1Reg\().4s - - mov x27, v8.D[0] - mov x28, v8.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - mov x27, v\Vec1Reg\().D[0] - mov x28, v\Vec1Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.endif - - .endm - - - .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#8 - ldr w28,[\AddrReg1\()],#-8 - mov v8.S[2], w28 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-8 - mov v9.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - mov x27, v8.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v8.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v9.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v9.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - mov x27, v4.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v4.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v5.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v5.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - -.endif - - .endm - - - .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr w28, [\AddrReg1\()],#-16 - -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr w27,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - mov v8.S[0], w28 - mov v8.S[2], w27 - - add v8.4s,v8.4s,v\Vec3Reg\().4s - - mov w27, v8.S[0] - mov w28, v8.S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov w27, v\Vec3Reg\().S[0] - mov w28, v\Vec3Reg\().S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif -.endif - -.endm - - - .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#-16 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - mov v11.D[0],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - - mov v10.D[1], v11.D[0] - - add v10.4s,v10.4s,v\Vec3Reg\().4s - - mov x27, v10.D[0] - mov x28, v10.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov x27, v\Vec3Reg\().D[0] - mov x28, v\Vec3Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif -.endif - - .endm - - .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#8 - ldr w28, [\AddrReg1\()],#-24 - mov v10.S[2], w28 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-24 - mov v11.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - add v10.4s,v10.4s,v6.4s - add v11.4s,v11.4s,v7.4s - - str q8,[\AddrReg1\()],#16 - - mov x27, v10.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v10.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 - mov x27, v11.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v11.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - str q4,[\AddrReg1\()],#16 - mov x27, v6.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v6.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 - mov x27, v7.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v7.S[2] - str w27, [\AddrReg2\()],#4 -.endif -.endif - - .endm - - - .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldp q8,q10,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldp q9,q11,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - add v8.4s,v8.4s,v4.4s - add v9.4s,v9.4s,v5.4s - add v10.4s,v10.4s,v6.4s - add v11.4s,v11.4s,v7.4s - - stp q8,q10,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q9,q11,[\AddrReg2\()],#32 -.endif -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - stp q4,q6,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q5,q7,[\AddrReg2\()],#32 -.endif -.endif - - .endm - -// -// OutputBlock -// -// Generates the code to store the output block. -// - - .macro OutputBlock Mode, Columns, Rows - - OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) - -.if \Rows\() > 2 - OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) -.endif - -.if \Rows\() > 4 - OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) -.endif - -.if \Rows\() > 6 - OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) -.endif - - .endm -// -// ProcessRows -// -// Generates the code to process a compute and store the output block for a -// fixed number of rows. -// - - .macro ProcessRows Mode, Rows - mov x4,#\Rows\() // return number of rows handled - cmp x5,#6 - ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() - -.L\Mode\().ProcessNextColumnLoop8x\Rows\(): - ComputeBlockLoop \Mode\(),8,\Rows\() - - sub x5,x5,#8 - cmp x5,#0 - blt .L\Mode\().Output14ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),16,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#6 - bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() - cbz x5,.L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop6x\Rows\(): - - cmp x5,#4 - ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() - ComputeBlockLoop \Mode\(),6,\Rows\() - sub x5,x5,#6 - cmp x5,#0 - blt .L\Mode\().Output10ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),12,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#4 - bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop4x\Rows\(): - cmp x5,#2 - ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() - ComputeBlockLoop \Mode\(),4,\Rows\() - sub x5,x5,#4 - cmp x5,#0 - blt .L\Mode\().Output6ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),8,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#2 - bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop2x\Rows\(): - ComputeBlockLoop \Mode\(),2,\Rows\() - sub x5,x5,#2 - cmp x5,#0 - blt .L\Mode\().Output2ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),4,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#2 - b .L\Mode\().ExitKernel - -.L\Mode\().Output14ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),14,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output10ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),10,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output6ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),6,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output2ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),2,\Rows\() - b .L\Mode\().ExitKernel - - .endm - - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - -Return Value: - - Returns the number of rows handled. - ---*/ - - .macro QgemmU8X8KernelUmmlaFunction Mode - - FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() - - ldr x10,[sp, #0] - ldr x11,[sp,#8] - - stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! - stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] - stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] - stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] - stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] - stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] - stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] - stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] - stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] - - add x13,x2,x6,lsl #2 // compute matrix C plus 1 row - add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows - add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows - add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows - add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows - add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows - add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows - - mov x8,x0 // save matrix A - -// -// Process 8 rows of the matrices. -// - ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 - cmp x4,#8 - blt .L\Mode\().ProcessCountMLessThan8 - ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 - ProcessRows \Mode\(),8 - -// -// Restore non-volatile registers and return. -// - -.L\Mode\().ExitKernel: - mov x0,x4 - - ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] - ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] - ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] - ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] - ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] - ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] - ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] - ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] - ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters - - ret - -// -// Process 4 rows of the matrix. -// - -.L\Mode\().ProcessCountMLessThan8: - cmp x4,#4 - blt .L\Mode\().ProcessCountMLessThan4 - ProcessRows \Mode\(),4 - b .L\Mode\().ExitKernel - -// -// Process 2 row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan4: - cmp x4,#2 - blt .L\Mode\().ProcessCountMLessThan2 - - ProcessRows \Mode\(),2 - b .L\Mode\().ExitKernel - - -// -// Process the last row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan2: - ProcessRows \Mode\(),1 - b .L\Mode\().ExitKernel - - - .endm - - QgemmU8X8KernelUmmlaFunction Zero - QgemmU8X8KernelUmmlaFunction Add - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S deleted file mode 100644 index e424c30515e9f..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S +++ /dev/null @@ -1,907 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - SbgemmKernelNeon.s - -Abstract: - - This module implements the kernels for the bfloat16 half precision matrix/matrix - multiply operation (SBGEMM). - ---*/ - -#include "asmmacro.h" - - .text - -// -// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save -// - .equ .LMlasSbgemmKernel_backup_x19_x20, 0 - .equ .LMlasSbgemmKernel_backup_x21_x22, 16 - .equ .LMlasSbgemmKernel_backup_x23_x24, 32 - .equ .LMlasSbgemmKernel_backup_x25_x26, 48 - .equ .LMlasSbgemmKernel_backup_x27_x28, 64 - .equ .LMlasSbgemmKernel_backup_d8_d9, 80 - .equ .LMlasSbgemmKernel_backup_d10_d11, 96 - .equ .LMlasSbgemmKernel_backup_d12_d13, 112 - .equ .LMlasSbgemmKernel_backup_d14_d15, 128 - .equ .LMlasSbgemmKernel_SavedRegisters, 144 - .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 - - -// -// ClearRowAccumulators -// -// Generates the code to clear the accumulators for a single row of the output -// block. -// - - .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - mov v\Vec1Reg\().16b,v0.16b -.if \Columns\() > 2 - mov v\Vec2Reg\().16b,v1.16b -.endif -.if \Columns\() > 4 - mov v\Vec3Reg\().16b,v2.16b -.endif -.if \Columns\() > 6 - mov v\Vec4Reg\().16b,v3.16b -.endif - - .endm - -// -// InitBlockAccumulators -// -// Generates the code to init the accumulators for a single row of the output -// block. -// - - .macro InitBlockAccumulators Mode, Columns, Rows - - //check if the Bias != nullptr - cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd - - ld1 {v14.4s},[x8],#16 // load Bias[0] - // v4~v7 will be set to matrixB after this, so, they can used now - dup v4.4s,v14.s[0] // broadcast Bias - dup v5.4s,v14.s[1] - dup v6.4s,v14.s[2] - dup v7.4s,v14.s[3] - - zip1 v0.4s, v4.4s, v5.4s - zip2 v1.4s, v6.4s, v7.4s -.if \Columns\() > 4 - ld1 {v15.4s},[x8],#16 // load Bias[4] - dup v4.4s,v15.s[0] // broadcast Bias - dup v5.4s,v15.s[1] - dup v6.4s,v15.s[2] - dup v7.4s,v15.s[3] - - zip1 v2.4s, v4.4s, v5.4s - zip2 v3.4s, v6.4s, v7.4s -.endif - - b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() - -.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: - eor v0.16b,v0.16b,v0.16b // No bias, reset regs - eor v1.16b,v1.16b,v1.16b - eor v2.16b,v2.16b,v2.16b - eor v3.16b,v3.16b,v3.16b - -.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): - InitRowAccumulators \Columns\(),16,17,18,19 -.if \Rows\() > 2 - InitRowAccumulators \Columns\(),20,21,22,23 -.endif -.if \Rows\() > 4 - InitRowAccumulators \Columns\(),24,25,26,27 -.endif -.if \Rows\() > 6 - InitRowAccumulators \Columns\(),28,29,30,31 -.endif - - .endm - -// LoadMatrixAElementsBy8 -// -// Generates the code to load 4 or 8 elements from matrix A. -// - .macro LoadMatrixAElementsBy8 Rows - - ldr q8,[x0],#16 - bfcvtn v8.4h, v8.4s -.if \Rows\() > 1 - ldr q1,[x10],#16 - bfcvtn2 v8.8h, v1.4s -.endif - -.if \Rows\() > 2 - ldr q9,[x11],#16 - bfcvtn v9.4h, v9.4s -.endif -.if \Rows\() > 3 - ldr q1,[x12],#16 - bfcvtn2 v9.8h, v1.4s -.endif - -.if \Rows\() > 4 - ldr q10,[x20],#16 - bfcvtn v10.4h, v10.4s -.endif -.if \Rows\() > 5 - ldr q1,[x21],#16 - bfcvtn2 v10.8h, v1.4s -.endif - -.if \Rows\() > 6 - ldr q11,[x22],#16 - bfcvtn v11.4h, v11.4s -.endif -.if \Rows\() > 7 - ldr q1,[x23],#16 - bfcvtn2 v11.8h, v1.4s -.endif - - .endm - - -// -// MultiplyAccumulateRow -// -// Generates the code to multiply and accumulate a single row of the output -// block. -// - - .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h -.if \Columns\() > 2 - bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h -.endif -.if \Columns\() > 4 - bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h -.endif -.if \Columns\() > 6 - bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h -.endif - - .endm - -// -// MultiplyAccumulateBlock -// -// Generates the code to multiply and accumulate into the output block. -// - - .macro MultiplyAccumulateBlock Columns, Rows - - MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 -.if \Rows\() > 2 - MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 -.endif -.if \Rows\() > 4 - MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 -.endif -.if \Rows\() > 6 - MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 -.endif - - .endm - -// -// ComputeBlockLoop -// -// Generates the code to loop over K entries of the input matrices to produce -// the output block. -// - - .macro ComputeBlockLoop Mode, Columns, Rows - - InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() - - add x10,x0,x6,lsl #2 // compute matrix A plus 1 row -.if \Rows\() > 2 - add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows - add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows -.endif -.if \Rows\() > 4 - add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows - add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows -.endif -.if \Rows\() > 6 - add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows - add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows -.endif - sub x9,x3,#4 // block count to process - tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: - - LoadMatrixAElementsBy8 \Rows\() - ldr q4, [x1],#16 -.if \Columns\() > 2 - ldr q5,[x1],#16 -.endif -.if \Columns\() > 4 - ldr q6,[x1],#16 -.endif -.if \Columns\() > 6 - ldr q7,[x1],#16 -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\() - - sub x9,x9,#4 - tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop -.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: - add x9,x9,#4 // correct for over-subtract above - cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: - LoadMatrixAElementsBy8 \Rows\() - ldr q4, [x1],#16 -.if \Columns\() > 2 - ldr q5,[x1],#16 -.endif -.if \Columns\() > 4 - ldr q6,[x1],#16 -.endif -.if \Columns\() > 6 - ldr q7,[x1],#16 -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\() - -.L\Mode\().Output\Columns\().x\Rows\().Block: - - .endm - - -// -// OutputRow2Element -// OutputRow4Element -// OutputRow6Element -// OutputRow8Element -// OutputRow10Element -// OutputRow12Element -// OutputRow14Element -// OutputRow16Element -// -// Generates the code to store elements to the output block. -// - - .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr s8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr s9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - mov v8.S[2], v9.S[0] - - fadd v8.4s,v8.4s,v\Vec1Reg\().4s - - mov w27, v8.S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v8.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - mov w27, v\Vec1Reg\().S[0] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov w27, v\Vec1Reg\().S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - - mov v8.D[1], v9.D[0] - - fadd v8.4s,v8.4s,v\Vec1Reg\().4s - - mov x27, v8.D[0] - mov x28, v8.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - mov x27, v\Vec1Reg\().D[0] - mov x28, v\Vec1Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.endif - - .endm - - - .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr d8,[\AddrReg1\()],#8 - ldr w28,[\AddrReg1\()],#-8 - mov v8.S[2], w28 -.if \last_row\() == 0 - ldr d9,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-8 - mov v9.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - - mov x27, v8.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v8.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v9.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v9.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - mov x27, v4.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v4.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - mov x27, v5.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v5.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.endif - - .endm - - - .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - -.endif - - .endm - - - .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr w28, [\AddrReg1\()],#-16 - -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr w27,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - mov v8.S[0], w28 - mov v8.S[2], w27 - - fadd v8.4s,v8.4s,v\Vec3Reg\().4s - - mov w27, v8.S[0] - mov w28, v8.S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov w27, v\Vec3Reg\().S[0] - mov w28, v\Vec3Reg\().S[2] - - str w27, [\AddrReg1\()],#4 -.if \last_row\() == 0 - str w28, [\AddrReg2\()],#4 -.endif -.endif - -.endm - - - .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#-16 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#-16 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - mov v11.D[0],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - - str q8,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 -.endif - - mov v10.D[1], v11.D[0] - - fadd v10.4s,v10.4s,v\Vec3Reg\().4s - - mov x27, v10.D[0] - mov x28, v10.D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - - str q4,[\AddrReg1\()],#16 -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 -.endif - mov x27, v\Vec3Reg\().D[0] - mov x28, v\Vec3Reg\().D[1] - - str x27, [\AddrReg1\()],#8 -.if \last_row\() == 0 - str x28, [\AddrReg2\()],#8 -.endif -.endif - - .endm - - .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldr q8,[\AddrReg1\()],#16 - ldr d10,[\AddrReg1\()],#8 - ldr w28, [\AddrReg1\()],#-24 - mov v10.S[2], w28 -.if \last_row\() == 0 - ldr q9,[\AddrReg2\()],#16 - ldr d11,[\AddrReg2\()],#8 - ldr w27,[\AddrReg2\()],#-24 - mov v11.S[2], w27 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - fadd v10.4s,v10.4s,v6.4s - fadd v11.4s,v11.4s,v7.4s - - str q8,[\AddrReg1\()],#16 - - mov x27, v10.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v10.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q9,[\AddrReg2\()],#16 - mov x27, v11.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v11.S[2] - str w27, [\AddrReg2\()],#4 -.endif - -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - str q4,[\AddrReg1\()],#16 - mov x27, v6.D[0] - str x27, [\AddrReg1\()],#8 - mov w27, v6.S[2] - str w27, [\AddrReg1\()],#4 - -.if \last_row\() == 0 - str q5,[\AddrReg2\()],#16 - mov x27, v7.D[0] - str x27, [\AddrReg2\()],#8 - mov w27, v7.S[2] - str w27, [\AddrReg2\()],#4 -.endif -.endif - - .endm - - - .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row - -.ifeqs "\Mode\()","Add" - ldp q8,q10,[\AddrReg1\()],#0 -.if \last_row\() == 0 - ldp q9,q11,[\AddrReg2\()],#0 -.else - mov x27,#0 - mov v9.D[0],x27 - mov v9.D[1],x27 - - mov v11.D[0],x27 - mov v11.D[1],x27 -.endif - uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d - - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - fadd v8.4s,v8.4s,v4.4s - fadd v9.4s,v9.4s,v5.4s - fadd v10.4s,v10.4s,v6.4s - fadd v11.4s,v11.4s,v7.4s - - stp q8,q10,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q9,q11,[\AddrReg2\()],#32 -.endif -.else - uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d - uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d - - stp q4,q6,[\AddrReg1\()],#32 -.if \last_row\() == 0 - stp q5,q7,[\AddrReg2\()],#32 -.endif -.endif - - .endm - -// -// OutputBlock -// -// Generates the code to store the output block. -// - - .macro OutputBlock Mode, Columns, Rows - - OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) - -.if \Rows\() > 2 - OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) -.endif - -.if \Rows\() > 4 - OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) -.endif - -.if \Rows\() > 6 - OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) -.endif - - .endm -// -// ProcessRows -// -// Generates the code to process a compute and store the output block for a -// fixed number of rows. -// - - .macro ProcessRows Mode, Rows - mov x4,#\Rows\() // return number of rows handled - cmp x5,#6 - ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() - -.L\Mode\().ProcessNextColumnLoop8x\Rows\(): - ComputeBlockLoop \Mode\(),8,\Rows\() - - sub x5,x5,#8 - cmp x5,#0 - blt .L\Mode\().Output14ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),16,\Rows\() - mov x0,x26 // reload matrix A - cmp x5,#6 - bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() - cbz x5,.L\Mode\().ExitKernel - - -.L\Mode\().ProcessNextColumnLoop6x\Rows\(): - - cmp x5,#4 - ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() - ComputeBlockLoop \Mode\(),6,\Rows\() - sub x5,x5,#6 - cmp x5,#0 - blt .L\Mode\().Output10ElementsOnlyFor\Rows\() - OutputBlock \Mode\(),12,\Rows\() - - mov x0,x26 // reload matrix A - cmp x5,#4 - bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop4x\Rows\(): - cmp x5,#2 - ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() - ComputeBlockLoop \Mode\(),4,\Rows\() - sub x5,x5,#4 - cmp x5,#0 - blt .L\Mode\().Output6ElementsOnlyFor\Rows\() - - OutputBlock \Mode\(),8,\Rows\() - - mov x0,x26 // reload matrix A - cmp x5,#2 - bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() - b .L\Mode\().ExitKernel - -.L\Mode\().ProcessNextColumnLoop2x\Rows\(): - ComputeBlockLoop \Mode\(),2,\Rows\() - sub x5,x5,#2 - cmp x5,#0 - blt .L\Mode\().Output2ElementsOnlyFor\Rows\() - - OutputBlock \Mode\(),4,\Rows\() - - mov x0,x26 // reload matrix A - cmp x5,#2 - b .L\Mode\().ExitKernel - -.L\Mode\().Output14ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),14,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output10ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),10,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output6ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),6,\Rows\() - b .L\Mode\().ExitKernel - - -.L\Mode\().Output2ElementsOnlyFor\Rows\(): - OutputBlock \Mode\(),2,\Rows\() - b .L\Mode\().ExitKernel - - .endm - - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. - - C (x2) - Supplies the address of matrix C. - - CountK (x3) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda (x6) - Supplies the first dimension of matrix A. - - ldc (x7) - Supplies the first dimension of matrix C. - - Bias - Supplies the address of Bias Vector [1xn] - - -Return Value: - - Returns the number of rows handled. - ---*/ - .macro SbgemmKernelNeonFunction Mode - - FUNCTION_ENTRY MlasSbgemmKernel\Mode\() - - ldr x8, [sp, #0] //Bias vector - - stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! - stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] - stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] - stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] - stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] - stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] - stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] - stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] - stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] - - add x13,x2,x7,lsl #2 // compute matrix C plus 1 row - add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows - add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows - add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows - add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows - add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows - add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows - - mov x26,x0 // save matrix A -// -// Process 8 rows of the matrices. -// - cmp x4,#8 - blt .L\Mode\().ProcessCountMLessThan8 - ProcessRows \Mode\(),8 - -// -// Restore non-volatile registers and return. -// - -.L\Mode\().ExitKernel: - mov x0,x4 - - ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] - ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] - ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] - ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] - ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] - ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] - ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] - ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] - ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters - - ret - -// -// Process 4 rows of the matrix. -// - -.L\Mode\().ProcessCountMLessThan8: - cmp x4,#4 - blt .L\Mode\().ProcessCountMLessThan4 - ProcessRows \Mode\(),4 - b .L\Mode\().ExitKernel - -// -// Process 2 row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan4: - cmp x4,#2 - blt .L\Mode\().ProcessCountMLessThan2 - - ProcessRows \Mode\(),2 - b .L\Mode\().ExitKernel - - -// -// Process the last row of the matrix. -// - -.L\Mode\().ProcessCountMLessThan2: - ProcessRows \Mode\(),1 - b .L\Mode\().ExitKernel - - - .endm - - SbgemmKernelNeonFunction Zero - SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S deleted file mode 100644 index 26555b5c6480b..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SgemmKernelNeon.S +++ /dev/null @@ -1,482 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelNeon.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "asmmacro.h" - - .text - -// -// ClearRowAccumulators -// -// Generates the code to clear the accumulators for a single row of the output -// block. -// - - .macro ClearRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - movi v\Vec1Reg\().16b,#0 - movi v\Vec2Reg\().16b,#0 -.if \Columns\() > 8 - movi v\Vec3Reg\().16b,#0 - movi v\Vec4Reg\().16b,#0 -.endif - - .endm - -// -// ClearBlockAccumulators -// -// Generates the code to clear the accumulators for a single row of the output -// block. -// - - .macro ClearBlockAccumulators Columns, Rows - - ClearRowAccumulators \Columns\(),16,17,18,19 -.if \Rows\() >= 2 - ClearRowAccumulators \Columns\(),20,21,22,23 -.endif -.if \Rows\() >= 4 - ClearRowAccumulators \Columns\(),24,25,26,27 - ClearRowAccumulators \Columns\(),28,29,30,31 -.endif - - .endm - -// -// LoadMatrixAElementsBy4 -// LoadMatrixAElementsBy1 -// -// Generates the code to load 1 or 4 elements from matrix A. -// - - .macro LoadMatrixAElementsBy4 Rows - - ldr q8,[x0],#16 -.if \Rows\() >= 2 - ldr q9,[x10],#16 -.endif -.if \Rows\() >= 4 - ldr q10,[x11],#16 - ldr q11,[x12],#16 -.endif - - .endm - - .macro LoadMatrixAElementsBy1 Rows - - ldr s8,[x0],#4 -.if \Rows\() >= 2 - ldr s9,[x10],#4 -.endif -.if \Rows\() >= 4 - ldr s10,[x11],#4 - ldr s11,[x12],#4 -.endif - - .endm - -// -// MultiplyAccumulateRow -// -// Generates the code to multiply and accumulate a single row of the output -// block. -// - - .macro MultiplyAccumulateRow Columns, MatrixAReg, Broadcast, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - - fmla v\Vec1Reg\().4s,v4.4s,\MatrixAReg\().s[\Broadcast\()] - fmla v\Vec2Reg\().4s,v5.4s,\MatrixAReg\().s[\Broadcast\()] -.if \Columns\() > 8 - fmla v\Vec3Reg\().4s,v6.4s,\MatrixAReg\().s[\Broadcast\()] - fmla v\Vec4Reg\().4s,v7.4s,\MatrixAReg\().s[\Broadcast\()] -.endif - - .endm - -// -// MultiplyAccumulateBlock -// -// Generates the code to multiply and accumulate into the output block. -// - - .macro MultiplyAccumulateBlock Columns, Rows, Broadcast - - MultiplyAccumulateRow \Columns\(),v8,\Broadcast\(),16,17,18,19 -.if \Rows\() >= 2 - MultiplyAccumulateRow \Columns\(),v9,\Broadcast\(),20,21,22,23 -.endif -.if \Rows\() >= 4 - MultiplyAccumulateRow \Columns\(),v10,\Broadcast\(),24,25,26,27 - MultiplyAccumulateRow \Columns\(),v11,\Broadcast\(),28,29,30,31 -.endif - - .endm - -// -// ComputeBlockLoop -// -// Generates the code to loop over K entries of the input matrices to produce -// the output block. -// - - .macro ComputeBlockLoop Mode, Columns, Rows - - ClearBlockAccumulators \Columns\(),\Rows\() - -.if \Rows\() >= 2 - add x10,x0,x6,lsl #2 // compute matrix A plus 1 row -.endif -.if \Rows\() >= 4 - add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows - add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows -.endif - - sub x9,x3,#4 // decrement block count to process - tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: - LoadMatrixAElementsBy4 \Rows\() - ldp q4,q5,[x1],#64*4 -.if \Columns\() > 8 - ldp q6,q7,[x1,#-56*4] -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\(),0 - ldp q4,q5,[x1,#-48*4] -.if \Columns\() > 8 - ldp q6,q7,[x1,#-40*4] -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\(),1 - ldp q4,q5,[x1,#-32*4] -.if \Columns\() > 8 - ldp q6,q7,[x1,#-24*4] -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\(),2 - ldp q4,q5,[x1,#-16*4] -.if \Columns\() > 8 - ldp q6,q7,[x1,#-8*4] -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\(),3 - sub x9,x9,#4 - tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop - -.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: - add x9,x9,#4 // correct for over-subtract above - cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block - -.L\Mode\().Compute\Columns\().x\Rows\().BlockBy1Loop: - LoadMatrixAElementsBy1 \Rows\() - ldp q4,q5,[x1],#16*4 -.if \Columns\() > 8 - ldp q6,q7,[x1,#-8*4] -.endif - MultiplyAccumulateBlock \Columns\(),\Rows\(),0 - sub x9,x9,#1 - cbnz x9,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy1Loop - -.L\Mode\().Output\Columns\().x\Rows\().Block: - - .endm - -// -// MultiplyAlphaRow -// -// Generates the code to multiply a single row of the output block by the alpha -// value. -// - - .macro MultiplyAlphaRow Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.if \Columns\() <= 4 - fmul v\Vec1Reg\().4s,v\Vec1Reg\().4s,v0.s[0] -.elif \Columns\() <= 8 - fmul v\Vec1Reg\().4s,v\Vec1Reg\().4s,v0.s[0] - fmul v\Vec2Reg\().4s,v\Vec2Reg\().4s,v0.s[0] -.elif \Columns\() <= 12 - fmul v\Vec1Reg\().4s,v\Vec1Reg\().4s,v0.s[0] - fmul v\Vec2Reg\().4s,v\Vec2Reg\().4s,v0.s[0] - fmul v\Vec3Reg\().4s,v\Vec3Reg\().4s,v0.s[0] -.else - fmul v\Vec1Reg\().4s,v\Vec1Reg\().4s,v0.s[0] - fmul v\Vec2Reg\().4s,v\Vec2Reg\().4s,v0.s[0] - fmul v\Vec3Reg\().4s,v\Vec3Reg\().4s,v0.s[0] - fmul v\Vec4Reg\().4s,v\Vec4Reg\().4s,v0.s[0] -.endif - - .endm - -// -// MultiplyAlphaBlock -// -// Generates the code to multiply the output block by the alpha value. -// - - .macro MultiplyAlphaBlock Columns, Rows - - MultiplyAlphaRow \Columns\(),16,17,18,19 -.if \Rows\() >= 2 - MultiplyAlphaRow \Columns\(),20,21,22,23 -.endif -.if \Rows\() >= 4 - MultiplyAlphaRow \Columns\(),24,25,26,27 - MultiplyAlphaRow \Columns\(),28,29,30,31 -.endif - - .endm - -// -// OutputRow1Element -// OutputRow2Element -// OutputRow4Element -// OutputRow8Element -// OutputRow16Element -// -// Generates the code to store elements to the output block. -// - - .macro OutputRow1Element Mode, AddrReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.ifeqs "\Mode\()","Add" - ld1 {v4.s}[0],[\AddrReg\()] - fmla v4.2s,v\Vec1Reg\().2s,v0.s[0] - st1 {v4.s}[0],[\AddrReg\()] // post-increment not needed for last element -.else - st1 {v\Vec1Reg\().s}[0],[\AddrReg\()]// post-increment not needed for last element -.endif - - .endm - - .macro OutputRow2Element Mode, AddrReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.ifeqs "\Mode\()","Add" - ld1 {v4.2s},[\AddrReg\()] - fmla v4.2s,v\Vec1Reg\().2s,v0.s[0] - st1 {v4.2s},[\AddrReg\()],#2*4 -.else - st1 {v\Vec1Reg\().2s},[\AddrReg\()],#2*4 -.endif - dup v\Vec1Reg\().4s,v\Vec1Reg\().s[2] // shift remaining elements down - - .endm - - .macro OutputRow4Element Mode, AddrReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.ifeqs "\Mode\()","Add" - ld1 {v4.4s},[\AddrReg\()] - fmla v4.4s,v\Vec1Reg\().4s,v0.s[0] - st1 {v4.4s},[\AddrReg\()],#4*4 -.else - st1 {v\Vec1Reg\().4s},[\AddrReg\()],#4*4 -.endif - mov v\Vec1Reg\().16b,v\Vec2Reg\().16b // shift remaining elements down - - .endm - - .macro OutputRow8Element Mode, AddrReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.ifeqs "\Mode\()","Add" - ldp q4,q5,[\AddrReg\()] - fmla v4.4s,v\Vec1Reg\().4s,v0.s[0] - fmla v5.4s,v\Vec2Reg\().4s,v0.s[0] - stp q4,q5,[\AddrReg\()],#8*4 -.else - stp q\Vec1Reg\(),q\Vec2Reg\(),[\AddrReg\()],#8*4 -.endif - mov v\Vec1Reg\().16b,v\Vec3Reg\().16b // shift remaining elements down - mov v\Vec2Reg\().16b,v\Vec4Reg\().16b - - .endm - - .macro OutputRow16Element Mode, AddrReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg - -.ifeqs "\Mode\()","Add" - ldp q4,q5,[\AddrReg\()] - ldp q6,q7,[\AddrReg\(),#8*4] - fmla v4.4s,v\Vec1Reg\().4s,v0.s[0] - fmla v5.4s,v\Vec2Reg\().4s,v0.s[0] - fmla v6.4s,v\Vec3Reg\().4s,v0.s[0] - fmla v7.4s,v\Vec4Reg\().4s,v0.s[0] - stp q4,q5,[\AddrReg\()],#16*4 - stp q6,q7,[\AddrReg\(),#-8*4] -.else - stp q\Vec1Reg\(),q\Vec2Reg\(),[\AddrReg\()],#16*4 - stp q\Vec3Reg\(),q\Vec4Reg\(),[\AddrReg\(),#-8*4] -.endif - - .endm - -// -// OutputBlock -// -// Generates the code to store the output block. -// - - .macro OutputBlock Mode, Columns, Rows - - OutputRow\Columns\()Element \Mode\(),x2,16,17,18,19 -.if \Rows\() >= 2 - OutputRow\Columns\()Element \Mode\(),x13,20,21,22,23 -.endif -.if \Rows\() >= 4 - OutputRow\Columns\()Element \Mode\(),x14,24,25,26,27 - OutputRow\Columns\()Element \Mode\(),x15,28,29,30,31 -.endif - - .endm - -// -// ProcessRows -// -// Generates the code to process a compute and store the output block for a -// fixed number of rows. -// - - .macro ProcessRows Mode, Rows - - mov x4,#\Rows\() // return number of rows handled - cmp x5,#8 - ble .L\Mode\().ProcessRemainingCountN\Rows\() - -.L\Mode\().ProcessNextColumnLoop16x\Rows\(): - ComputeBlockLoop \Mode\(),16,\Rows\() -.ifeqs "\Mode\()","Zero" - MultiplyAlphaBlock 16,\Rows\() -.endif - sub x5,x5,#16 - tbnz x5,#63,.L\Mode\().OutputMasked16x\Rows\().Block - OutputBlock \Mode\(),16,\Rows\() - mov x0,x8 // reload matrix A - cmp x5,#8 - bgt .L\Mode\().ProcessNextColumnLoop16x\Rows\() - cbz x5,.L\Mode\().ExitKernel - -.L\Mode\().ProcessRemainingCountN\Rows\(): - ComputeBlockLoop \Mode\(),8,\Rows\() -.ifeqs "\Mode\()","Zero" - MultiplyAlphaBlock 8,\Rows\() -.endif - -.L\Mode\().OutputMasked16x\Rows\().Block: - tbz x5,#3,.L\Mode\().OutputRemaining7x\Rows\().Block - OutputBlock \Mode\(),8,\Rows\() - -.L\Mode\().OutputRemaining7x\Rows\().Block: - tbz x5,#2,.L\Mode\().OutputRemaining3x\Rows\().Block - OutputBlock \Mode\(),4,\Rows\() - -.L\Mode\().OutputRemaining3x\Rows\().Block: - tbz x5,#1,.L\Mode\().OutputRemaining1x\Rows\().Block - OutputBlock \Mode\(),2,\Rows\() - -.L\Mode\().OutputRemaining1x\Rows\().Block: - tbz x5,#0,.L\Mode\().ExitKernel - OutputBlock \Mode\(),1,\Rows\() - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (x2) - Supplies the address of matrix C. - - CountK (x3) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda (x6) - Supplies the first dimension of matrix A. - - ldc (x7) - Supplies the first dimension of matrix C. - - Alpha (s0) - Supplies the scalar multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ - - .macro SgemmKernelNeonFunction Mode - - FUNCTION_ENTRY MlasSgemmKernel\Mode\() - - stp d8,d9,[sp,#-32]! - stp d10,d11,[sp,#16] - - add x13,x2,x7,lsl #2 // compute matrix C plus 1 row - add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows - add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows - mov x8,x0 // save matrix A - -// -// Process 4 rows of the matrices. -// - - cmp x4,#4 - blt .L\Mode\().ProcessCountMLessThan4 - ProcessRows \Mode\(),4 - -// -// Restore non-volatile registers and return. -// - -.L\Mode\().ExitKernel: - mov x0,x4 - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#32 - ret - -// -// Process 2 rows of the matrices. -// - -.L\Mode\().ProcessCountMLessThan4: - cmp x4,#2 - blt .L\Mode\().ProcessCountMLessThan2 - ProcessRows \Mode\(),2 - b .L\Mode\().ExitKernel - -// -// Process 1 row of the matrices. -// - -.L\Mode\().ProcessCountMLessThan2: - ProcessRows \Mode\(),1 - b .L\Mode\().ExitKernel - - .endm - - SgemmKernelNeonFunction Zero - SgemmKernelNeonFunction Add - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SgemvKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SgemvKernelNeon.S deleted file mode 100644 index c935f06255604..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SgemvKernelNeon.S +++ /dev/null @@ -1,303 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemvKernelNeon.s - -Abstract: - - This module implements the kernels for the single precision matrix/vector - multiply operation (SGEMV). - ---*/ - -#include "asmmacro.h" - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. This handles the special case of M=1. - - The elements in matrix B are not transposed. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. - - C (x2) - Supplies the address of matrix C. - - CountK (x3) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountN (x4) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb (x5) - Supplies the first dimension of matrix B. - - ZeroMode (x6) - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasGemvFloatKernel - - cmp x4,#64 - blo .LSgemvN.ProcessRemainingCountN - mov x14,x0 // preserve vector A - -// -// Process 64 columns at a time in a loop. -// - -.LSgemvN.ProcessColumnLoopBy64: - ldr q4,[x1] - add x15,x1,#256 // compute next matrix B - ldr q5,[x1,#16] - tst w6,0xFF // ZeroMode? - mov x13,x3 // reload CountK - ldr q6,[x1,#32] - beq .LSgemvN.LoadOutputBy64 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - movi v20.4s,#0 - movi v21.4s,#0 - movi v22.4s,#0 - movi v23.4s,#0 - movi v24.4s,#0 - movi v25.4s,#0 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - b .LSgemvN.MultiplyAccumulateBy64 - -.LSgemvN.LoadOutputBy64: - ldp q16,q17,[x2] - ldp q18,q19,[x2,#32] - ldp q20,q21,[x2,#64] - ldp q22,q23,[x2,#96] - ldp q24,q25,[x2,#128] - ldp q26,q27,[x2,#160] - ldp q28,q29,[x2,#192] - ldp q30,q31,[x2,#224] - -.LSgemvN.MultiplyAccumulateBy64: - ld1r {v0.4s},[x0] // broadcast next vector A element - add x0,x0,4 // advance vector A by 1 element - sub x13,x13,#1 // decrement K remaining - fmla v16.4s,v4.4s,v0.4s - ldr q7,[x1,#48] - fmla v17.4s,v5.4s,v0.4s - ldr q4,[x1,#64] - fmla v18.4s,v6.4s,v0.4s - ldr q5,[x1,#80] - fmla v19.4s,v7.4s,v0.4s - ldr q6,[x1,#96] - fmla v20.4s,v4.4s,v0.4s - ldr q7,[x1,#112] - fmla v21.4s,v5.4s,v0.4s - ldr q4,[x1,#128] - fmla v22.4s,v6.4s,v0.4s - ldr q5,[x1,#144] - fmla v23.4s,v7.4s,v0.4s - ldr q6,[x1,#160] - fmla v24.4s,v4.4s,v0.4s - ldr q7,[x1,#176] - fmla v25.4s,v5.4s,v0.4s - ldr q4,[x1,#192] - fmla v26.4s,v6.4s,v0.4s - ldr q5,[x1,#208] - fmla v27.4s,v7.4s,v0.4s - ldr q6,[x1,#224] - fmla v28.4s,v4.4s,v0.4s - ldr q7,[x1,#240] - add x1,x1,x5,lsl #2 // compute next matrix B row address - cbz x13,.LSgemvN.StoreOutputBy64 - ldr q4,[x1] // load data for next iteration - fmla v29.4s,v5.4s,v0.4s - ldr q5,[x1,#16] - fmla v30.4s,v6.4s,v0.4s - ldr q6,[x1,#32] - fmla v31.4s,v7.4s,v0.4s - b .LSgemvN.MultiplyAccumulateBy64 - -.LSgemvN.StoreOutputBy64: - stp q16,q17,[x2] - fmla v29.4s,v5.4s,v0.4s // finish computing tail vectors - stp q18,q19,[x2,#32] - fmla v30.4s,v6.4s,v0.4s - stp q20,q21,[x2,#64] - fmla v31.4s,v7.4s,v0.4s - stp q22,q23,[x2,#96] - sub x4,x4,#64 // subtract 64 columns - stp q24,q25,[x2,#128] - mov x0,x14 // reload vector A - stp q26,q27,[x2,#160] - mov x1,x15 // load next matrix B - stp q28,q29,[x2,#192] - stp q30,q31,[x2,#224] - add x2,x2,#256 // advance vector C by 64 columns - cbz x4,.LSgemvN.ExitKernel - cmp x4,#64 - bhs .LSgemvN.ProcessColumnLoopBy64 - -// -// Process the remaining 1 to 63 columns. -// - -.LSgemvN.ProcessRemainingCountN: - tst w6,0xFF // ZeroMode? - beq .LSgemvN.LoadOutputPartial32 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - movi v20.4s,#0 - movi v21.4s,#0 - movi v22.4s,#0 - movi v23.4s,#0 - movi v24.4s,#0 - movi v25.4s,#0 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 // trailing float[2] - movi v1.4s,#0 // trailing float[1] - b .LSgemvN.ProcessNextPartialRow - -.LSgemvN.LoadOutputPartial32: - mov x15,x2 - tbz x4,#5,.LSgemvN.LoadOutputPartial16 - ldp q16,q17,[x15],#128 - ldp q18,q19,[x15,#-96] - ldp q20,q21,[x15,#-64] - ldp q22,q23,[x15,#-32] - -.LSgemvN.LoadOutputPartial16: - tbz x4,#4,.LSgemvN.LoadOutputPartial8 - ldp q24,q25,[x15],#64 - ldp q26,q27,[x15,#-32] - -.LSgemvN.LoadOutputPartial8: - tbz x4,#3,.LSgemvN.LoadOutputPartial4 - ldp q28,q29,[x15],#32 - -.LSgemvN.LoadOutputPartial4: - tbz x4,#2,.LSgemvN.LoadOutputPartial2 - ldr q30,[x15],#16 - -.LSgemvN.LoadOutputPartial2: - tbz x4,#1,.LSgemvN.LoadOutputPartial1 - ldr d31,[x15],#8 - -.LSgemvN.LoadOutputPartial1: - tbz x4,#0,.LSgemvN.ProcessNextPartialRow - ldr s1,[x15] - -.LSgemvN.ProcessNextPartialRow: - ld1r {v0.4s},[x0] - add x0,x0,4 - sub x3,x3,#1 // decrement K remaining - mov x15,x1 - -.LSgemvN.MultiplyAccumulatePartial32: - tbz x4,#5,.LSgemvN.MultiplyAccumulatePartial16 - ldp q4,q5,[x15],#128 - fmla v16.4s,v4.4s,v0.4s - ldp q6,q7,[x15,#-96] - fmla v17.4s,v5.4s,v0.4s - ldp q4,q5,[x15,#-64] - fmla v18.4s,v6.4s,v0.4s - fmla v19.4s,v7.4s,v0.4s - ldp q6,q7,[x15,#-32] - fmla v20.4s,v4.4s,v0.4s - fmla v21.4s,v5.4s,v0.4s - fmla v22.4s,v6.4s,v0.4s - fmla v23.4s,v7.4s,v0.4s - -.LSgemvN.MultiplyAccumulatePartial16: - tbz x4,#4,.LSgemvN.MultiplyAccumulatePartial8 - ldp q4,q5,[x15],#64 - fmla v24.4s,v4.4s,v0.4s - ldp q6,q7,[x15,#-32] - fmla v25.4s,v5.4s,v0.4s - fmla v26.4s,v6.4s,v0.4s - fmla v27.4s,v7.4s,v0.4s - -.LSgemvN.MultiplyAccumulatePartial8: - tbz x4,#3,.LSgemvN.MultiplyAccumulatePartial4 - ldp q4,q5,[x15],#32 - fmla v28.4s,v4.4s,v0.4s - fmla v29.4s,v5.4s,v0.4s - -.LSgemvN.MultiplyAccumulatePartial4: - tbz x4,#2,.LSgemvN.MultiplyAccumulatePartial2 - ldr q4,[x15],#16 - fmla v30.4s,v4.4s,v0.4s - -.LSgemvN.MultiplyAccumulatePartial2: - tbz x4,#1,.LSgemvN.MultiplyAccumulatePartial1 - ldr d4,[x15],#8 - fmla v31.4s,v4.4s,v0.4s - -.LSgemvN.MultiplyAccumulatePartial1: - tbz x4,#0,.LSgemvN.AdvancePartialRow - ldr s4,[x15] - fmla v1.4s,v4.4s,v0.4s - -.LSgemvN.AdvancePartialRow: - add x1,x1,x5,lsl #2 // compute next matrix B row address - cbnz x3,.LSgemvN.ProcessNextPartialRow - -.LSgemvN.StoreOutputPartial32: - tbz x4,#5,.LSgemvN.StoreOutputPartial16 - stp q16,q17,[x2],#128 - stp q18,q19,[x2,#-96] - stp q20,q21,[x2,#-64] - stp q22,q23,[x2,#-32] - -.LSgemvN.StoreOutputPartial16: - tbz x4,#4,.LSgemvN.StoreOutputPartial8 - stp q24,q25,[x2],#64 - stp q26,q27,[x2,#-32] - -.LSgemvN.StoreOutputPartial8: - tbz x4,#3,.LSgemvN.StoreOutputPartial4 - stp q28,q29,[x2],#32 - -.LSgemvN.StoreOutputPartial4: - tbz x4,#2,.LSgemvN.StoreOutputPartial2 - str q30,[x2],#16 - -.LSgemvN.StoreOutputPartial2: - tbz x4,#1,.LSgemvN.StoreOutputPartial1 - str d31,[x2],#8 - -.LSgemvN.StoreOutputPartial1: - tbz x4,#0,.LSgemvN.ExitKernel - str s1,[x2] - -.LSgemvN.ExitKernel: - ret - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelNeon.S deleted file mode 100644 index f236ef4ed1742..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelNeon.S +++ /dev/null @@ -1,536 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelNeon.S - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - ---*/ - -#include "asmmacro.h" - -// -// Stack frame layout for the S8S8 kernel. -// d8-d15, x19-x30 need to be preserved if used -// - - .equ .LSQGemmS8Frame_SavedRegisters, (8 * 8) - .equ .LSQGemmS8Frame_ColumnSumBuffer, 0 + .LSQGemmS8Frame_SavedRegisters - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasSymQgemmS8KernelNeon - - stp d8,d9,[sp,#-.LSQGemmS8Frame_SavedRegisters]! - stp d10,d11,[sp,#16] - stp d12,d13,[sp,#32] - stp d14,d15,[sp,#48] - ldr x13,[sp,#.LSQGemmS8Frame_ColumnSumBuffer] - mov x14,x0 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq M1_ProcessLoop - cmp x4,#4 // CountM < 4? - blo M2_ProcessLoop - -// -// Process 4 rows of the matrices. -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v8.b[0] v9.b[0] v10.b[0] v11.b[0]| -// | ... ... ... ... | -// |v8.b[7] v9.b[7] v10.b[7] v11.b[7]| -// A 4x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v24.4s v25.4s v26.4s v27.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v28.4s v29.4s v30.4s v31.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// - -M4_ProcessNextColumnLoop: - mov x0,x14 // reload matrix A0 - mov x3,x15 // reload PackedCountK - ldr d0,[x0],#8 // Load A0 - add x9,x14,x7 // A1 - ldr d2,[x0],#8 // Load A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d4,d8,[x1],#64 // B - movi v18.4s,#0 - movi v19.4s,#0 - ldp d5,d9,[x1,#-48] - movi v20.4s,#0 - movi v21.4s,#0 - ldp d6,d10,[x1,#-32] - movi v22.4s,#0 - movi v23.4s,#0 - ldp d7,d11,[x1,#-16] - movi v24.4s,#0 - movi v25.4s,#0 - add x10,x9,x7 // A2 - ldp d1,d3,[x9],#16 // Load A1 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - add x11,x10,x7 // A3 - -M4_ComputeBlockLoop: - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x10],#16 // Load A2 - sadalp v16.4s,v12.8h - sadalp v17.4s,v13.8h - sadalp v18.4s,v14.8h - sadalp v19.4s,v15.8h - sub x3,x3,#1 - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - ldp d1,d3,[x11],#16 // Load A3 - sadalp v20.4s,v12.8h - sadalp v21.4s,v13.8h - sadalp v22.4s,v14.8h - sadalp v23.4s,v15.8h - cbz x3,M4_ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0],#16 // Load A0 next iter - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - ldp d4,d8,[x1],#64 // B - smlal v13.8h,v3.8b,v9.8b - ldp d5,d9,[x1,#-48] - smlal v14.8h,v3.8b,v10.8b - ldp d6,d10,[x1,#-32] - smlal v15.8h,v3.8b,v11.8b - ldp d7,d11,[x1,#-16] - sadalp v28.4s,v12.8h - ldp d1,d3,[x9],#16 // Load A1 next iter - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b M4_ComputeBlockLoop - -M4_ComputeBlockLoopFinish: - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ld1 {v2.4s},[x13],#16 // load ColumnSumBuffer[0] - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v24.4s,v24.4s,v25.4s - addp v26.4s,v26.4s,v27.4s - addp v28.4s,v28.4s,v29.4s - addp v30.4s,v30.4s,v31.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - - // accumulator += column sum B - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - -M4_StoreOutput: - add x10,x2,x6,lsl #2 - add x11,x10,x6,lsl #2 - add x12,x11,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo M4_StoreOutputPartial - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - st1 {v24.4s},[x11] - st1 {v28.4s},[x12] - cbnz x5,M4_ProcessNextColumnLoop - -M4_ExitKernel: - mov x0,#4 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LSQGemmS8Frame_SavedRegisters - ret - -M4_StoreOutputPartial: - -M4_StoreOutputPartial_ZeroMode: - tbz x5,#1,M4_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -M4_StoreOutputPartial1_ZeroMode: - tbz x5,#0,M4_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b M4_ExitKernel - -// -// Process 2 rows of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// Each row sum replicated to all 4 elements of a vector register -// v30 v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 2x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) - -M2_ProcessLoop: - -M2_ProcessNextColumnLoop: - ldp d4,d24,[x1],#16 // B - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // Load A0 - add x9,x14,x7 // A1 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d5,d25,[x1],#16 - movi v18.4s,#0 - movi v19.4s,#0 - ldp d6,d26,[x1],#16 - movi v20.4s,#0 - movi v21.4s,#0 - ldp d7,d27,[x1],#16 - movi v22.4s,#0 - movi v23.4s,#0 - ldp d1,d3,[x9],#16 // Load A1 - -M2_ComputeBlockLoop: - sub x3,x3,#1 - smull v28.8h,v0.8b,v4.8b - smull v29.8h,v0.8b,v5.8b - smull v30.8h,v0.8b,v6.8b - smull v31.8h,v0.8b,v7.8b - cbz x3,M2_ComputeBlockLoopFinish - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // Load A0 - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v29.8h,v3.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v30.8h,v3.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v31.8h,v3.8b,v27.8b - ldp d7,d27,[x1],#16 - sadalp v20.4s,v28.8h - ldp d1,d3,[x9],#16 // Load A1 - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - b M2_ComputeBlockLoop - -M2_ComputeBlockLoopFinish: - ld1 {v0.4s},[x13],#16 // load ColumnSumBuffer[0] - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - smlal v29.8h,v3.8b,v25.8b - smlal v30.8h,v3.8b,v26.8b - smlal v31.8h,v3.8b,v27.8b - sadalp v20.4s,v28.8h - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - - // accumulator = column sum B - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - -M2_StoreOutput: - add x10,x2,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo M2_StoreOutputPartial - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - cbnz x5,M2_ProcessNextColumnLoop - -M2_ExitKernel: - mov x0,#2 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LSQGemmS8Frame_SavedRegisters - ret - -M2_StoreOutputPartial: - -M2_StoreOutputPartial_ZeroMode: - tbz x5,#1,M2_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -M2_StoreOutputPartial1_ZeroMode: - tbz x5,#0,M2_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b M2_ExitKernel - -// -// Process 1 row of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// row sum replicated to all 4 elements of a vector register -// v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 1x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -M1_ProcessLoop: - -M1_ProcessNextColumnLoop: - ldp d4,d24,[x1],#16 // B - ldp d5,d25,[x1],#16 - ldp d6,d26,[x1],#16 - ldp d7,d27,[x1],#16 - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - -M1_ComputeBlockLoop: - sub x3,x3,#1 - smull v20.8h,v0.8b,v4.8b - smull v21.8h,v0.8b,v5.8b - cbz x3,M1_ComputeBlockLoopFinish - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v21.8h,v2.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v22.8h,v2.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v23.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - ldp d7,d27,[x1],#16 - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - b M1_ComputeBlockLoop - -M1_ComputeBlockLoopFinish: - ld1 {v4.4s},[x13],#16 // load ColumnSumBuffer[0] - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - smlal v21.8h,v2.8b,v25.8b - smlal v22.8h,v2.8b,v26.8b - smlal v23.8h,v2.8b,v27.8b - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v16.4s,v16.4s,v18.4s - - // accumulator += column sum B - add v16.4s,v16.4s,v4.4s - -M1_StoreOutput: - subs x5,x5,#4 // adjust CountN remaining - blo M1_StoreOutputPartial - st1 {v16.4s},[x2],#16 - cbnz x5,M1_ProcessNextColumnLoop - -M1_ExitKernel: - mov x0,#1 // return number of rows handled - ldp d14,d15,[sp,#48] - ldp d12,d13,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LSQGemmS8Frame_SavedRegisters - ret - -M1_StoreOutputPartial: - -M1_StoreOutputPartial_ZeroMode: - tbz x5,#1,M1_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -M1_StoreOutputPartial1_ZeroMode: - tbz x5,#0,M1_ExitKernel - st1 {v16.s}[0],[x2] - b M1_ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdot.S b/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdot.S deleted file mode 100644 index be4b2d3c22e51..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdot.S +++ /dev/null @@ -1,389 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelSdot.S - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - - This version utilizes dot product instructions, and uses 128b loads - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout d8-d15, x19-x30 need to be preserved if used -// - .equ .LGemmS8S8KernelFrame_SavedRegisters, (4 * 8) - .equ .LGemmS8S8KernelFrame_ColumnSumBuffer, (0 + .LGemmS8S8KernelFrame_SavedRegisters) - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - Packed K should be 16x - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasSymQgemmS8KernelSdot - - stp d8,d9,[sp,#-.LGemmS8S8KernelFrame_SavedRegisters]! - ldr x8,[sp,#.LGemmS8S8KernelFrame_ColumnSumBuffer] - stp d10,d11,[sp,#16] - - // compute C pointers: x2, x16, x17, x6 - cmp x4,#2 // M < 2 ? - add x16,x2,x6,lsl #2 // x16 -> C1 - add x17,x2,x6,lsl #3 // x17 -> C2 - csel x16,x2,x16,lo // if M < 2 x16/C1 -> C0 - csel x17,x16,x17,ls // if M <= 2 x17/C2 -> C1 - cmp x4,#4 // M < 4 ? - mov x12,#4 // set max M to 4 - add x6,x16,x6,lsl #3 // x6 -> C3 - mov x9,x0 // save A0 - mov x10,x3 // save K - csel x6,x17,x6,lo // if M < 4 x6/C3 -> C2 - csel x4,x12,x4,hi // if M > 4 M = 4; - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x0 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x12 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x13 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v22.s[3] v26.s[0]_v26.s[3] v30.s[0]_v30.s[3]| x17 -// x14 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x6 -// ------------------ ---------------------------------------------------------------------------- - -ProcessNextColumnLoop: - ldr q16,[x8],#16 // Init accumulators with column sums - ldr q20,[x8],#16 - ldr q24,[x8],#16 - ldr q28,[x8],#16 - mov x0,x9 // reload A0 - cmp x4,#2 // M < 2 ? - add x12,x9,x7 // x12 -> A1 - add x13,x0,x7,lsl #1 // x13 -> A2 - ldr q4,[x1],#16 // Load B - csel x12,x0,x12,lo // if M < 2 A1 -> A0 - csel x13,x12,x13,ls // if M <= 2 A2 -> A1 - cmp x4,4 // M < 4 ? - add x14,x12,x7,lsl #1 // x14 -> A3 - ldr q5,[x1],#16 - csel x14,x13,x14,lo // if M < 4 A3 -> A2 - ldr d0,[x0],#8 // Load A0 1st/2nd block of 4 - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldr d1,[x12],#8 // Load A1 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - ldr d2,[x13],#8 // Load A2 - mov v22.16b,v20.16b - mov v23.16b,v20.16b - ldr d3,[x14],#8 // Load A3 - mov v25.16b,v24.16b - mov v26.16b,v24.16b - ldr q6,[x1],#16 - mov v27.16b,v24.16b - mov v29.16b,v28.16b - ldr q7,[x1],#16 - subs x3,x10,#2 // one loop iteration and epilogue consume k = 32 - mov v30.16b,v28.16b - mov v31.16b,v28.16b - b.lo BlockLoopEpilogue // Need 32 k for main loop - -BlockLoop: - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x0],#8 // Load A0 3rd/4th block of 4 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],#16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x12],#8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],#16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x13],#8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],#16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x14],#8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],#16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],#16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],#16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],#16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],#16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - ldr d0,[x0],#8 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],#16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - ldr d1,[x12],#8 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],#16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - ldr d2,[x13],#8 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],#16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - ldr d3,[x14],#8 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],#16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr q4,[x1],#16 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr q5,[x1],#16 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - ldr q6,[x1],#16 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x3,x3,#1 // k -= 16 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr q7,[x1],#16 - b.hs BlockLoop - -BlockLoopEpilogue: - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x0],#8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],#16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x12],#8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],#16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x13],#8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],#16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x14],#8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],#16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],#16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],#16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],#16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],#16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],#16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],#16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],#16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],#16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x5,x5,#16 // adjust CountN remaining - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - blo StoreOutputPartial - stp q16,q20,[x2],#32 - stp q24,q28,[x2],#32 - stp q17,q21,[x16],#32 - stp q25,q29,[x16],#32 - stp q18,q22,[x17],#32 - stp q26,q30,[x17],#32 - stp q19,q23,[x6],#32 - stp q27,q31,[x6],#32 - cbnz x5,ProcessNextColumnLoop - -ExitKernel: - mov x0,x4 // return number of rows handled - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LGemmS8S8KernelFrame_SavedRegisters - ret - -// -// Store the partial 1 to 15 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartial: - tbz x5,#3,StoreOutputPartial4 - stp q16,q20,[x2],#32 - mov v16.16b,v24.16b // shift remaining elements down - mov v20.16b,v28.16b - stp q17,q21,[x16],#32 - mov v17.16b,v25.16b - mov v21.16b,v29.16b - stp q18,q22,[x17],#32 - mov v18.16b,v26.16b - mov v22.16b,v30.16b - stp q19,q23,[x6],#32 - mov v19.16b,v27.16b - mov v23.16b,v31.16b - -StoreOutputPartial4: - tbz x5,#2,StoreOutputPartial2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v20.16b // shift remaining elements down - st1 {v17.4s},[x16],#16 - mov v17.16b,v21.16b - st1 {v18.4s},[x17],#16 - mov v18.16b,v22.16b - st1 {v19.4s},[x6],#16 - mov v19.16b,v23.16b - -StoreOutputPartial2: - tbz x5,#1,StoreOutputPartial1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v17.2s},[x16],#8 - dup v17.4s,v17.s[2] - st1 {v18.2s},[x17],#8 - dup v18.4s,v18.s[2] - st1 {v19.2s},[x6],#8 - dup v19.4s,v19.s[2] - -StoreOutputPartial1: - tbz x5,#0,ExitKernel - st1 {v16.s}[0],[x2] - st1 {v17.s}[0],[x16] - st1 {v18.s}[0],[x17] - st1 {v19.s}[0],[x6] - b ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdotLd64.S b/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdotLd64.S deleted file mode 100644 index 80bbb49ac83a7..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/SymQgemmS8KernelSdotLd64.S +++ /dev/null @@ -1,452 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelSdot.S - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - - This version utilizes dot product instructions, and uses only 64b loads that performs - better on cores with narrow memory interface such as A55 - ---*/ - -#include "asmmacro.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout, d8-d15, x19-x30 need to be preserved if used -// - .equ .LGemmS8S8KernelFrame_SavedRegisters, (6 * 8) - .equ .LGemmS8S8KernelFrame_ColumnSumBuffer, (0 + .LGemmS8S8KernelFrame_SavedRegisters) - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - Packed K should be 16x - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasSymQgemmS8KernelSdotLd64 - - stp d8,d9,[sp,#-.LGemmS8S8KernelFrame_SavedRegisters]! - ldr x8,[sp,#.LGemmS8S8KernelFrame_ColumnSumBuffer] - cmp x4,#2 // M < 2 ? - stp d10,d11,[sp,#16] - add x16,x2,x6,lsl #2 // x16 -> C1 - add x17,x2,x6,lsl #3 // x17 -> C2 - stp x20,x21,[sp,#32] - csel x16,x2,x16,lo // if M < 2 x16/C1 -> C0 - mov x12,#4 // set max M to 4 - csel x17,x16,x17,ls // if M <= 2 x17/C2 -> C1 - cmp x4,#4 // M < 4 ? - add x6,x16,x6,lsl #3 // x6 -> C3 - mov x9,x0 // save A0 - mov x10,x3 // save K - csel x6,x17,x6,lo // if M < 4 x6/C3 -> C2 - csel x4,x12,x4,hi // if M > 4 M = 4; - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x0 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x12 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x13 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v22.s[3] v26.s[0]_v26.s[3] v30.s[0]_v30.s[3]| x17 -// x14 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x6 -// ------------------ ---------------------------------------------------------------------------- - -ProcessNextColumnLoop: - ldr q16,[x8],#16 // Init accumulators with column sums - ldr q20,[x8],#16 - ldr q24,[x8],#16 - ldr q28,[x8],#16 - mov x0,x9 // reload A0 - cmp x4,#2 // M < 2 ? - ldr q4,[x1],#16 // Load B - add x12,x9,x7 // x12 -> A1 - add x13,x0,x7,lsl #1 // x13 -> A2 - csel x12,x0,x12,lo // if M < 2 A1 -> A0 - ldr d0,[x0],#8 // Load A0 1st/2nd block of 4 - csel x13,x12,x13,ls // if M <= 2 A2 -> A1 - cmp x4,4 // M < 4 ? - ldr d5,[x1],#8 - add x14,x12,x7,lsl #1 // x14 -> A3 - ldr d1,[x12],#8 // Load A1 - csel x14,x13,x14,lo // if M < 4 A3 -> A2 - ldr d2,[x13],#8 // Load A2 - mov v17.16b,v16.16b - ldr d3,[x14],#8 // Load A3 - mov v18.16b,v16.16b - ldr x15,[x1],#8 - mov v19.16b,v16.16b - ldr d6,[x1],#8 - mov v21.16b,v20.16b - ldr x20,[x1],#8 - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - subs x3,x10,#2 // one loop iteration and epilogue consume k = 32 - mov v30.16b,v28.16b - mov v31.16b,v28.16b - b.lo BlockLoopEpilogue // Need 32 k for main loop - -BlockLoop: - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,0 - ldr x21,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x15 - SdotByElement 18, 4, 2,0 - ldr d8,[x0],#8 // Load A0 3rd/4th block of 4 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x11,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x20 - SdotByElement 22, 5, 2,0 - ldr d9,[x12],#8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x15,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x21 - SdotByElement 26, 6, 2,0 - ldr d10,[x13],#8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x20,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x11 - SdotByElement 30, 7, 2,0 - ldr d11,[x14],#8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x21,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x15 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x11,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x20 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x15,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x21 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x20,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x11 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x21,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x15 - SdotByElement 18, 4,10,0 - ldr d0,[x0],#8 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x11,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x20 - SdotByElement 22, 5,10,0 - ldr d1,[x12],#8 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x15,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x21 - SdotByElement 26, 6,10,0 - ldr d2,[x13],#8 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x20,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x11 - SdotByElement 30, 7,10,0 - ldr d3,[x14],#8 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x21,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x15 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,1 - ldr x11,[x1],#8 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x20 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,1 - ldr x15,[x1],#8 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x21 - SdotByElement 26, 6,10,1 - subs x3,x3,#1 // k -= 16 - SdotByElement 27, 6,11,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,1 - ldr x20,[x1],#8 - SdotByElement 29, 7, 9,1 - ins v4.d[1],x11 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - b.hs BlockLoop - -BlockLoopEpilogue: - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,0 - ldr x21,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x15 - SdotByElement 18, 4, 2,0 - ldr d8,[x0],#8 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x11,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x20 - SdotByElement 22, 5, 2,0 - ldr d9,[x12],#8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x15,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x21 - SdotByElement 26, 6, 2,0 - ldr d10,[x13],#8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x20,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x11 - SdotByElement 30, 7, 2,0 - ldr d11,[x14],#8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x21,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x15 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x11,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x20 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x15,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x21 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x20,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x11 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x21,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x15 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x11,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x20 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x15,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x21 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x20,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x11 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x21,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x15 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x20 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x21 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x5,x5,#16 // adjust CountN remaining - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - blo StoreOutputPartial - stp q16,q20,[x2],#32 - stp q24,q28,[x2],#32 - stp q17,q21,[x16],#32 - stp q25,q29,[x16],#32 - stp q18,q22,[x17],#32 - stp q26,q30,[x17],#32 - stp q19,q23,[x6],#32 - stp q27,q31,[x6],#32 - cbnz x5,ProcessNextColumnLoop - -ExitKernel: - mov x0,x4 // return number of rows handled - ldp x20,x21,[sp,#32] - ldp d10,d11,[sp,#16] - ldp d8,d9,[sp],#.LGemmS8S8KernelFrame_SavedRegisters - ret - -// -// Store the partial 1 to 15 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartial: - tbz x5,#3,StoreOutputPartial4 - stp q16,q20,[x2],#32 - mov v16.16b,v24.16b // shift remaining elements down - mov v20.16b,v28.16b - stp q17,q21,[x16],#32 - mov v17.16b,v25.16b - mov v21.16b,v29.16b - stp q18,q22,[x17],#32 - mov v18.16b,v26.16b - mov v22.16b,v30.16b - stp q19,q23,[x6],#32 - mov v19.16b,v27.16b - mov v23.16b,v31.16b - -StoreOutputPartial4: - tbz x5,#2,StoreOutputPartial2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v20.16b // shift remaining elements down - st1 {v17.4s},[x16],#16 - mov v17.16b,v21.16b - st1 {v18.4s},[x17],#16 - mov v18.16b,v22.16b - st1 {v19.4s},[x6],#16 - mov v19.16b,v23.16b - -StoreOutputPartial2: - tbz x5,#1,StoreOutputPartial1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v17.2s},[x16],#8 - dup v17.4s,v17.s[2] - st1 {v18.2s},[x17],#8 - dup v18.4s,v18.s[2] - st1 {v19.2s},[x6],#8 - dup v19.4s,v19.s[2] - -StoreOutputPartial1: - tbz x5,#0,ExitKernel - st1 {v16.s}[0],[x2] - st1 {v17.s}[0],[x16] - st1 {v18.s}[0],[x17] - st1 {v19.s}[0],[x6] - b ExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/aarch64/asmmacro.h b/onnxruntime/core/mlas/lib/aarch64/asmmacro.h deleted file mode 100644 index 72982db00352f..0000000000000 --- a/onnxruntime/core/mlas/lib/aarch64/asmmacro.h +++ /dev/null @@ -1,95 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - asmmacro.h - -Abstract: - - This module implements common macros for the assembly modules. - ---*/ - -/*++ - -Macro Description: - - This macro emits the assembler directives to annotate a new function. - -Arguments: - - FunctionName - Supplies the name of the function. - ---*/ - - .macro FUNCTION_ENTRY FunctionName - - .p2align 2 -#if defined(__APPLE__) - .globl _\FunctionName\() -_\FunctionName\(): -#else - .globl \FunctionName\() - .type \FunctionName\(),%function -\FunctionName\(): -#endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count is greater than or - equal to Value. - -Arguments: - - Count - Supplies the variable used in the comparison. - - Value - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCountGE Count1, Value1, Statement - -.if (\Count1\() >= \Value1\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count1 is greater than or - equal to Value1 and Count2 is greater than or equal to Value2. - -Arguments: - - Count1 - Supplies the variable used in the comparison. - - Value1 - Supplies the static used in the comparison. - - Count2 - Supplies the variable used in the comparison. - - Value2 - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement - -.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) - \Statement\() -.endif - - .endm diff --git a/onnxruntime/core/mlas/lib/activate.cpp b/onnxruntime/core/mlas/lib/activate.cpp deleted file mode 100644 index df3b884a7e7c9..0000000000000 --- a/onnxruntime/core/mlas/lib/activate.cpp +++ /dev/null @@ -1,521 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - activate.cpp - -Abstract: - - This module implements the fused activation and bias addition routines. - ---*/ - -#include "mlasi.h" - -// -// Templates for bias addition functions. -// - -template -struct MLAS_BIAS_ADDITION; - -template<> -struct MLAS_BIAS_ADDITION -{ - MLAS_FLOAT32X4 BiasBroadcast; - - void LoadNext(const float*& Bias) - { - BiasBroadcast = MlasBroadcastFloat32x4(Bias++); - } - - MLAS_FLOAT32X4 Add(MLAS_FLOAT32X4 Value) - { - return MlasAddFloat32x4(Value, BiasBroadcast); - } - - float Add(float Value) - { - return Value + MlasExtractLaneFloat32x4<0>(BiasBroadcast); - } -}; - -template<> -struct MLAS_BIAS_ADDITION -{ - void LoadNext(const float*& Bias) - { - MLAS_UNREFERENCED_PARAMETER(Bias); - } - - MLAS_FLOAT32X4 Add(MLAS_FLOAT32X4 Value) - { - return Value; - } - - float Add(float Value) - { - return Value; - } -}; - -// -// Templates for activation functions. -// - -template -struct MLAS_ACTIVATION_FUNCTION; - -template<> -struct MLAS_ACTIVATION_FUNCTION -{ - MLAS_ACTIVATION_FUNCTION(const MLAS_ACTIVATION* Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT32X4 Activate(MLAS_FLOAT32X4 Value) - { - return Value; - } - - float Activate(float Value) - { - return Value; - } -}; - -template<> -struct MLAS_ACTIVATION_FUNCTION -{ - const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - MLAS_ACTIVATION_FUNCTION(const MLAS_ACTIVATION* Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT32X4 Activate(MLAS_FLOAT32X4 Value) - { - return MlasMaximumFloat32x4(ZeroFloat32x4, Value); - } - - float Activate(float Value) - { -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_cvtss_f32(Activate(_mm_set_ss(Value))); -#else - return std::max(Value, 0.0f); -#endif - } -}; - -template<> -struct MLAS_ACTIVATION_FUNCTION -{ - const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - MLAS_FLOAT32X4 AlphaBroadcast; - - MLAS_ACTIVATION_FUNCTION(const MLAS_ACTIVATION* Activation) - { - AlphaBroadcast = MlasBroadcastFloat32x4(&Activation->Parameters.LeakyRelu.alpha); - } - - MLAS_FLOAT32X4 Activate(MLAS_FLOAT32X4 Value) - { - MLAS_FLOAT32X4 ValueTimesAlpha = MlasMultiplyFloat32x4(Value, AlphaBroadcast); - -#if defined(MLAS_NEON_INTRINSICS) -#if defined(_WIN32) - return vbslq_f32(vcleq_z_f32_ex(Value), ValueTimesAlpha, Value); -#else - // N.B. Standard NEON headers lack an intrinsic for the "vcle #0" form. - return vbslq_f32(vcleq_f32(Value, ZeroFloat32x4), ValueTimesAlpha, Value); -#endif -#elif defined(MLAS_AVX_INTRINSICS) - return _mm_blendv_ps(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); -#elif defined(MLAS_SSE2_INTRINSICS) - return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); -#else - return MlasBlendFloat32x4(ValueTimesAlpha, Value, ZeroFloat32x4 < Value); -#endif - } - - float Activate(float Value) - { - float ValueTimesAlpha = Value * MlasExtractLaneFloat32x4<0>(AlphaBroadcast); - -#if defined(MLAS_SSE2_INTRINSICS) - return (Value >= MlasExtractLaneFloat32x4<0>(ZeroFloat32x4)) ? Value : ValueTimesAlpha; -#else - return (Value >= 0.0f) ? Value : ValueTimesAlpha; -#endif - } -}; - -template<> -struct MLAS_ACTIVATION_FUNCTION -{ - MLAS_FLOAT32X4 MinimumBroadcast; - MLAS_FLOAT32X4 MaximumBroadcast; - - MLAS_ACTIVATION_FUNCTION(const MLAS_ACTIVATION* Activation) - { - MinimumBroadcast = MlasBroadcastFloat32x4(&Activation->Parameters.Clip.minimum); - MaximumBroadcast = MlasBroadcastFloat32x4(&Activation->Parameters.Clip.maximum); - } - - MLAS_FLOAT32X4 Activate(MLAS_FLOAT32X4 Value) - { - Value = MlasMaximumFloat32x4(MinimumBroadcast, Value); - Value = MlasMinimumFloat32x4(MaximumBroadcast, Value); - - return Value; - } - - float Activate(float Value) - { -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_cvtss_f32(Activate(_mm_set_ss(Value))); -#else - Value = std::max(Value, MlasExtractLaneFloat32x4<0>(MinimumBroadcast)); - Value = std::min(Value, MlasExtractLaneFloat32x4<0>(MaximumBroadcast)); - - return Value; -#endif - } -}; - -template<> -struct MLAS_ACTIVATION_FUNCTION -{ - MLAS_FLOAT32X4 AlphaBroadcast; - MLAS_FLOAT32X4 BetaBroadcast; - MLAS_FLOAT32X4 MinimumBroadcast; - MLAS_FLOAT32X4 MaximumBroadcast; - - MLAS_ACTIVATION_FUNCTION(const MLAS_ACTIVATION* Activation) - { - AlphaBroadcast = MlasBroadcastFloat32x4(&Activation->Parameters.HardSigmoid.alpha); - BetaBroadcast = MlasBroadcastFloat32x4(&Activation->Parameters.HardSigmoid.beta); - MinimumBroadcast = MlasZeroFloat32x4(); - MaximumBroadcast = MlasBroadcastFloat32x4(1.0f); - } - - MLAS_FLOAT32X4 Activate(MLAS_FLOAT32X4 Value) - { - Value = MlasMultiplyAddFloat32x4(Value, AlphaBroadcast, BetaBroadcast); - Value = MlasMinimumFloat32x4(MaximumBroadcast, Value); - Value = MlasMaximumFloat32x4(MinimumBroadcast, Value); - - return Value; - } - - float Activate(float Value) - { -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_cvtss_f32(Activate(_mm_set_ss(Value))); -#else - Value = MlasExtractLaneFloat32x4<0>(AlphaBroadcast) * Value + MlasExtractLaneFloat32x4<0>(BetaBroadcast); - Value = std::min(Value, MlasExtractLaneFloat32x4<0>(MaximumBroadcast)); - Value = std::max(Value, MlasExtractLaneFloat32x4<0>(MinimumBroadcast)); - - return Value; -#endif - } -}; - -template -void -MlasActivationKernel( - const MLAS_ACTIVATION* Activation, - float* Buffer, - const float* Bias, - size_t M, - size_t N, - size_t ldc - ) -/*++ - -Routine Description: - - This routine steps over the output matrix and invokes the templated bias - addition and activation functions. - -Arguments: - - Activation - Supplies the parameters for the activation. - - Buffer - Supplies the output matrix. - - Bias - Supplies the optional bias vector. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - ldc - Supplies the number of elements per row of the output matrix. - -Return Value: - - None. - ---*/ -{ - MLAS_ACTIVATION_FUNCTION ActivationFunction(Activation); - MLAS_BIAS_ADDITION BiasAddition; - - // - // Step through each row of the output matrix. - // - - while (M-- > 0) { - - float* buffer = Buffer; - size_t n = N; - - BiasAddition.LoadNext(Bias); - - if (n >= 4) { - - do { - - MLAS_FLOAT32X4 Vector = BiasAddition.Add(MlasLoadFloat32x4(buffer)); - MlasStoreFloat32x4(buffer, ActivationFunction.Activate(Vector)); - buffer += 4; - n -= 4; - - } while (n >= 4); - } - - while (n > 0) { - - float Scalar = BiasAddition.Add(*buffer); - *buffer++ = ActivationFunction.Activate(Scalar); - n -= 1; - } - - Buffer += ldc; - } -} - -template<> -inline -void -MlasActivationKernel( - const MLAS_ACTIVATION* Activation, - float* Buffer, - const float* Bias, - size_t M, - size_t N, - size_t ldc - ) -/*++ - -Routine Description: - - This routine is invoked for the special case of an identity operation with - no bias addition, which translates to a no-op. - -Arguments: - - Activation - Supplies the parameters for the activation. - - Buffer - Supplies the output matrix. - - Bias - Supplies the optional bias vector. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - ldc - Supplies the number of elements per row of the output matrix. - -Return Value: - - None. - ---*/ -{ - // - // No operation. - // - - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_UNREFERENCED_PARAMETER(Buffer); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(M); - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(ldc); -} - -template -inline -void -MlasActivationKernel( - const MLAS_ACTIVATION* Activation, - float* Buffer, - const float* Bias, - size_t M, - size_t N, - size_t ldc - ) -/*++ - -Routine Description: - - This routine invokes the appropriate activation kernel based on the - optional bias vector. - -Arguments: - - Activation - Supplies the parameters for the activation. - - Buffer - Supplies the output matrix. - - Bias - Supplies the optional bias vector. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - ldc - Supplies the number of elements per row of the output matrix. - -Return Value: - - None. - ---*/ -{ - if (Bias != nullptr) { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - } else { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - } -} - -void -MLASCALL -MlasActivation( - const MLAS_ACTIVATION* Activation, - float* Buffer, - const float* Bias, - size_t M, - size_t N, - size_t ldc - ) -/*++ - -Routine Description: - - This routine applies an activation function to the output matrix after - optionally adding a bias vector. - -Arguments: - - Activation - Supplies the parameters for the activation. - - Buffer - Supplies the output matrix. - - Bias - Supplies the optional bias vector. - - M - Supplies the number of elements of the bias vector and the number of - rows in the output matrix. - - N - Supplies the number of columns of the output matrix. - - ldc - Supplies the number of elements per row of the output matrix. - -Return Value: - - None. - ---*/ -{ - switch (Activation->ActivationKind) { - - case MlasIdentityActivation: - { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - break; - } - - case MlasReluActivation: - { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - break; - } - - case MlasLeakyReluActivation: - { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - break; - } - - case MlasTanhActivation: - { - if (Bias != nullptr) { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - } - - if (N == ldc) { - MlasComputeTanh(Buffer, Buffer, M * N); - } else { - while (M-- > 0) { - MlasComputeTanh(Buffer, Buffer, N); - Buffer += ldc; - } - } - - break; - } - - case MlasLogisticActivation: - { - if (Bias != nullptr) { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - } - - if (N == ldc) { - MlasComputeLogistic(Buffer, Buffer, M * N); - } else { - while (M-- > 0) { - MlasComputeLogistic(Buffer, Buffer, N); - Buffer += ldc; - } - } - - break; - } - - case MlasClipActivation: - { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - break; - } - - case MlasHardSigmoidActivation: - { - MlasActivationKernel(Activation, Buffer, Bias, M, N, ldc); - break; - } - - case MlasActivationKindCount: - { - MLAS_THROW_EX(std::runtime_error, "bad mlas activation kind"); - break; - } - } -} diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp deleted file mode 100644 index 776ec67fccc1a..0000000000000 --- a/onnxruntime/core/mlas/lib/activate_fp16.cpp +++ /dev/null @@ -1,885 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - activate_fp16.cpp - -Abstract: - - This module implements the activation routines for fp16 data types - ---*/ - -#include "fp16_common.h" - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - -// -// Templates for activation functions. -// - -template -struct MLAS_HALF_ACTIVATION_FUNCTION; - -template <> -struct MLAS_HALF_ACTIVATION_FUNCTION { - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { return Value; } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { return Value; } - - float Activate(float Value) { return Value; } -}; - -template<> -struct MLAS_HALF_ACTIVATION_FUNCTION -{ - const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8(); - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) - { - return MlasMaximumFloat16x8(ZeroVec, Value); - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) - { - return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value); - } -}; - -template<> -struct MLAS_HALF_ACTIVATION_FUNCTION -{ - const MLAS_FLOAT16X8 ZeroVec = MlasZeroFloat16x8(); - - MLAS_FLOAT16X8 AlphaBroadcast; - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.LeakyRelu.alpha); - AlphaBroadcast = MlasBroadcastFloat16x8(alpha); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) - { - MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast); - return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec), - ValueTimesAlpha, Value); - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) - { - MLAS_FLOAT16X4 ValueTimesAlpha = - MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); - return MlasBitwiseSelectFloat16x4( - MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha, - Value); - } -}; - -template <> -struct MLAS_HALF_ACTIVATION_FUNCTION { -#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) - // - // Ported from XNNPACK (f16-tanh-aarch64-neonfp16arith-expm1minus-rr1-p3h2-div.c) - // - - // - // Constants for float16x8_t - // - - // The smallest z for which tanhh(-z) is saturated at -1.0h. - const float16x8_t vsat_cutoff_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x4482))); // 0x1.208p+2h - // Large number such that ulp(magic bias) == 0.5 and magic bias === 7.5 mod 2**8. - const float16x8_t vmagic_bias_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x620F))); // 0x1.83Cp+9h - const float16x8_t vminus_log2e_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xBDC5))); // -0x1.714p+0h - const float16x8_t vln2_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x398C))); // 0x1.630p-1h - // Coefficients of polynomial approximation - // exp(-2t) - 1 ~ t * (-2 + t * (c2 + t * c3)) - // on [-log(2)/4, log(2)/4] - const float16x8_t vc3_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xBD5B))); // -0x1.56Cp+0h - const float16x8_t vc2_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x4008))); // 0x1.020p+1h - const float16x8_t vtwo_16x8 = vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x4000))); // 2.0h - const float16x8_t vminus_one_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xBC00))); // -1.0h - // Mask for the sign bit. - const uint16x8_t vsign_mask_16x8 = vmovq_n_u16(UINT16_C(0x8000)); - - // - // Constants for float16x4_t - // - - // The smallest z for which tanhh(-z) is saturated at -1.0h. - const float16x4_t vsat_cutoff_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x4482))); // 0x1.208p+2h - // Large number such that ulp(magic bias) == 0.5 and magic bias === 7.5 mod 2**8. - const float16x4_t vmagic_bias_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x620F))); // 0x1.83Cp+9h - const float16x4_t vminus_log2e_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xBDC5))); // -0x1.714p+0h - const float16x4_t vln2_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x398C))); // 0x1.630p-1h - // Coefficients of polynomial approximation - // exp(-2t) - 1 ~ t * (-2 + t * (c2 + t * c3)) - // on [-log(2)/4, log(2)/4] - const float16x4_t vc3_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xBD5B))); // -0x1.56Cp+0h - const float16x4_t vc2_16x4 = vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x4008))); // 0x1.020p+1h - const float16x4_t vtwo_16x4 = vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x4000))); // 2.0h - const float16x4_t vminus_one_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xBC00))); // -1.0h - // Mask for the sign bit. - const uint16x4_t vsign_mask_16x4 = vmov_n_u16(UINT16_C(0x8000)); - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 vx) - { - // General structure of the algorithm: - // - // / -expm1(-2x) / (2 + expm1(-2x)) if x >= 0 - // f(x) := - // \ -f(-x) if x <= 0 - // - // First we compute y := expm1(-2z) / (2 + expm1(-2z)) where z = abs(x), - // then set its sign according to the sign of x: f(x) := sign(x) * abs(y). - float16x8_t vz = vabsq_f16(vx); - - // The function saturates at -1 for large positive inputs: tanhh(-z) == -1.0h for z >= - // sat_cutoff ~= 9.010913. To guarantee this behavior, we clip input z at sat_cutoff, and - // leverage the fact that for our implementation tanhf(sat_cutoff) == -1.0h. NaN inputs are - // passed unchanged. - vz = vminq_f16(vz, vsat_cutoff_16x8); - - // Compute reduced argument n := round(-z / log(2), 1). - // We do it by adding a large number (magic bias), which cause rounding of the result to 1 - // fractional bit, then subtracing the large number back. The trick with adding large number - // is valid only within certain bounds - // (|-z / log(2)| <= 2**10, i.e. |z| <= 0x1.630p+7 = 177.5), but that is acceptable, because - // inputs x outside of [-4.5078125, 4.5078125] (i.e. z outsize [0, 4.5078125]) saturate - // tanhh(x). Additionally, we fuse addition of the floating-point exponent bias (15) into - // the magic bias. Note that addition-subtraction of the large number doesn't cause overflow - // for inputs in this range. - float16x8_t vn = vfmaq_f16(vmagic_bias_16x8, vz, vminus_log2e_16x8); - - // Create a floating-point number s (scale) such that s == 2**(2n) for inputs which don't - // cause underflow, i.e. 0 <= z <= 4.5078125, and -7 <= n <= 0 accordingly. - const float16x8_t vs = vreinterpretq_f16_s16(vshlq_n_s16(vreinterpretq_s16_f16(vn), 10)); - - // Subtract the large number back to get final n := round(-z / log(2), 1) as a - // floating-point number. - vn = vsubq_f16(vn, vmagic_bias_16x8); - - // Compute reduced argument t := z + n * log(2). Note that -t = -z - n * log(2). - const float16x8_t vt = vfmaq_f16(vz, vn, vln2_16x8); - - // Compute degree-3 polynomial approximation for exp(-2t) - 1 on [-log(2)/4, log(2)/4]. - // P(t) = t * (-2 + t * (c2 + t * c3)) - // = t * (-p) - float16x8_t vp = vfmaq_f16(vc2_16x8, vc3_16x8, vt); - vp = vfmsq_f16(vtwo_16x8, vp, vt); - - // Reconstruct the exp(-2z) - 1 value: - // exp(-2z) - 1 = s * (t * (-2 + t * (c2 + t * c3)) + 1) - 1 - // = s * t * (-p) + (s - 1) - // = (s - 1) - (t * s) * p - const float16x8_t vts = vmulq_f16(vt, vs); - const float16x8_t vsmo = vaddq_f16(vs, vminus_one_16x8); - const float16x8_t vemo = vfmsq_f16(vsmo, vp, vts); - - // Denominator of the tanh fraction: exp(-2z) + 1 = expm1(-2z) + 2 - const float16x8_t vepo = vaddq_f16(vemo, vtwo_16x8); - - // Reconstruct y = expm1(-2z) / (expm1(-2z) + 2) - float16x8_t vy = vdivq_f16(vemo, vepo); - - // Reconstruct tanh(x) = copysign(y, x) - vy = vbslq_f16(vsign_mask_16x8, vx, vy); - - return vy; - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 vx) - { - // General structure of the algorithm: - // - // / -expm1(-2x) / (2 + expm1(-2x)) if x >= 0 - // f(x) := - // \ -f(-x) if x <= 0 - // - // First we compute y := expm1(-2z) / (2 + expm1(-2z)) where z = abs(x), - // then set its sign according to the sign of x: f(x) := sign(x) * abs(y). - float16x4_t vz = vabs_f16(vx); - - // The function saturates at -1 for large positive inputs: tanhh(-z) == -1.0h for z >= - // sat_cutoff ~= 9.010913. To guarantee this behavior, we clip input z at sat_cutoff, and - // leverage the fact that for our implementation tanhf(sat_cutoff) == -1.0h. NaN inputs are - // passed unchanged. - vz = vmin_f16(vz, vsat_cutoff_16x4); - - // Compute reduced argument n := round(-z / log(2), 1). - // We do it by adding a large number (magic bias), which cause rounding of the result to 1 - // fractional bit, then subtracing the large number back. The trick with adding large number - // is valid only within certain bounds - // (|-z / log(2)| <= 2**10, i.e. |z| <= 0x1.630p+7 = 177.5), but that is acceptable, because - // inputs x outside of [-4.5078125, 4.5078125] (i.e. z outsize [0, 4.5078125]) saturate - // tanhh(x). Additionally, we fuse addition of the floating-point exponent bias (15) into - // the magic bias. Note that addition-subtraction of the large number doesn't cause overflow - // for inputs in this range. - float16x4_t vn = vfma_f16(vmagic_bias_16x4, vz, vminus_log2e_16x4); - - // Create a floating-point number s (scale) such that s == 2**(2n) for inputs which don't - // cause underflow, i.e. 0 <= z <= 4.5078125, and -7 <= n <= 0 accordingly. - const float16x4_t vs = vreinterpret_f16_s16(vshl_n_s16(vreinterpret_s16_f16(vn), 10)); - - // Subtract the large number back to get final n := round(-z / log(2), 1) as a - // floating-point number. - vn = vsub_f16(vn, vmagic_bias_16x4); - - // Compute reduced argument t := z + n * log(2). Note that -t = -z - n * log(2). - const float16x4_t vt = vfma_f16(vz, vn, vln2_16x4); - - // Compute degree-3 polynomial approximation for exp(-2t) - 1 on [-log(2)/4, log(2)/4]. - // P(t) = t * (-2 + t * (c2 + t * c3)) - // = t * (-p) - float16x4_t vp = vfma_f16(vc2_16x4, vc3_16x4, vt); - vp = vfms_f16(vtwo_16x4, vp, vt); - - // Reconstruct the exp(-2z) - 1 value: - // exp(-2z) - 1 = s * (t * (-2 + t * (c2 + t * c3)) + 1) - 1 - // = s * t * (-p) + (s - 1) - // = (s - 1) - (t * s) * p - const float16x4_t vts = vmul_f16(vt, vs); - const float16x4_t vsmo = vadd_f16(vs, vminus_one_16x4); - const float16x4_t vemo = vfms_f16(vsmo, vp, vts); - - // Denominator of the tanh fraction: exp(-2z) + 1 = expm1(-2z) + 2 - const float16x4_t vepo = vadd_f16(vemo, vtwo_16x4); - - // Reconstruct y = expm1(-2z) / (expm1(-2z) + 2) - float16x4_t vy = vdiv_f16(vemo, vepo); - - // Reconstruct tanh(x) = copysign(y, x) - vy = vbsl_f16(vsign_mask_16x4, vx, vy); - - return vy; - } -#else - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 vx) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 vx) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } -#endif -}; - -template <> -struct MLAS_HALF_ACTIVATION_FUNCTION { -#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) - // - // Ported from XNNPACK (f16-sigmoid-aarch64-neonfp16arith-rr2-p3-div.c). - // - - // - // Constants for float16x8_t - // - - // Large number such that ulp(magic bias) == 1 and magic bias === 15 mod 2**9. - const float16x8_t vmagic_bias_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x660F))); // 0x1.83Cp+10h - const float16x8_t vminus_log2e_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xBDC5))); // -0x1.714p+0h - const float16x8_t vln2_hi_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x398C))); // 0x1.630p-1h - const float16x8_t vln2_lo_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x8AF4))); // -0x1.BD0p-13h - // Coefficient of polynomial approximation - // exp(-t) ~ 1 + t * (c1 + t * c2) - // on [-log(2)/2, log(2)/2] - const float16x8_t vc3_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xB156))); // -0x1.558p-3h - const float16x8_t vc2_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x3808))); // 0x1.020p-1h - const float16x8_t vone_16x8 = vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0x3C00))); // 1.0h - // The largest z for which sigmoidh(-z) is normalized. - // This number is also the largest z for which exph(-z) is normalized. - const float16x8_t vdenorm_cutoff_16x8 = - vreinterpretq_f16_u16(vmovq_n_u16(UINT16_C(0xC8DA))); // -0x1.368p+3h - - // - // Constants for float16x4_t - // - - // Large number such that ulp(magic bias) == 1 and magic bias === 15 mod 2**9. - const float16x4_t vmagic_bias_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x660F))); // 0x1.83Cp+10h - const float16x4_t vminus_log2e_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xBDC5))); // -0x1.714p+0h - const float16x4_t vln2_hi_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x398C))); // 0x1.630p-1h - const float16x4_t vln2_lo_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x8AF4))); // -0x1.BD0p-13h - // Coefficient of polynomial approximation - // exp(-t) ~ 1 + t * (c1 + t * c2) - // on [-log(2)/2, log(2)/2] - const float16x4_t vc3_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xB156))); // -0x1.558p-3h - const float16x4_t vc2_16x4 = vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x3808))); // 0x1.020p-1h - const float16x4_t vone_16x4 = vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0x3C00))); // 1.0h - // The largest z for which sigmoidh(-z) is normalized. - // This number is also the largest z for which exph(-z) is normalized. - const float16x4_t vdenorm_cutoff_16x4 = - vreinterpret_f16_u16(vmov_n_u16(UINT16_C(0xC8DA))); // -0x1.368p+3h - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 vx) - { - // General structure of the algorithm: - // - // / exp(x) / (1 + exp(x)) if x <= 0 - // f[x] := - // \ 1 - f[-x] if x >= 0 - // - // First we compute f[-z] := exp(-z) / (1 + exp(-z)) where z = abs(x), - // then replace result with 1 - f[-z] if x >= 0. - const float16x8_t vz = vabsq_f16(vx); - - // Compute reduced argument n := round(-z / ln2). - // We do it by adding a large number (magic bias) to the product z * (-1/ln2), which - // cause rounding of the result to an integer, then subtracing the large number back. The - // first addition is combined with multiplication by -log2e into a single FMA instruction. - // The trick with adding large number is valid only within certain bounds - // (|-x / ln2| <= 2**9, i.e. |z| <= 0x1.630p+8 = 355.0), but that is acceptable, because - // inputs outside of [-9.703125, 8.3125] (i.e. z outside [0, 9.703125]) underflow or - // saturate sigmoidh(x). We fixup the result for such inputs at the very end of the - // algorithm. - float16x8_t vn = vfmaq_f16(vmagic_bias_16x8, vz, vminus_log2e_16x8); - - // Create a floating-point number s (scale) such that s == 2**n for inputs which don't cause - // underflow, i.e. -9.703125 <= -z <= 0.0, and -14 <= n <= 0 accordingly. - const float16x8_t vs = vreinterpretq_f16_s16(vshlq_n_s16(vreinterpretq_s16_f16(vn), 10)); - - // Subtract the large number back to get the final n := round(-z / ln2) as a - // floating-point number. - vn = vsubq_f16(vn, vmagic_bias_16x8); - - // Compute reduced argument t := z - n * log(2). Note that -t = -z - n * log(2). - // Use Cody-Waite range reduction method (note two constants to represent -ln(2)) to - // improve accuracy. - float16x8_t vt = vfmaq_f16(vz, vn, vln2_hi_16x8); - vt = vfmaq_f16(vt, vn, vln2_lo_16x8); - - // Compute degree-3 polynomial approximation for exp(-t) on [-log(2)/2, log(2)/2]: - // P(t) = 1 + t * (-1 + t * (c2 + t * c3)) = -(1 - t * p) - float16x8_t vp = vfmaq_f16(vc2_16x8, vc3_16x8, vt); - vp = vfmsq_f16(vone_16x8, vp, vt); - - // Reconstruct the exp(-z) value: - // e = s * (1 + t * (-1 + t * (c2 + t * c3)) - // = s * (1 - t * (-p)) - // = s - (t * s) * (-p) - vt = vmulq_f16(vt, vs); - float16x8_t ve = vfmsq_f16(vs, vp, vt); - - // Denominator of the sigmoid fraction: 1.0 + exp(-z) - float16x8_t vd = vaddq_f16(ve, vone_16x8); - - // Reconstruct sigmoid(-z) = exp(-z) / (1.0 + exp(-z)) - float16x8_t vf = vdivq_f16(ve, vd); - - // For inputs below denormal cutoff, replace output with +0.0f. - // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. - vf = vreinterpretq_f16_u16( - vbicq_u16(vreinterpretq_u16_f16(vf), vcagtq_f16(vx, vdenorm_cutoff_16x8))); - - // Reconstruct sigmoid(x) = x < 0 ? sigmoid(-z) : 1.0 - sigmoid(-z) - const uint16x8_t vm = vcltq_f16(vx, vreinterpretq_f16_u16(vmovq_n_u16(0))); - vf = vbslq_f16(vm, vf, vsubq_f16(vone_16x8, vf)); - - return vf; - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 vx) - { - // General structure of the algorithm: - // - // / exp(x) / (1 + exp(x)) if x <= 0 - // f[x] := - // \ 1 - f[-x] if x >= 0 - // - // First we compute f[-z] := exp(-z) / (1 + exp(-z)) where z = abs(x), - // then replace result with 1 - f[-z] if x >= 0. - const float16x4_t vz = vabs_f16(vx); - - // Compute reduced argument n := round(-z / ln2). - // We do it by adding a large number (magic bias) to the product z * (-1/ln2), which - // cause rounding of the result to an integer, then subtracing the large number back. The - // first addition is combined with multiplication by -log2e into a single FMA instruction. - // The trick with adding large number is valid only within certain bounds - // (|-x / ln2| <= 2**9, i.e. |z| <= 0x1.630p+8 = 355.0), but that is acceptable, because - // inputs outside of [-9.703125, 8.3125] (i.e. z outside [0, 9.703125]) underflow or - // saturate sigmoidh(x). We fixup the result for such inputs at the very end of the - // algorithm. - float16x4_t vn = vfma_f16(vmagic_bias_16x4, vz, vminus_log2e_16x4); - - // Create a floating-point number s (scale) such that s == 2**n for inputs which don't cause - // underflow, i.e. -9.703125 <= -z <= 0.0, and -14 <= n <= 0 accordingly. - const float16x4_t vs = vreinterpret_f16_s16(vshl_n_s16(vreinterpret_s16_f16(vn), 10)); - - // Subtract the large number back to get the final n := round(-z / ln2) as a - // floating-point number. - vn = vsub_f16(vn, vmagic_bias_16x4); - - // Compute reduced argument t := z - n * log(2). Note that -t = -z - n * log(2). - // Use Cody-Waite range reduction method (note two constants to represent -ln2) to - // improve accuracy. - float16x4_t vt = vfma_f16(vz, vn, vln2_hi_16x4); - vt = vfma_f16(vt, vn, vln2_lo_16x4); - - // Compute degree-3 polynomial approximation for exp(-t) on [-log(2)/2, log(2)/2]: - // P(t) = 1 + t * (-1 + t * (c2 + t * c3)) = -(1 - t * p) - float16x4_t vp = vfma_f16(vc2_16x4, vc3_16x4, vt); - vp = vfms_f16(vone_16x4, vp, vt); - - // Reconstruct the exp(-z) value: - // e = s * (1 + t * (-1 + t * (c2 + t * c3)) - // = s * (1 - t * (-p)) - // = s - (t * s) * (-p) - vt = vmul_f16(vt, vs); - float16x4_t ve = vfms_f16(vs, vp, vt); - - // Denominator of the sigmoid fraction: 1.0 + exp(-z) - float16x4_t vd = vadd_f16(ve, vone_16x4); - - // Reconstruct sigmoid(-z) = exp(-z) / (1.0 + exp(-z)) - float16x4_t vf = vdiv_f16(ve, vd); - - // For inputs below denormal cutoff, replace output with +0.0f. - // Note that for NaN inputs, comparison result is false, and outputs are left unchanged. - vf = vreinterpret_f16_u16( - vbic_u16(vreinterpret_u16_f16(vf), vcagt_f16(vx, vdenorm_cutoff_16x4))); - - // Reconstruct sigmoid(x) = x < 0 ? sigmoid(-z) : 1.0 - sigmoid(-z) - const uint16x4_t vm = vclt_f16(vx, vreinterpret_f16_u16(vmov_n_u16(0))); - vf = vbsl_f16(vm, vf, vsub_f16(vone_16x4, vf)); - - return vf; - } -#else - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 vx) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 vx) - { - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_THROW_EX(std::runtime_error, "unsupported target architecture"); - } -#endif -}; - -template <> -struct MLAS_HALF_ACTIVATION_FUNCTION { - MLAS_FLOAT16X8 MinimumBroadcast; - MLAS_FLOAT16X8 MaximumBroadcast; - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - const _mlas_fp16_ min = MLAS_Float2Half(Activation.Parameters.Clip.minimum); - MinimumBroadcast = MlasBroadcastFloat16x8(min); - const _mlas_fp16_ max = MLAS_Float2Half(Activation.Parameters.Clip.maximum); - MaximumBroadcast = MlasBroadcastFloat16x8(max); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) - { - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); - - return Value; - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) - { - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); - return Value; - } -}; - -template<> -struct MLAS_HALF_ACTIVATION_FUNCTION -{ - MLAS_FLOAT16X8 AlphaBroadcast; - MLAS_FLOAT16X8 BetaBroadcast; - MLAS_FLOAT16X8 MinimumBroadcast; - MLAS_FLOAT16X8 MaximumBroadcast; - - MLAS_HALF_ACTIVATION_FUNCTION(const MLAS_ACTIVATION& Activation) - { - const _mlas_fp16_ alpha = MLAS_Float2Half(Activation.Parameters.HardSigmoid.alpha); - AlphaBroadcast = MlasBroadcastFloat16x8(alpha); - const _mlas_fp16_ beta = MLAS_Float2Half(Activation.Parameters.HardSigmoid.beta); - BetaBroadcast = MlasBroadcastFloat16x8(beta); - MinimumBroadcast = MlasZeroFloat16x8(); - MaximumBroadcast = MlasBroadcastFloat16x8(MLAS_Float2Half(1.0f)); - } - - MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) - { - Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); - - return Value; - } - - MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) - { - Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), - MlasToLowHalfFloat16x4(BetaBroadcast)); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); - - return Value; - } -}; - -template -inline -void -MlasActivationKernel( - const MLAS_ACTIVATION& Activation, - _mlas_fp16_* Buffer, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) -{ - MLAS_HALF_ACTIVATION_FUNCTION ActivationFunction(Activation); - - auto* CRow = Buffer + StartM * ldc + StartN; - - while (CountM-- > 0) { - _mlas_fp16_* buffer = CRow; - size_t n = CountN; - - while (n >= 8) { - MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); - MlasStoreFloat16x8(buffer, ActivationFunction.Activate(Vector)); - buffer += 8; - n -= 8; - } - - if (n >= 4) { - MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); - MlasStoreFloat16x4(buffer, ActivationFunction.Activate(Vector)); - buffer += 4; - n -= 4; - } - - if (n > 0) { - MLAS_FLOAT16X4 buf; - std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); - MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf); - MlasStorePartialFloat16x4(buffer, res, n); - } - - CRow += ldc; - } -} - -template<> -inline -void -MlasActivationKernel( - const MLAS_ACTIVATION& Activation, - _mlas_fp16_* Buffer, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) -{ - // - // No operation. - // - - MLAS_UNREFERENCED_PARAMETER(Activation); - MLAS_UNREFERENCED_PARAMETER(Buffer); - MLAS_UNREFERENCED_PARAMETER(StartM); - MLAS_UNREFERENCED_PARAMETER(StartN); - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(ldc); -} - - -template -MLAS_FORCEINLINE -void -MlasActivationKernel( - const MLAS_ACTIVATION& Activation, - _mlas_fp16_* Buffer, - const _mlas_fp16_* Addon, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) -{ - MLAS_HALF_ACTIVATION_FUNCTION ActivationFunction(Activation); - - auto* CRow = Buffer + StartM * ldc + StartN; - const auto* ARow = Addon + StartM * ldc + StartN; - - while (CountM-- > 0) { - auto* buffer = CRow; - const auto* addsrc = ARow; - size_t n = CountN; - - while (n >= 8) { - MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc); - MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); - addsrc += 8; - Vector = MlasAddFloat16x8(Vector, AVec); - Vector = ActivationFunction.Activate(Vector); - MlasStoreFloat16x8(buffer, Vector); - buffer += 8; - n -= 8; - } - - if (n >= 4) { - MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc); - MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); - addsrc += 4; - Vector = MlasAddFloat16x4(Vector, AVec); - Vector = ActivationFunction.Activate(Vector); - MlasStoreFloat16x4(buffer, Vector); - buffer += 4; - n -= 4; - } - - if (n > 0) { - MLAS_FLOAT16X4 addbuf; - MLAS_FLOAT16X4 buf; - std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_)); - std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); - buf = MlasAddFloat16x4(buf, addbuf); - buf = ActivationFunction.Activate(buf); - MlasStorePartialFloat16x4(buffer, buf, n); - } - - CRow += ldc; - ARow += ldc; - } -} - - -void -MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( - MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const -{ - auto* Buffer = reinterpret_cast<_mlas_fp16_*>(C); - switch (Activation_.ActivationKind) { - case MlasIdentityActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasReluActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasLeakyReluActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasTanhActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasLogisticActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasClipActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - case MlasHardSigmoidActivation: { - if (SumBuf_) { - MlasActivationKernel( - Activation_, Buffer, reinterpret_cast(SumBuf_), StartM, - StartN, CountM, CountN, ldc); - } else { - MlasActivationKernel(Activation_, Buffer, StartM, StartN, - CountM, CountN, ldc); - } - break; - } - - default: { - MLAS_THROW_EX(std::runtime_error, "bad mlas activation kind"); - return; - } - } -} - -#else -// Really dumb implementation when fp16 acceleration is not supported - -#include - -MLAS_FORCEINLINE -void -CvtFloat2Half( - _mlas_fp16_* dest, - const float* src, - size_t len -) -{ - for (size_t i = 0; i < len; i++) { - *dest++ = MLAS_Float2Half(*src++); - } -} - -void -MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( - MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const -{ - std::vector buffer(CountM*CountN); - - _mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C); - auto* CRow = buffer.data(); - const _mlas_fp16_* CAdd = nullptr; - if (SumBuf_) { - CAdd = reinterpret_cast(SumBuf_) + StartM * ldc + StartN; - } - Output += StartM * ldc + StartN; - - while (CountM-- > 0) { - if (CAdd) { - for (size_t n = 0; n < CountN; n++) { - CRow[n] += MLAS_Half2Float(CAdd[n]); - } - CAdd += ldc; - } - MlasActivation(&this->Activation_, CRow, nullptr, 1, CountN, CountN); - - CvtFloat2Half(Output, CRow, CountN); - CRow += CountN; - Output += ldc; - } -} - -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc b/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc deleted file mode 100644 index ed885dcb7b781..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/AssembleAvx512Vnni.inc +++ /dev/null @@ -1,242 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; AssembleAvx512Vnni.inc -; -; Abstract: -; -; This module contains macros to build VNNI instructions for toolchains that -; do not natively support this newer instruction set extension. -; -;-- - -; -; Map friendly register names to the encoded register index. -; - -ZmmIndex_zmm0 EQU 0 -ZmmIndex_zmm1 EQU 1 -ZmmIndex_zmm2 EQU 2 -ZmmIndex_zmm3 EQU 3 -ZmmIndex_zmm4 EQU 4 -ZmmIndex_zmm5 EQU 5 -ZmmIndex_zmm6 EQU 6 -ZmmIndex_zmm7 EQU 7 -ZmmIndex_zmm8 EQU 8 -ZmmIndex_zmm9 EQU 9 -ZmmIndex_zmm10 EQU 10 -ZmmIndex_zmm11 EQU 11 -ZmmIndex_zmm12 EQU 12 -ZmmIndex_zmm13 EQU 13 -ZmmIndex_zmm14 EQU 14 -ZmmIndex_zmm15 EQU 15 -ZmmIndex_zmm16 EQU 16 -ZmmIndex_zmm17 EQU 17 -ZmmIndex_zmm18 EQU 18 -ZmmIndex_zmm19 EQU 19 -ZmmIndex_zmm20 EQU 20 -ZmmIndex_zmm21 EQU 21 -ZmmIndex_zmm22 EQU 22 -ZmmIndex_zmm23 EQU 23 -ZmmIndex_zmm24 EQU 24 -ZmmIndex_zmm25 EQU 25 -ZmmIndex_zmm26 EQU 26 -ZmmIndex_zmm27 EQU 27 -ZmmIndex_zmm28 EQU 28 -ZmmIndex_zmm29 EQU 29 -ZmmIndex_zmm30 EQU 30 -ZmmIndex_zmm31 EQU 31 - -GprIndex_rax EQU 0 -GprIndex_rcx EQU 1 -GprIndex_rdx EQU 2 -GprIndex_rbx EQU 3 -GprIndex_rbp EQU 5 -GprIndex_rsi EQU 6 -GprIndex_rdi EQU 7 -GprIndex_r8 EQU 8 -GprIndex_r9 EQU 9 -GprIndex_r10 EQU 10 -GprIndex_r11 EQU 11 -GprIndex_r12 EQU 12 -GprIndex_r13 EQU 13 -GprIndex_r14 EQU 14 -GprIndex_r15 EQU 15 - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr zmm1,zmm2,zmm3 -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; Src2Reg - Specifies the second source register. -; - -VnniZmmZmmZmm MACRO Opcode, DestReg, Src1Reg, Src2Reg - - LOCAL Payload0, Payload1, Payload2, ModRMByte - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) - Payload0 = Payload0 + ((((ZmmIndex_&Src2Reg& SHR 4) AND 1) XOR 1) SHL 6) - Payload0 = Payload0 + ((((ZmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) - Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 4) AND 1) XOR 1) SHL 4) - - Payload1 = 005h ; "66" prefix - Payload1 = Payload1 + (((ZmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - Payload2 = 040h ; 512-bit vector length - Payload2 = Payload2 + ((((ZmmIndex_&Src1Reg& SHR 4) AND 1) XOR 1) SHL 3) - - ModRMByte = 0C0h ; register form - ModRMByte = ModRMByte + ((ZmmIndex_&DestReg& AND 7) SHL 3) - ModRMByte = ModRMByte + (ZmmIndex_&Src2Reg& AND 7) - - db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte - - ENDM - -VpdpbusdZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 050h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbusdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 051h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 052h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdsZmmZmmZmm MACRO DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 053h, DestReg, Src1Reg, Src2Reg - - ENDM - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr zmm1,zmm2,DWORD BCST [BaseReg+IndexReg*Scale+ByteOffset] -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; BaseReg - Specifies the base register of the broadcast operand. -; -; ByteOffset - Specifies the DWORD aligned byte offset for the broadcast -; operand. -; -; IndexReg - Specifies the optional index register of the broadcast operand. -; -; Scale - Specifies the scaling factor of the optional index register. -; - -VnniZmmZmmBroadcast MACRO Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - LOCAL Payload0, Payload1, Payload2, ModRMByte, SibByte - -.errnz (ByteOffset AND 3) - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) -IFNB - Payload0 = Payload0 + ((((GprIndex_&IndexReg& SHR 3) AND 1) XOR 1) SHL 6) -ELSE - Payload0 = Payload0 + 040h ; zero logical index register -ENDIF - Payload0 = Payload0 + ((((GprIndex_&BaseReg& SHR 3) AND 1) XOR 1) SHL 5) - Payload0 = Payload0 + ((((ZmmIndex_&DestReg& SHR 4) AND 1) XOR 1) SHL 4) - - Payload1 = 005h ; "66" prefix - Payload1 = Payload1 + (((ZmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - Payload2 = 050h ; 512-bit vector length, broadcast - Payload2 = Payload2 + ((((ZmmIndex_&Src1Reg& SHR 4) AND 1) XOR 1) SHL 3) - - ModRMByte = 000h ; memory form - ModRMByte = ModRMByte + ((ZmmIndex_&DestReg& AND 7) SHL 3) -IFNB - ModRMByte = ModRMByte + 004h ; indicate SIB byte needed -ELSE - ModRMByte = ModRMByte + (GprIndex_&BaseReg& AND 7) -ENDIF -IF ByteOffset NE 0 - ModRMByte = ModRMByte + 040h ; indicate disp8 byte offset -ENDIF - -IFNB - SibByte = 0 -IF Scale EQ 2 - SibByte = SibByte + (1 SHL 6) -ELSEIF Scale EQ 4 - SibByte = SibByte + (2 SHL 6) -ELSEIF Scale EQ 8 - SibByte = SibByte + (3 SHL 6) -ELSEIF Scale NE 1 - .err -ENDIF - SibByte = SibByte + ((GprIndex_&IndexReg& AND 7) SHL 3) - SibByte = SibByte + (GprIndex_&BaseReg& AND 7) -ENDIF - - db 062h, Payload0, Payload1, Payload2, Opcode, ModRMByte -IFNB - db SibByte -ENDIF -IF ByteOffset NE 0 - db ByteOffset SHR 2 -ENDIF - - ENDM - -VpdpbusdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 050h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - ENDM - -VpdpbusdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 051h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - ENDM - -VpdpwssdZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 052h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - ENDM - -VpdpwssdsZmmZmmBroadcast MACRO DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 053h, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/AssembleAvxVnni.inc b/onnxruntime/core/mlas/lib/amd64/AssembleAvxVnni.inc deleted file mode 100644 index d5867e8884c44..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/AssembleAvxVnni.inc +++ /dev/null @@ -1,330 +0,0 @@ -;++ -; -; Copyright (c) 2020 Intel Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; AssembleAvxVnni.inc -; -; Abstract: -; -; This module contains macros to build AVXVNNI instructions for toolchains that -; do not natively support this newer instruction set extension. -; -;-- - -; -; Map friendly register names to the encoded register index. -; - -YmmIndex_ymm0 EQU 0 -YmmIndex_ymm1 EQU 1 -YmmIndex_ymm2 EQU 2 -YmmIndex_ymm3 EQU 3 -YmmIndex_ymm4 EQU 4 -YmmIndex_ymm5 EQU 5 -YmmIndex_ymm6 EQU 6 -YmmIndex_ymm7 EQU 7 -YmmIndex_ymm8 EQU 8 -YmmIndex_ymm9 EQU 9 -YmmIndex_ymm10 EQU 10 -YmmIndex_ymm11 EQU 11 -YmmIndex_ymm12 EQU 12 -YmmIndex_ymm13 EQU 13 -YmmIndex_ymm14 EQU 14 -YmmIndex_ymm15 EQU 15 - -XmmIndex_xmm0 EQU 0 -XmmIndex_xmm1 EQU 1 -XmmIndex_xmm2 EQU 2 -XmmIndex_xmm3 EQU 3 -XmmIndex_xmm4 EQU 4 -XmmIndex_xmm5 EQU 5 -XmmIndex_xmm6 EQU 6 -XmmIndex_xmm7 EQU 7 -XmmIndex_xmm8 EQU 8 -XmmIndex_xmm9 EQU 9 -XmmIndex_xmm10 EQU 10 -XmmIndex_xmm11 EQU 11 -XmmIndex_xmm12 EQU 12 -XmmIndex_xmm13 EQU 13 -XmmIndex_xmm14 EQU 14 -XmmIndex_xmm15 EQU 15 - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr ymm1,ymm2,ymm3 -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; Src2Reg - Specifies the second source register. -; - -VnniYmmYmmYmm MACRO Opcode, DestReg, Src1Reg, Src2Reg - - LOCAL Payload0, Payload1, ModRMByte - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((YmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) - Payload0 = Payload0 + (1 SHL 6) - Payload0 = Payload0 + ((((YmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) - - Payload1 = 005h ; "66" prefix - Payload1 = Payload1 + (((YmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - ModRMByte = 0C0h ; register form - ModRMByte = ModRMByte + ((YmmIndex_&DestReg& AND 7) SHL 3) - ModRMByte = ModRMByte + (YmmIndex_&Src2Reg& AND 7) - - db 0C4h, Payload0, Payload1, Opcode, ModRMByte - - ENDM - -VpdpbusdYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 050h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbusdsYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 051h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 052h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdsYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 053h, DestReg, Src1Reg, Src2Reg - - ENDM - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr xmm1,xmm2,xmm3 -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; Src2Reg - Specifies the second source register. -; - -VnniXmmXmmXmm MACRO Opcode, DestReg, Src1Reg, Src2Reg - - LOCAL Payload0, Payload1, ModRMByte - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((XmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) - Payload0 = Payload0 + (1 SHL 6) - Payload0 = Payload0 + ((((XmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) - - Payload1 = 001h ; "66" prefix - Payload1 = Payload1 + (((XmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - ModRMByte = 0C0h ; register form - ModRMByte = ModRMByte + ((XmmIndex_&DestReg& AND 7) SHL 3) - ModRMByte = ModRMByte + (XmmIndex_&Src2Reg& AND 7) - - db 0C4h, Payload0, Payload1, Opcode, ModRMByte - - ENDM - -VpdpbusdXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 050h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbusdsXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 051h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 052h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpwssdsXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 053h, DestReg, Src1Reg, Src2Reg - - ENDM - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr ymm1,ymm2,ymm3 -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; Prefix - Specifies the opcode prefix for payload 1 -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; Src2Reg - Specifies the second source register. -; - -Avx2VnniYmmYmmYmm MACRO Opcode, Prefix, DestReg, Src1Reg, Src2Reg - - LOCAL Payload0, Payload1, ModRMByte - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((YmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) - Payload0 = Payload0 + (1 SHL 6) - Payload0 = Payload0 + ((((YmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) - - Payload1 = 004h + Prefix ; 256-bit length and opcode prefix - Payload1 = Payload1 + (((YmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - ModRMByte = 0C0h ; register form - ModRMByte = ModRMByte + ((YmmIndex_&DestReg& AND 7) SHL 3) - ModRMByte = ModRMByte + (YmmIndex_&Src2Reg& AND 7) - - db 0C4h, Payload0, Payload1, Opcode, ModRMByte - - ENDM - -VpdpbssdYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 050h, 003h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbssdsYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 051h, 003h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbsudYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 050h, 002h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbsudsYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 051h, 002h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbuudYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 050h, 000h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbuudsYmmYmmYmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 051h, 000h, DestReg, Src1Reg, Src2Reg - - ENDM - -; -; Macro Description: -; -; This macro builds a VNNI instruction of the form: -; -; instr xmm1,xmm2,xmm3 -; -; Arguments: -; -; Opcode - Specifies the opcode for the VNNI instruction. -; -; Prefix - Specifies the opcode prefix for payload 1 -; -; DestReg - Specifies the destination register. -; -; Src1Reg - Specifies the first source register. -; -; Src2Reg - Specifies the second source register. -; - -Avx2VnniXmmXmmXmm MACRO Opcode, Prefix, DestReg, Src1Reg, Src2Reg - - LOCAL Payload0, Payload1, ModRMByte - - Payload0 = 002h ; "0F 38" prefix - Payload0 = Payload0 + ((((XmmIndex_&DestReg& SHR 3) AND 1) XOR 1) SHL 7) - Payload0 = Payload0 + (1 SHL 6) - Payload0 = Payload0 + ((((XmmIndex_&Src2Reg& SHR 3) AND 1) XOR 1) SHL 5) - - Payload1 = 000h + Prefix ; 128-bit length and opcode prefix - Payload1 = Payload1 + (((XmmIndex_&Src1Reg& AND 15) XOR 15) SHL 3) - - ModRMByte = 0C0h ; register form - ModRMByte = ModRMByte + ((XmmIndex_&DestReg& AND 7) SHL 3) - ModRMByte = ModRMByte + (XmmIndex_&Src2Reg& AND 7) - - db 0C4h, Payload0, Payload1, Opcode, ModRMByte - - ENDM - -VpdpbssdXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 050h, 003h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbssdsXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 051h, 003h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbsudXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 050h, 002h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbsudsXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 051h, 002h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbuudXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 050h, 000h, DestReg, Src1Reg, Src2Reg - - ENDM - -VpdpbuudsXmmXmmXmm MACRO DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 051h, 000h, DestReg, Src1Reg, Src2Reg - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm deleted file mode 100644 index a42d7ff8730cb..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm +++ /dev/null @@ -1,974 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; ConvSymKernelAvx2.asm -; -; Abstract: -; -; This module implements the kernels for the symmetric quantized integer -; convolution operation. -; -; This implementation uses AVX2 and AVX VNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE ConvSymKernelCommon.inc -INCLUDE AssembleAvxVnni.inc - .list - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate a single row of the -; output block. -; -; Arguments: -; -; Vec1Reg - Supplies the low block accumulator register. -; -; Vec2Reg - Supplies the high block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from the filter buffer. -; -; ymm1 - Supplies the second vector loaded from the filter buffer. -; -; ymm2 - Supplies the broadcast value loaded from the input buffer. -; -; ymm3 - Supplies a scratch register for intermediate results. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateRowAvx2 MACRO Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 - vpaddd Vec1Reg,Vec1Reg,ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd Vec2Reg,Vec2Reg,ymm2 - - ENDM - -MultiplyAccumulateRowAvxVnni MACRO Vec1Reg, Vec2Reg - - VpdpbusdsYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbusdsYmmYmmYmm Vec2Reg,ymm2,ymm1 - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from the filter to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from the input to fetch elements. -; -; Implicit Arguments: -; -; rdx - Supplies the address of the filter buffer. -; -; r10 - Supplies the address of the base of the input buffer. -; -; Implicit Arguments (Avx2): -; -; r11-r13 - Supplies the relative byte offsets from the base of the input -; buffer to access the second through fourth rows. -; -; ymm4-ymm11 - Supplies the block accumulators. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; -; Implicit Arguments (AvxVnni): -; -; r11-r15 - Supplies the relative byte offsets from the base of the input -; buffer to access the second through sixth rows. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlock MACRO Isa, RowCount, VectorOffset, BroadcastOffset - - vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] - vmovdqu ymm1,YMMWORD PTR [rdx+VectorOffset+32] - EmitIfCountGE RowCount,1, - EmitIfCountGE RowCount,1, - EmitIfCountGE RowCount,2, - EmitIfCountGE RowCount,2, - EmitIfCountGE RowCount,3, - EmitIfCountGE RowCount,3, - EmitIfCountGE RowCount,4, - EmitIfCountGE RowCount,4, - EmitIfCountGE RowCount,5, - EmitIfCountGE RowCount,5, - EmitIfCountGE RowCount,6, - EmitIfCountGE RowCount,6, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple times -; and advancing the input and filter data pointers. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; RowCount - Supplies the number of rows to produce. -; -; UnrollLoop - Supplies a non-blank value if the loop should be unrolled to -; improve performance. -; -; Implicit Arguments: -; -; rax - Supplies the number of input channels. -; -; rdx - Supplies the address of the filter buffer. -; -; r10 - Supplies the address of the base of the input buffer. -; - -ComputeBlockLoop MACRO Isa, RowCount, UnrollLoop - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ComputeBlockLoopExit - -IFNB - sub rax,4*4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock Isa,RowCount,0*64,0 - ComputeBlock Isa,RowCount,1*64,4 - ComputeBlock Isa,RowCount,2*64,8 - ComputeBlock Isa,RowCount,3*64,12 - add r10,4*4 ; advance input base address - add rdx,4*16*4 ; advance filter address - sub rax,4*4 ; decrement elements remaining - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add rax,4*4 ; correct for over-subtract above - jz ComputeBlockLoopExit -ENDIF - -ComputeBlockBy1Loop: - ComputeBlock Isa,RowCount,0*64,0 - add r10,4 ; advance input base address - add rdx,16*4 ; advance filter address - sub rax,4 ; decrement elements remaining - jnz ComputeBlockBy1Loop - -ComputeBlockLoopExit: - - ENDM - -; -; Macro Description: -; -; This macro generates code to convert the block accumulators from the matrix -; multiply loop to float values. -; -; Arguments: -; -; RegList - Supplies the list of vector registers to operate on. -; -; Implicit Arguments: -; -; ymm0 - Supplies the integer bias vector. -; -; ymm1 - Supplies the output scale vector. -; - -ConvertAccumulatorToFloatRegList MACRO RegList - -; -; Offset each value by the per-channel bias value, convert to floating point, -; and apply the output scale. -; - - EmitForEachRegister , - EmitForEachRegister , - EmitForEachRegister , - - ENDM - -; -; Macro Description: -; -; This macro generates code to convert float values to 32-bit integers in the -; range 0 to 255. -; -; Arguments: -; -; RegList - Supplies the list of vector registers to operate on. -; -; Implicit Arguments: -; -; ymm0 - Supplies the broadcasted minimum clip float value. -; -; This is set to static_cast(0 - ZeroPointValue). -; -; ymm1 - Supplies the broadcasted maximum clip float value. -; -; This is set to static_cast(255 - ZeroPointValue). -; -; ymm2 - Supplies the broadcasted zero point integer value. -; - -ConvertFloatToIntegerRegList MACRO RegList - -; -; Clip the float values to the integer range covered by the output zero point. -; This also keeps values outside the range INT_MIN to INT_MAX from converting -; to INT_MIN. -; - - EmitForEachRegister , - EmitForEachRegister , - -; -; Convert the float value to integer and add the zero point offset. -; - - EmitForEachRegister , - EmitForEachRegister , - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner kernel to compute a convolution -; for the elements of an output row for a set of filter rows. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; - -ConvSymKernelFunction MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a convolution for the elements -; of an output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points -; directly at the input tensor. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an -; indirection buffer. Every pointer in the indirection buffer points at a -; InputChannels length vector (either from the input tensor or a vector of -; padding values). These are grouped in batches of length KernelSize. -; These batches are then repeated OutputCount times. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; KernelSize (r9) - Supplies the size of the kernel. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. -; -; InputChannels - Supplies the number of input channels. -; -; This implementation requires the count to be a multiple of 4. -; -; OutputChannels - Supplies the number of output channels. -; -; ChannelCount - Supplies the number of channels this iteration produces. -; -; This implementation requires the count to be 8 or 16. -; -; OutputCount - Supplies the number of output elements this iteration produces. -; -IFIDNI , -; This implementation requires the count to be in the range 1 to 6. -ELSE -; This implementation requires the count to be in the range 1 to 4. -ENDIF -; -; PostProcessParams - Supplies the address of the post process parameter block. -; -; KernelFlags - Supplies additional flags controlling the operation. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvSymKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (ConvSymKernelFrame.SavedR13) -IFIDNI , - save_reg r14,ConvSymKernelFrame.SavedR14 - save_reg r15,ConvSymKernelFrame.SavedR15 -ENDIF - save_xmm128 xmm6,ConvSymKernelFrame.SavedXmm6 - save_xmm128 xmm7,ConvSymKernelFrame.SavedXmm7 - save_xmm128 xmm8,ConvSymKernelFrame.SavedXmm8 - save_xmm128 xmm9,ConvSymKernelFrame.SavedXmm9 - save_xmm128 xmm10,ConvSymKernelFrame.SavedXmm10 - save_xmm128 xmm11,ConvSymKernelFrame.SavedXmm11 - save_xmm128 xmm12,ConvSymKernelFrame.SavedXmm12 -IFIDNI , - save_xmm128 xmm13,ConvSymKernelFrame.SavedXmm13 - save_xmm128 xmm14,ConvSymKernelFrame.SavedXmm14 - save_xmm128 xmm15,ConvSymKernelFrame.SavedXmm15 -ENDIF - - END_PROLOGUE - - lea rdi,[r9*8] - mov ebx,DWORD PTR ConvSymKernelFrame.OutputCount[rsp] - mov rsi,ConvSymKernelFrame.InputChannels[rsp] - mov ebp,DWORD PTR ConvSymKernelFrame.KernelFlags[rsp] - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - vpxor xmm6,xmm6,xmm6 - vpxor xmm7,xmm7,xmm7 - vpxor xmm8,xmm8,xmm8 - vpxor xmm9,xmm9,xmm9 - vpxor xmm10,xmm10,xmm10 - vpxor xmm11,xmm11,xmm11 -IFIDNI , - vpxor xmm12,xmm12,xmm12 - vpxor xmm13,xmm13,xmm13 - vpxor xmm14,xmm14,xmm14 - vpxor xmm15,xmm15,xmm15 -ELSE - vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] -ENDIF - -; -; Process an input block of length InputChannels for each element of the kernel. -; - -ProcessNextInputBlock: - test bpl,MLAS_CONV_SYM_FLAG_INPUT_DIRECT - jz InputIndirection - -; -; The input buffer points directly at the input data and this is effectively a -; GEMM operation (such as a pointwise convolution or an Im2Col transform). -; - -InputDirect: - xor r10,r10 - mov r11,rsi - lea r12,[r11+r11] - lea r13,[r12+r11] -IFIDNI , - lea r14,[r13+r11] - lea r15,[r14+r11] -ENDIF - cmp ebx,2 - cmovb r11,r10 ; use first row if output count is small - cmovbe r12,r10 - cmp ebx,4 - cmovb r13,r10 -IFIDNI , - cmovbe r14,r10 - cmp ebx,6 - cmovb r15,r10 -ENDIF - mov r10,rcx - jmp ComputeBlockLoopStart - -InputIndirection: - lea r11,[rcx+rdi] - lea r12,[rcx+rdi*2] - lea r13,[r11+rdi*2] -IFIDNI , - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] -ENDIF - cmp ebx,2 - cmovb r11,rcx ; use first row if output count is small - cmovbe r12,rcx - cmp ebx,4 - cmovb r13,rcx -IFIDNI , - cmovbe r14,rcx - cmp ebx,6 - cmovb r15,rcx -ENDIF - mov r10,QWORD PTR [rcx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] -IFIDNI , - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] -ENDIF - add rcx,8 ; advance indirection buffer address - sub r11,r10 ; compute deltas from base address - sub r12,r10 - sub r13,r10 -IFIDNI , - sub r14,r10 - sub r15,r10 -ENDIF - -ComputeBlockLoopStart: - mov rax,rsi ; reload input channels - cmp ebx,2 ; output count <= 2? - jbe ComputeBlockLoopBy2 -IFIDNI , - cmp ebx,4 ; output count <= 4? - jbe ComputeBlockLoopBy4 - ComputeBlockLoop Isa,6,UnrollLoop -ELSE - ComputeBlockLoop Isa,4,UnrollLoop -ENDIF - -ComputeBlockLoopDone: - dec r9 ; decrement input blocks remaining - jnz ProcessNextInputBlock - -; -; Apply the bias and convert the block accumulators to intermediate float values. -; - - mov rdx,ConvSymKernelFrame.PostProcessParams[rsp] - mov rsi,ConvSymKernelFrame.OutputChannels[rsp] - mov r11d,DWORD PTR ConvSymKernelFrame.ChannelCount[rsp] - mov rcx,ConvSymPostProcessParams.Bias[rdx] - mov r9,ConvSymPostProcessParams.Scale[rdx] - lea r10,[rsi*2+rsi] ; compute fourth row output offset - add r10,r8 - vmovdqu ymm0,YMMWORD PTR [rcx] ; load low bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz BroadcastScaleValue - vmovups ymm1,YMMWORD PTR [r9] ; load low scale vector - jmp ConvertLowAccumulatorsToFloat - -BroadcastScaleValue: - vbroadcastss ymm1,DWORD PTR [r9] - -ConvertLowAccumulatorsToFloat: -IFIDNI , - ConvertAccumulatorToFloatRegList -ELSE - ConvertAccumulatorToFloatRegList -ENDIF - cmp r11d,8 ; output single vector? - jbe ConvertFloatsToIntegers - vmovdqu ymm0,YMMWORD PTR [rcx+8*4] ; load high bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz ConvertHighAccumulatorsToFloat - vmovups ymm1,YMMWORD PTR [r9+8*4] ; load high scale vector - -ConvertHighAccumulatorsToFloat: -IFIDNI , - ConvertAccumulatorToFloatRegList -ELSE - ConvertAccumulatorToFloatRegList -ENDIF - -; -; Convert the intermediate float values to 32-bit integers in the range 0 to 255. -; - -ConvertFloatsToIntegers: - vbroadcastss ymm0,DWORD PTR ConvSymPostProcessParams.MinimumValue[rdx] - vbroadcastss ymm1,DWORD PTR ConvSymPostProcessParams.MaximumValue[rdx] - vpbroadcastd ymm2,DWORD PTR ConvSymPostProcessParams.OutputZeroPoint[rdx] -IFIDNI , - ConvertFloatToIntegerRegList -ELSE - ConvertFloatToIntegerRegList -ENDIF - cmp r11d,8 ; output single vector? - jbe StoreQuantizedOutputBy8 -IFIDNI , - ConvertFloatToIntegerRegList -ELSE - ConvertFloatToIntegerRegList -ENDIF - -; -; Pack with saturation and store 16 bytes to the output buffer. -; - -StoreQuantizedOutputBy16: -IFIDNI , - cmp ebx,5 - ja StoreQuantizedOutput6By16 - je StoreQuantizedOutput5By16 -ENDIF - cmp ebx,3 - ja StoreQuantizedOutput4By16 - je StoreQuantizedOutput3By16 - cmp ebx,1 - ja StoreQuantizedOutput2By16 - jmp StoreQuantizedOutput1By16 - -IFIDNI , -StoreQuantizedOutput6By16: - vextracti128 xmm0,ymm14,1 - vpackusdw xmm14,xmm14,xmm0 - vextracti128 xmm1,ymm15,1 - vpackusdw xmm15,xmm15,xmm1 - vpackuswb xmm14,xmm14,xmm15 - vmovdqu XMMWORD PTR [r10+rsi*2],xmm14 - -StoreQuantizedOutput5By16: - vextracti128 xmm0,ymm12,1 - vpackusdw xmm12,xmm12,xmm0 - vextracti128 xmm1,ymm13,1 - vpackusdw xmm13,xmm13,xmm1 - vpackuswb xmm12,xmm12,xmm13 - vmovdqu XMMWORD PTR [r10+rsi],xmm12 -ENDIF - -StoreQuantizedOutput4By16: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vextracti128 xmm1,ymm11,1 - vpackusdw xmm11,xmm11,xmm1 - vpackuswb xmm10,xmm10,xmm11 - vmovdqu XMMWORD PTR [r10],xmm10 - -StoreQuantizedOutput3By16: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vextracti128 xmm1,ymm9,1 - vpackusdw xmm9,xmm9,xmm1 - vpackuswb xmm8,xmm8,xmm9 - vmovdqu XMMWORD PTR [r8+rsi*2],xmm8 - -StoreQuantizedOutput2By16: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vextracti128 xmm1,ymm7,1 - vpackusdw xmm7,xmm7,xmm1 - vpackuswb xmm6,xmm6,xmm7 - vmovdqu XMMWORD PTR [r8+rsi],xmm6 - -StoreQuantizedOutput1By16: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vextracti128 xmm1,ymm5,1 - vpackusdw xmm5,xmm5,xmm1 - vpackuswb xmm4,xmm4,xmm5 - vmovdqu XMMWORD PTR [r8],xmm4 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,ConvSymKernelFrame.SavedXmm6[rsp] - movaps xmm7,ConvSymKernelFrame.SavedXmm7[rsp] - movaps xmm8,ConvSymKernelFrame.SavedXmm8[rsp] - movaps xmm9,ConvSymKernelFrame.SavedXmm9[rsp] - movaps xmm10,ConvSymKernelFrame.SavedXmm10[rsp] - movaps xmm11,ConvSymKernelFrame.SavedXmm11[rsp] - movaps xmm12,ConvSymKernelFrame.SavedXmm12[rsp] -IFIDNI , - movaps xmm13,ConvSymKernelFrame.SavedXmm13[rsp] - movaps xmm14,ConvSymKernelFrame.SavedXmm14[rsp] - movaps xmm15,ConvSymKernelFrame.SavedXmm15[rsp] - mov r14,ConvSymKernelFrame.SavedR14[rsp] - mov r15,ConvSymKernelFrame.SavedR15[rsp] -ENDIF - add rsp,(ConvSymKernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Pack with saturation and store 8 bytes to the output buffer. -; - -StoreQuantizedOutputBy8: -IFIDNI , - cmp ebx,5 - ja StoreQuantizedOutput6By8 - je StoreQuantizedOutput5By8 -ENDIF - cmp ebx,3 - ja StoreQuantizedOutput4By8 - je StoreQuantizedOutput3By8 - cmp ebx,1 - ja StoreQuantizedOutput2By8 - jmp StoreQuantizedOutput1By8 - -IFIDNI , -StoreQuantizedOutput6By8: - vextracti128 xmm0,ymm14,1 - vpackusdw xmm14,xmm14,xmm0 - vpackuswb xmm14,xmm14,xmm14 - vmovq QWORD PTR [r10+rsi*2],xmm14 - -StoreQuantizedOutput5By8: - vextracti128 xmm0,ymm12,1 - vpackusdw xmm12,xmm12,xmm0 - vpackuswb xmm12,xmm12,xmm12 - vmovq QWORD PTR [r10+rsi],xmm12 -ENDIF - -StoreQuantizedOutput4By8: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vpackuswb xmm10,xmm10,xmm10 - vmovq QWORD PTR [r10],xmm10 - -StoreQuantizedOutput3By8: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vpackuswb xmm8,xmm8,xmm8 - vmovq QWORD PTR [r8+rsi*2],xmm8 - -StoreQuantizedOutput2By8: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vpackuswb xmm6,xmm6,xmm6 - vmovq QWORD PTR [r8+rsi],xmm6 - -StoreQuantizedOutput1By8: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vpackuswb xmm4,xmm4,xmm4 - vmovq QWORD PTR [r8],xmm4 - jmp ExitKernel - -; -; Process the tail output counts out of line with a reduced block size. -; - -IFIDNI , -ComputeBlockLoopBy4: - ComputeBlockLoop Isa,4 - jmp ComputeBlockLoopDone -ENDIF - -ComputeBlockLoopBy2: - ComputeBlockLoop Isa,2 - jmp ComputeBlockLoopDone - - NESTED_END MlasConvSymKernel&Isa&, _TEXT - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate a single cell of the -; output block. -; -; Arguments: -; -; AccumReg - Supplies the register to accumulate into. -; -; Mult1Reg - Supplies the first multiplication operand register. This register -; may be trashed on return. -; -; Mult2Reg - Supplies the second multiplication operand register. -; - -DepthwiseMultiplyAccumulateCellAvx2 MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd Mult1Reg,Mult1Reg,Mult2Reg - vpaddd AccumReg,AccumReg,Mult1Reg - - ENDM - -DepthwiseMultiplyAccumulateCellAvxVnni MACRO AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsYmmYmmYmm AccumReg,Mult1Reg,Mult2Reg - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner kernel to compute a depthwise -; convolution for the elements of an output row for a set of filter rows. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; - -ConvSymDepthwiseKernelFunction MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a depthwise convolution for the -; elements of an output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the indirection buffer. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; KernelSize (r9) - Supplies the size of the kernel. -; -; Channels - Supplies the number of input and output channels. -; -; ChannelOffset - Supplies the byte offset from the indirection buffer base -; address for this iteration. -; -; ChannelCount - Supplies the number of channels this iteration produces. -; -; This implementation requires the count to be 16. -; -; OutputCount - Supplies the number of output elements this iteration produces. -; -; This implementation requires the count to be in the range 1 to 4. -; -; PostProcessParams - Supplies the address of the post process parameter block. -; -; KernelFlags - Supplies additional flags controlling the operation. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvSymDepthwiseKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (ConvSymDepthwiseKernelFrame.SavedR13) - save_xmm128 xmm6,ConvSymDepthwiseKernelFrame.SavedXmm6 - save_xmm128 xmm7,ConvSymDepthwiseKernelFrame.SavedXmm7 - save_xmm128 xmm8,ConvSymDepthwiseKernelFrame.SavedXmm8 - save_xmm128 xmm9,ConvSymDepthwiseKernelFrame.SavedXmm9 - save_xmm128 xmm10,ConvSymDepthwiseKernelFrame.SavedXmm10 - save_xmm128 xmm11,ConvSymDepthwiseKernelFrame.SavedXmm11 - - END_PROLOGUE - - lea rdi,[r9*8] - mov ebx,DWORD PTR ConvSymDepthwiseKernelFrame.OutputCount[rsp] - mov rsi,ConvSymDepthwiseKernelFrame.Channels[rsp] - mov rax,ConvSymDepthwiseKernelFrame.ChannelOffset[rsp] - mov ebp,DWORD PTR ConvSymDepthwiseKernelFrame.KernelFlags[rsp] - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - vpxor xmm6,xmm6,xmm6 - vpxor xmm7,xmm7,xmm7 - vpxor xmm8,xmm8,xmm8 - vpxor xmm9,xmm9,xmm9 - vpxor xmm10,xmm10,xmm10 - vpxor xmm11,xmm11,xmm11 - -; -; Process an input block of length Channels for each element of the kernel. -; - -ProcessNextInputBlock: - vpmovsxbd ymm0,QWORD PTR [rdx] - vpmovsxbd ymm1,QWORD PTR [rdx+8] - lea r11,[rcx+rdi] - lea r12,[rcx+rdi*2] - lea r13,[r11+rdi*2] - cmp ebx,2 - cmovb r11,rcx ; use first row if output count is small - cmovbe r12,rcx - cmp ebx,4 - cmovb r13,rcx - mov r10,QWORD PTR [rcx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - add rcx,8 ; advance indirection buffer address - vpmovzxbd ymm2,QWORD PTR [r10+rax] - vpmovzxbd ymm3,QWORD PTR [r10+rax+8] - DepthwiseMultiplyAccumulateCell&Isa& ymm4,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r11+rax] - DepthwiseMultiplyAccumulateCell&Isa& ymm5,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r11+rax+8] - DepthwiseMultiplyAccumulateCell&Isa& ymm6,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r12+rax] - DepthwiseMultiplyAccumulateCell&Isa& ymm7,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r12+rax+8] - DepthwiseMultiplyAccumulateCell&Isa& ymm8,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r13+rax] - DepthwiseMultiplyAccumulateCell&Isa& ymm9,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r13+rax+8] - DepthwiseMultiplyAccumulateCell&Isa& ymm10,ymm2,ymm0 - add rdx,rsi ; advance filter to next kernel - DepthwiseMultiplyAccumulateCell&Isa& ymm11,ymm3,ymm1 - dec r9 ; decrement input blocks remaining - jnz ProcessNextInputBlock - -; -; Apply the bias and convert the block accumulators to intermediate float values. -; - - mov rdx,ConvSymDepthwiseKernelFrame.PostProcessParams[rsp] - mov rcx,ConvSymPostProcessParams.Bias[rdx] - mov r9,ConvSymPostProcessParams.Scale[rdx] - vmovdqu ymm0,YMMWORD PTR [rcx] ; load low bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz BroadcastScaleValue - vmovups ymm1,YMMWORD PTR [r9] ; load low scale vector - jmp ConvertLowAccumulatorsToFloat - -BroadcastScaleValue: - vbroadcastss ymm1,DWORD PTR [r9] - -ConvertLowAccumulatorsToFloat: - ConvertAccumulatorToFloatRegList - vmovdqu ymm0,YMMWORD PTR [rcx+8*4] ; load high bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz ConvertHighAccumulatorsToFloat - vmovups ymm1,YMMWORD PTR [r9+8*4] ; load high scale vector - -ConvertHighAccumulatorsToFloat: - ConvertAccumulatorToFloatRegList - -; -; Convert the intermediate float values to 32-bit integers in the range 0 to 255. -; - -ConvertFloatsToIntegers: - vbroadcastss ymm0,DWORD PTR ConvSymPostProcessParams.MinimumValue[rdx] - vbroadcastss ymm1,DWORD PTR ConvSymPostProcessParams.MaximumValue[rdx] - vpbroadcastd ymm2,DWORD PTR ConvSymPostProcessParams.OutputZeroPoint[rdx] - ConvertFloatToIntegerRegList - ConvertFloatToIntegerRegList - -; -; Pack with saturation and store 16 bytes to the output buffer. -; - -StoreQuantizedOutputBy16: - lea r10,[rsi*2+rsi] - cmp ebx,3 - ja StoreQuantizedOutput4By16 - je StoreQuantizedOutput3By16 - cmp ebx,1 - ja StoreQuantizedOutput2By16 - jmp StoreQuantizedOutput1By16 - -StoreQuantizedOutput4By16: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vextracti128 xmm1,ymm11,1 - vpackusdw xmm11,xmm11,xmm1 - vpackuswb xmm10,xmm10,xmm11 - vmovdqu XMMWORD PTR [r8+r10],xmm10 - -StoreQuantizedOutput3By16: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vextracti128 xmm1,ymm9,1 - vpackusdw xmm9,xmm9,xmm1 - vpackuswb xmm8,xmm8,xmm9 - vmovdqu XMMWORD PTR [r8+rsi*2],xmm8 - -StoreQuantizedOutput2By16: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vextracti128 xmm1,ymm7,1 - vpackusdw xmm7,xmm7,xmm1 - vpackuswb xmm6,xmm6,xmm7 - vmovdqu XMMWORD PTR [r8+rsi],xmm6 - -StoreQuantizedOutput1By16: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vextracti128 xmm1,ymm5,1 - vpackusdw xmm5,xmm5,xmm1 - vpackuswb xmm4,xmm4,xmm5 - vmovdqu XMMWORD PTR [r8],xmm4 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,ConvSymDepthwiseKernelFrame.SavedXmm6[rsp] - movaps xmm7,ConvSymDepthwiseKernelFrame.SavedXmm7[rsp] - movaps xmm8,ConvSymDepthwiseKernelFrame.SavedXmm8[rsp] - movaps xmm9,ConvSymDepthwiseKernelFrame.SavedXmm9[rsp] - movaps xmm10,ConvSymDepthwiseKernelFrame.SavedXmm10[rsp] - movaps xmm11,ConvSymDepthwiseKernelFrame.SavedXmm11[rsp] - add rsp,(ConvSymDepthwiseKernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConvSymDepthwiseKernel&Isa&, _TEXT - - ENDM - -; -; Generate the convolution kernels. -; - -ConvSymKernelFunction Avx2 -ConvSymDepthwiseKernelFunction Avx2 - -ConvSymKernelFunction AvxVnni -ConvSymDepthwiseKernelFunction AvxVnni - - END diff --git a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx512Core.asm deleted file mode 100644 index cd1ddb55e8b74..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx512Core.asm +++ /dev/null @@ -1,926 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; ConvSymKernelAvx512Core.asm -; -; Abstract: -; -; This module implements the kernels for the symmetric quantized integer -; convolution operation. -; -; This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE ConvSymKernelCommon.inc -INCLUDE AssembleAvx512Vnni.inc - .list - -; -; Macro Description: -; -; This macro generates code to setup registers that is common between -; convolution kernel types. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; r9 - Supplies the size of the kernel. -; -; Output: -; -; rbx - Supplies the address of the input buffer. -; -; rdi - Supplies the input indirection buffer stride. -; -IFIDNI , -; zmm7 - Supplies a 512-bit with the broadcasted word value 0x0001. -ENDIF -; -; zmm8-zmm31 - Supplies the zeroed block accumulators. -; -; k1-k4 - Supplies the opmask registers loaded with a 64-bit channel bitmask -; for KernelFrame.ChannelCount. -; - -SetupRegistersCommon MACRO Isa, KernelFrame - - mov rbx,rcx ; preserve base input address - lea rdi,[r9*8] ; indirection buffer offset to next output -IFIDNI , - mov esi,1 - vpbroadcastw zmm7,esi ; generate 512-bit word vector [0x0001] -ENDIF - EmitForEachRegister , - mov ecx,DWORD PTR KernelFrame.ChannelCount[rsp] - EmitForEachRegister , - dec ecx ; convert shift count to 0..63 - mov eax,2 - shl rax,cl ; compute 2 << ChannelShiftCount - dec rax ; convert to 64-bit channel bitmask - EmitForEachRegister , - kmovw k1,eax ; k1 = channel bitmask[0..15] - shr rax,16 - EmitForEachRegister , - kmovw k2,eax ; k2 = channel bitmask[16..31] - shr rax,16 - EmitForEachRegister , - kmovw k3,eax ; k3 = channel bitmask[32..47] - shr eax,16 - EmitForEachRegister , - kmovw k4,eax ; k4 = channel bitmask[48..63] - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate a single cell of the -; output block. -; -; Arguments: -; -; AccumReg - Supplies the register to accumulate into. -; -; Mult1Reg - Supplies the first multiplication operand register. -; -; Mult2Reg - Supplies the second multiplication operand register. -; -; Implicit Arguments: -; -; zmm5 - Supplies a scratch register for intermediate results. -; -; zmm7 - Supplies a 512-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateCellAvx512Core MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm5,Mult1Reg,Mult2Reg - vpmaddwd zmm5,zmm5,zmm7 - vpaddd AccumReg,AccumReg,zmm5 - - ENDM - -MultiplyAccumulateCellAvx512Vnni MACRO AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsZmmZmmZmm AccumReg,Mult1Reg,Mult2Reg - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; VectorOffset - Supplies the byte offset from the filter to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from the input to fetch elements. -; -; Implicit Arguments: -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the filter stride to access the packed data for the next 16 -; output channels. -; -; rbp - Supplies three times the above filter stride. -; -; r10 - Supplies the address of the base of the input buffer. -; -; r11-r15 - Supplies the relative byte offsets from the base of the input -; buffer to access the second through sixth rows. -; -; zmm8-zmm31 - Supplies the block accumulators. -; - -ComputeBlock MACRO Isa, ColumnCount, VectorOffset, BroadcastOffset - - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+r11+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+r12+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+r13+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+r14+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - vpbroadcastd zmm4,DWORD PTR [r10+r15+BroadcastOffset] - EmitIfCountGE ColumnCount,16, - EmitIfCountGE ColumnCount,32, - EmitIfCountGE ColumnCount,48, - EmitIfCountGE ColumnCount,64, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple times -; and advancing the input and filter data pointers. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; Implicit Arguments: -; -; rax - Supplies the number of byte elements to process (multiple of 4). -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the filter stride to access the packed data for the next 16 -; output channels. -; -; rbp - Supplies three times the above filter stride. -; -; r10 - Supplies the address of the base of the input buffer. -; -; r11-r15 - Supplies the relative byte offsets from the base of the input -; buffer to access the second through sixth rows. -; -; zmm8-zmm31 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO Isa, ColumnCount - - LOCAL ComputeBlockBy1Loop - -ComputeBlockBy1Loop: - ComputeBlock Isa,ColumnCount,0*64,0 - add r10,4 ; advance input base address - add rdx,16*4 ; advance filter address - sub rax,4 ; decrement elements remaining - jnz ComputeBlockBy1Loop - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner kernel to compute a convolution -; for the elements of an output row for a set of filter rows. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; - -ConvSymKernelFunction MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a convolution for the elements -; of an output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points -; directly at the input tensor. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an -; indirection buffer. Every pointer in the indirection buffer points at a -; InputChannels length vector (either from the input tensor or a vector of -; padding values). These are grouped in batches of length KernelSize. -; These batches are then repeated OutputCount times. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; KernelSize (r9) - Supplies the size of the kernel. -; -; If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. -; -; InputChannels - Supplies the number of input channels. -; -; This implementation requires the count to be a multiple of 4. -; -; OutputChannels - Supplies the number of output channels. -; -; ChannelCount - Supplies the number of channels this iteration produces. -; -; This implementation requires the count to be in the range 1 to 64. -; -; OutputCount - Supplies the number of output elements this iteration produces. -; -; This implementation requires the count to be in the range 1 to 6. -; -; PostProcessParams - Supplies the address of the post process parameter block. -; -; KernelFlags - Supplies additional flags controlling the operation. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvSymKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - push_reg r14 - push_reg r15 - alloc_stack (ConvSymKernelFrame.SavedR15) - save_xmm128 xmm6,ConvSymKernelFrame.SavedXmm6 - save_xmm128 xmm7,ConvSymKernelFrame.SavedXmm7 - save_xmm128 xmm8,ConvSymKernelFrame.SavedXmm8 - save_xmm128 xmm9,ConvSymKernelFrame.SavedXmm9 - save_xmm128 xmm10,ConvSymKernelFrame.SavedXmm10 - save_xmm128 xmm11,ConvSymKernelFrame.SavedXmm11 - save_xmm128 xmm12,ConvSymKernelFrame.SavedXmm12 - save_xmm128 xmm13,ConvSymKernelFrame.SavedXmm13 - save_xmm128 xmm14,ConvSymKernelFrame.SavedXmm14 - save_xmm128 xmm15,ConvSymKernelFrame.SavedXmm15 - - END_PROLOGUE - - SetupRegistersCommon Isa,ConvSymKernelFrame - - mov rsi,ConvSymKernelFrame.InputChannels[rsp] - mov ecx,DWORD PTR ConvSymKernelFrame.ChannelCount[rsp] - shl rsi,4 ; 16 output channels per filter block - imul rsi,r9 ; compute filter stride - lea rbp,[rsi*2+rsi] - -; -; Process an input block of length InputChannels for each element of the kernel. -; -; To keep code size small, this kernel always computes a fixed number of output -; rows. If the output count is less than this fixed number, then the first row -; is duplicated into the unused slots and the results are discarded. -; - -ProcessNextInputBlock: - mov eax,DWORD PTR ConvSymKernelFrame.OutputCount[rsp] - test BYTE PTR ConvSymKernelFrame.KernelFlags[rsp],MLAS_CONV_SYM_FLAG_INPUT_DIRECT - jz InputIndirection - -; -; The input buffer points directly at the input data and this is effectively a -; GEMM operation (such as a pointwise convolution or an Im2Col transform). -; - -InputDirect: - xor r10,r10 - mov r11,ConvSymKernelFrame.InputChannels[rsp] - lea r12,[r11+r11] - lea r13,[r12+r11] - lea r14,[r13+r11] - lea r15,[r14+r11] - cmp eax,2 - cmovb r11,r10 ; use first row if output count is small - cmovbe r12,r10 - cmp eax,4 - cmovb r13,r10 - cmovbe r14,r10 - cmp eax,6 - cmovb r15,r10 - mov r10,rbx - jmp ComputeBlockLoopStart - -InputIndirection: - lea r11,[rbx+rdi] - lea r12,[rbx+rdi*2] - lea r13,[r11+rdi*2] - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] - cmp eax,2 - cmovb r11,rbx ; use first row if output count is small - cmovbe r12,rbx - cmp eax,4 - cmovb r13,rbx - cmovbe r14,rbx - cmp eax,6 - cmovb r15,rbx - mov r10,QWORD PTR [rbx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] - add rbx,8 ; advance indirection buffer address - sub r11,r10 ; compute deltas from base address - sub r12,r10 - sub r13,r10 - sub r14,r10 - sub r15,r10 - -ComputeBlockLoopStart: - mov rax,ConvSymKernelFrame.InputChannels[rsp] - cmp ecx,16 - jbe ComputeBlockLoopBy16 - cmp ecx,32 - jbe ComputeBlockLoopBy32 - cmp ecx,48 - jbe ComputeBlockLoopBy48 - -ComputeBlockLoopBy64: - ComputeBlockLoop Isa,64 - jmp ComputeBlockLoopDone - -ComputeBlockLoopBy48: - ComputeBlockLoop Isa,48 - jmp ComputeBlockLoopDone - -ComputeBlockLoopBy32: - ComputeBlockLoop Isa,32 - jmp ComputeBlockLoopDone - -ComputeBlockLoopBy16: - ComputeBlockLoop Isa,16 - -ComputeBlockLoopDone: - dec r9 ; decrement input blocks remaining - jnz ProcessNextInputBlock - -; -; Post-process the block accumulators. -; - - mov ebx,DWORD PTR ConvSymKernelFrame.OutputCount[rsp] - mov rsi,ConvSymKernelFrame.OutputChannels[rsp] - mov rdx,ConvSymKernelFrame.PostProcessParams[rsp] - mov ebp,DWORD PTR ConvSymKernelFrame.KernelFlags[rsp] - call MlasConvSymPostProcessAvx512Core - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,ConvSymKernelFrame.SavedXmm6[rsp] - movaps xmm7,ConvSymKernelFrame.SavedXmm7[rsp] - movaps xmm8,ConvSymKernelFrame.SavedXmm8[rsp] - movaps xmm9,ConvSymKernelFrame.SavedXmm9[rsp] - movaps xmm10,ConvSymKernelFrame.SavedXmm10[rsp] - movaps xmm11,ConvSymKernelFrame.SavedXmm11[rsp] - movaps xmm12,ConvSymKernelFrame.SavedXmm12[rsp] - movaps xmm13,ConvSymKernelFrame.SavedXmm13[rsp] - movaps xmm14,ConvSymKernelFrame.SavedXmm14[rsp] - movaps xmm15,ConvSymKernelFrame.SavedXmm15[rsp] - add rsp,(ConvSymKernelFrame.SavedR15) - - BEGIN_EPILOGUE - - pop r15 - pop r14 - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConvSymKernel&Isa&, _TEXT - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner kernel to compute a depthwise -; convolution for the elements of an output row for a set of filter rows. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; - -ConvSymDepthwiseKernelFunction MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a depthwise convolution for the -; elements of an output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input indirection buffer. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; KernelSize (r9) - Supplies the size of the kernel. -; -; Channels - Supplies the number of input and output channels. -; -; ChannelOffset - Supplies the byte offset from the indirection buffer base -; address for this iteration. -; -; ChannelCount - Supplies the number of channels this iteration produces. -; -; This implementation requires the count to be in the range 1 to 64. -; -; OutputCount - Supplies the number of output elements this iteration produces. -; -; This implementation requires the count to be in the range 1 to 6. -; -; PostProcessParams - Supplies the address of the post process parameter block. -; -; KernelFlags - Supplies additional flags controlling the operation. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvSymDepthwiseKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - push_reg r14 - push_reg r15 - alloc_stack (ConvSymDepthwiseKernelFrame.SavedR15) - save_xmm128 xmm6,ConvSymDepthwiseKernelFrame.SavedXmm6 - save_xmm128 xmm7,ConvSymDepthwiseKernelFrame.SavedXmm7 - save_xmm128 xmm8,ConvSymDepthwiseKernelFrame.SavedXmm8 - save_xmm128 xmm9,ConvSymDepthwiseKernelFrame.SavedXmm9 - save_xmm128 xmm10,ConvSymDepthwiseKernelFrame.SavedXmm10 - save_xmm128 xmm11,ConvSymDepthwiseKernelFrame.SavedXmm11 - save_xmm128 xmm12,ConvSymDepthwiseKernelFrame.SavedXmm12 - save_xmm128 xmm13,ConvSymDepthwiseKernelFrame.SavedXmm13 - save_xmm128 xmm14,ConvSymDepthwiseKernelFrame.SavedXmm14 - save_xmm128 xmm15,ConvSymDepthwiseKernelFrame.SavedXmm15 - - END_PROLOGUE - - SetupRegistersCommon Isa,ConvSymDepthwiseKernelFrame - - mov rsi,ConvSymDepthwiseKernelFrame.Channels[rsp] - mov ebp,DWORD PTR ConvSymDepthwiseKernelFrame.OutputCount[rsp] - mov rax,ConvSymDepthwiseKernelFrame.ChannelOffset[rsp] - mov ecx,DWORD PTR ConvSymDepthwiseKernelFrame.ChannelCount[rsp] - -; -; Process an input block of length Channels for each element of the kernel. -; -; To keep code size small, this kernel always computes a fixed number of output -; rows. If the output count is less than this fixed number, then the first row -; is duplicated into the unused slots and the results are discarded. -; - -ProcessNextInputBlock: - lea r11,[rbx+rdi] - lea r12,[rbx+rdi*2] - lea r13,[r11+rdi*2] - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] - cmp ebp,2 - cmovb r11,rbx ; use first row if output count is small - cmovbe r12,rbx - cmp ebp,4 - cmovb r13,rbx - cmovbe r14,rbx - cmp ebp,6 - cmovb r15,rbx - mov r10,QWORD PTR [rbx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] - add rbx,8 - cmp ecx,16 - jbe ComputeDepthwiseBlockBy16 - cmp ecx,32 - jbe ComputeDepthwiseBlockBy32 - cmp ecx,48 - jbe ComputeDepthwiseBlockBy48 - -ComputeDepthwiseBlockBy64: - vpmovzxbd zmm2{k4}{z},XMMWORD PTR [rdx+3*16] - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r10+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r11+rax+3*16] - MultiplyAccumulateCell&Isa& zmm11,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm15,zmm1,zmm2 - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r12+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r13+rax+3*16] - MultiplyAccumulateCell&Isa& zmm19,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm23,zmm1,zmm2 - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r14+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r15+rax+3*16] - MultiplyAccumulateCell&Isa& zmm27,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm31,zmm1,zmm2 - -ComputeDepthwiseBlockBy48: - vpmovzxbd zmm2{k3}{z},XMMWORD PTR [rdx+2*16] - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r10+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r11+rax+2*16] - MultiplyAccumulateCell&Isa& zmm10,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm14,zmm1,zmm2 - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r12+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r13+rax+2*16] - MultiplyAccumulateCell&Isa& zmm18,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm22,zmm1,zmm2 - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r14+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r15+rax+2*16] - MultiplyAccumulateCell&Isa& zmm26,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm30,zmm1,zmm2 - -ComputeDepthwiseBlockBy32: - vpmovzxbd zmm2{k2}{z},XMMWORD PTR [rdx+1*16] - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r10+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r11+rax+1*16] - MultiplyAccumulateCell&Isa& zmm9,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm13,zmm1,zmm2 - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r12+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r13+rax+1*16] - MultiplyAccumulateCell&Isa& zmm17,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm21,zmm1,zmm2 - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r14+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r15+rax+1*16] - MultiplyAccumulateCell&Isa& zmm25,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm29,zmm1,zmm2 - -ComputeDepthwiseBlockBy16: - vpmovzxbd zmm2{k1}{z},XMMWORD PTR [rdx] - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r10+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r11+rax] - MultiplyAccumulateCell&Isa& zmm8,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm12,zmm1,zmm2 - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r12+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r13+rax] - MultiplyAccumulateCell&Isa& zmm16,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm20,zmm1,zmm2 - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r14+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r15+rax] - MultiplyAccumulateCell&Isa& zmm24,zmm0,zmm2 - MultiplyAccumulateCell&Isa& zmm28,zmm1,zmm2 - add rdx,rsi ; advance filter to next kernel - dec r9 ; decrement input blocks remaining - jnz ProcessNextInputBlock - -; -; Post-process the block accumulators. -; - - mov ebx,ebp - mov rdx,ConvSymDepthwiseKernelFrame.PostProcessParams[rsp] - mov ebp,DWORD PTR ConvSymDepthwiseKernelFrame.KernelFlags[rsp] - call MlasConvSymPostProcessAvx512Core - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,ConvSymDepthwiseKernelFrame.SavedXmm6[rsp] - movaps xmm7,ConvSymDepthwiseKernelFrame.SavedXmm7[rsp] - movaps xmm8,ConvSymDepthwiseKernelFrame.SavedXmm8[rsp] - movaps xmm9,ConvSymDepthwiseKernelFrame.SavedXmm9[rsp] - movaps xmm10,ConvSymDepthwiseKernelFrame.SavedXmm10[rsp] - movaps xmm11,ConvSymDepthwiseKernelFrame.SavedXmm11[rsp] - movaps xmm12,ConvSymDepthwiseKernelFrame.SavedXmm12[rsp] - movaps xmm13,ConvSymDepthwiseKernelFrame.SavedXmm13[rsp] - movaps xmm14,ConvSymDepthwiseKernelFrame.SavedXmm14[rsp] - movaps xmm15,ConvSymDepthwiseKernelFrame.SavedXmm15[rsp] - add rsp,(ConvSymDepthwiseKernelFrame.SavedR15) - - BEGIN_EPILOGUE - - pop r15 - pop r14 - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConvSymDepthwiseKernel&Isa&, _TEXT - - ENDM - -; -; Macro Description: -; -; This macro generates code to convert the block accumulators from the matrix -; multiply loop to float values. -; -; Arguments: -; -; RegList - Supplies the list of vector registers to operate on. -; -; ScaleReg - Supplies the output scale vector. -; -; Implicit Arguments: -; -; zmm4 - Supplies the integer bias vector. -; - -ConvertAccumulatorToFloatRegList MACRO RegList, ScaleReg - -; -; Offset each value by the per-channel bias value, convert to floating point, -; and apply the output scale. -; - - EmitForEachRegister , - EmitForEachRegister , - EmitForEachRegister , - - ENDM - -; -; Macro Description: -; -; This macro generates code to convert float values to 32-bit integers in the -; range 0 to 255. -; -; Arguments: -; -; RegList - Supplies the list of vector registers to operate on. -; -; Implicit Arguments: -; -; zmm0 - Supplies the broadcasted minimum clip float value. -; -; This is set to static_cast(0 - ZeroPointValue). -; -; zmm1 - Supplies the broadcasted maximum clip float value. -; -; This is set to static_cast(255 - ZeroPointValue). -; -; zmm2 - Supplies the broadcasted zero point integer value. -; - -ConvertFloatToIntegerRegList MACRO RegList - -; -; Clip the float values to the integer range covered by the output zero point. -; This also keeps values outside the range INT_MIN to INT_MAX from converting -; to INT_MIN. -; - - EmitForEachRegister , - EmitForEachRegister , - -; -; Convert the float value to integer and add the zero point offset. -; - - EmitForEachRegister , - EmitForEachRegister , - - ENDM - -;++ -; -; Routine Description: -; -; This routine post processes the block accumulators produced by the convolution -; kernels, including type conversion, requantization, and storing to the output -; buffer. -; -; Arguments: -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasConvSymPostProcessAvx512Core, _TEXT - -; -; Apply the bias and convert the block accumulators to intermediate float values. -; - - mov r10,ConvSymPostProcessParams.Bias[rdx] - mov r11,ConvSymPostProcessParams.Scale[rdx] - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz BroadcastScaleValue - vmovups zmm0{k1}{z},ZMMWORD PTR [r11] - vmovups zmm1{k2}{z},ZMMWORD PTR [r11+16*4] - vmovups zmm2{k3}{z},ZMMWORD PTR [r11+32*4] - vmovups zmm3{k4}{z},ZMMWORD PTR [r11+48*4] - jmp ConvertAccumulatorsToFloat - -BroadcastScaleValue: - vbroadcastss zmm0,DWORD PTR [r11] - vmovups zmm1,zmm0 - vmovups zmm2,zmm0 - vmovups zmm3,zmm0 - -ConvertAccumulatorsToFloat: - cmp ecx,16 - jbe ConvertAccumulatorsToFloatBy16 - cmp ecx,32 - jbe ConvertAccumulatorsToFloatBy32 - cmp ecx,48 - jbe ConvertAccumulatorsToFloatBy48 - -ConvertAccumulatorsToFloatBy64: - vmovdqu32 zmm4{k4}{z},ZMMWORD PTR [r10+48*4] - ConvertAccumulatorToFloatRegList ,zmm3 - -ConvertAccumulatorsToFloatBy48: - vmovdqu32 zmm4{k3}{z},ZMMWORD PTR [r10+32*4] - ConvertAccumulatorToFloatRegList ,zmm2 - -ConvertAccumulatorsToFloatBy32: - vmovdqu32 zmm4{k2}{z},ZMMWORD PTR [r10+16*4] - ConvertAccumulatorToFloatRegList ,zmm1 - -ConvertAccumulatorsToFloatBy16: - vmovdqu32 zmm4{k1}{z},ZMMWORD PTR [r10] - ConvertAccumulatorToFloatRegList ,zmm0 - -; -; Convert the intermediate float values to 32-bit integers in the range 0 to 255. -; - - vbroadcastss zmm0,DWORD PTR ConvSymPostProcessParams.MinimumValue[rdx] - vbroadcastss zmm1,DWORD PTR ConvSymPostProcessParams.MaximumValue[rdx] - vpbroadcastd zmm2,DWORD PTR ConvSymPostProcessParams.OutputZeroPoint[rdx] - cmp ecx,16 - jbe ConvertFloatsToIntegerBy16 - cmp ecx,32 - jbe ConvertFloatsToIntegerBy32 - cmp ecx,48 - jbe ConvertFloatsToIntegerBy48 - -ConvertFloatsToIntegerBy64: - ConvertFloatToIntegerRegList - -ConvertFloatsToIntegerBy48: - ConvertFloatToIntegerRegList - -ConvertFloatsToIntegerBy32: - ConvertFloatToIntegerRegList - -ConvertFloatsToIntegerBy16: - ConvertFloatToIntegerRegList - -; -; Pack with saturation and store 1 to 64 bytes to the output buffer. -; - -StoreQuantizedOutput: - lea r9,[rsi*2+rsi] - add r9,r8 - cmp ebx,5 - ja StoreQuantizedOutput6 - je StoreQuantizedOutput5 - cmp ebx,3 - ja StoreQuantizedOutput4 - je StoreQuantizedOutput3 - cmp ebx,1 - ja StoreQuantizedOutput2 - jmp StoreQuantizedOutput1 - -StoreQuantizedOutput6: - vpmovusdb XMMWORD PTR [r9+rsi*2]{k1},zmm28 - vpmovusdb XMMWORD PTR [r9+rsi*2+16]{k2},zmm29 - vpmovusdb XMMWORD PTR [r9+rsi*2+32]{k3},zmm30 - vpmovusdb XMMWORD PTR [r9+rsi*2+48]{k4},zmm31 - -StoreQuantizedOutput5: - vpmovusdb XMMWORD PTR [r9+rsi]{k1},zmm24 - vpmovusdb XMMWORD PTR [r9+rsi+16]{k2},zmm25 - vpmovusdb XMMWORD PTR [r9+rsi+32]{k3},zmm26 - vpmovusdb XMMWORD PTR [r9+rsi+48]{k4},zmm27 - -StoreQuantizedOutput4: - vpmovusdb XMMWORD PTR [r9]{k1},zmm20 - vpmovusdb XMMWORD PTR [r9+16]{k2},zmm21 - vpmovusdb XMMWORD PTR [r9+32]{k3},zmm22 - vpmovusdb XMMWORD PTR [r9+48]{k4},zmm23 - -StoreQuantizedOutput3: - vpmovusdb XMMWORD PTR [r8+rsi*2]{k1},zmm16 - vpmovusdb XMMWORD PTR [r8+rsi*2+16]{k2},zmm17 - vpmovusdb XMMWORD PTR [r8+rsi*2+32]{k3},zmm18 - vpmovusdb XMMWORD PTR [r8+rsi*2+48]{k4},zmm19 - -StoreQuantizedOutput2: - vpmovusdb XMMWORD PTR [r8+rsi]{k1},zmm12 - vpmovusdb XMMWORD PTR [r8+rsi+16]{k2},zmm13 - vpmovusdb XMMWORD PTR [r8+rsi+32]{k3},zmm14 - vpmovusdb XMMWORD PTR [r8+rsi+48]{k4},zmm15 - -StoreQuantizedOutput1: - vpmovusdb XMMWORD PTR [r8]{k1},zmm8 - vpmovusdb XMMWORD PTR [r8+16]{k2},zmm9 - vpmovusdb XMMWORD PTR [r8+32]{k3},zmm10 - vpmovusdb XMMWORD PTR [r8+48]{k4},zmm11 - ret - - LEAF_END MlasConvSymPostProcessAvx512Core, _TEXT - -; -; Generate the convolution kernels. -; - -ConvSymKernelFunction Avx512Core -ConvSymDepthwiseKernelFunction Avx512Core - -ConvSymKernelFunction Avx512Vnni -ConvSymDepthwiseKernelFunction Avx512Vnni - - END diff --git a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/ConvSymKernelCommon.inc deleted file mode 100644 index 29a09395a696e..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelCommon.inc +++ /dev/null @@ -1,111 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; ConvSymKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the symmetric -; quantized integer convolution operation. -; -;-- - -; -; Define the convolution kernel flags. -; - -MLAS_CONV_SYM_FLAG_INPUT_DIRECT EQU 00000001h -MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE EQU 00000002h - -; -; Define the structure of the post process parameter block. -; - -ConvSymPostProcessParams STRUCT - - Bias QWORD ? - Scale QWORD ? - MinimumValue DWORD ? - MaximumValue DWORD ? - OutputZeroPoint DWORD ? - -ConvSymPostProcessParams ENDS - -; -; Stack frame layout for the symmetric convolution kernels. -; - -ConvSymKernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR15 QWORD ? - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - InputChannels QWORD ? - OutputChannels QWORD ? - ChannelCount QWORD ? - OutputCount QWORD ? - PostProcessParams QWORD ? - KernelFlags QWORD ? - -ConvSymKernelFrame ENDS - -ConvSymDepthwiseKernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR15 QWORD ? - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - Channels QWORD ? - ChannelOffset QWORD ? - ChannelCount QWORD ? - OutputCount QWORD ? - PostProcessParams QWORD ? - KernelFlags QWORD ? - -ConvSymDepthwiseKernelFrame ENDS diff --git a/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx.asm deleted file mode 100644 index a7f9f6a206766..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; DgemmKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the double precision matrix/matrix -; multiply operation (DGEMM). -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE DgemmKernelCommon.inc -INCLUDE FgemmKernelAvxCommon.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelAvxFunction Double - - END diff --git a/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx512F.asm deleted file mode 100644 index 87d1a2aec82e1..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/DgemmKernelAvx512F.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; DgemmKernelAvx512F.asm -; -; Abstract: -; -; This module implements the kernels for the double precision matrix/matrix -; multiply operation (DGEMM). -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE DgemmKernelCommon.inc -INCLUDE FgemmKernelAvx512FCommon.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelAvx512FFunction Double - - END diff --git a/onnxruntime/core/mlas/lib/amd64/DgemmKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/DgemmKernelCommon.inc deleted file mode 100644 index 52ee7156437af..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/DgemmKernelCommon.inc +++ /dev/null @@ -1,45 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; DgemmKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the double -; precision matrix/matrix multiply operation (DGEMM). -; -;-- - -; -; Define the double precision parameters. -; - -FgemmElementShift EQU 3 -FgemmElementSize EQU (1 SHL FgemmElementShift) -FgemmElementPtr EQU QWORD PTR -FgemmElementBcst EQU QWORD BCST - -; -; Define the typed instructions for double precision. -; - -addpf EQU addpd -movupf EQU movupd - -vaddpf EQU vaddpd -vbroadcastsf EQU vbroadcastsd -vfmadd213pf EQU vfmadd213pd -vfmadd231pf EQU vfmadd231pd -vmaskmovpf EQU vmaskmovpd -vmovapf EQU vmovapd -vmovsf EQU vmovsd -vmovupf EQU vmovupd -vmulpf EQU vmulpd -vxorpf EQU vxorpd - -INCLUDE FgemmKernelCommon.inc diff --git a/onnxruntime/core/mlas/lib/amd64/DgemmKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/DgemmKernelFma3.asm deleted file mode 100644 index 5cddec31ddbec..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/DgemmKernelFma3.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; DgemmKernelFma3.asm -; -; Abstract: -; -; This module implements the kernels for the double precision matrix/matrix -; multiply operation (DGEMM). -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE DgemmKernelCommon.inc -INCLUDE FgemmKernelFma3Common.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelFma3Function Double - - END diff --git a/onnxruntime/core/mlas/lib/amd64/DgemmKernelSse2.asm b/onnxruntime/core/mlas/lib/amd64/DgemmKernelSse2.asm deleted file mode 100644 index 6ac3bed97a641..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/DgemmKernelSse2.asm +++ /dev/null @@ -1,233 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; DgemmKernelSse2.asm -; -; Abstract: -; -; This module implements the kernels for the double precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses SSE2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE DgemmKernelCommon.inc -INCLUDE FgemmKernelSse2Common.inc - .list - -; -; Macro Description: -; -; This macro multiplies and accumulates for a 8xN block of the output matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rdx - Supplies the address into the matrix B data. -; -; xmm0-xmm1 - Supplies up to four elements loaded from matrix A and matrix A -; plus one row. -; -; xmm8-xmm15 - Supplies the block accumulators. -; - -ComputeBlockSseBy8 MACRO RowCount - - movapd xmm4,XMMWORD PTR [rdx] - movapd xmm5,XMMWORD PTR [rdx+16] -IF RowCount EQ 2 - movapd xmm6,xmm4 - movapd xmm7,xmm5 -ENDIF - mulpd xmm4,xmm0 - mulpd xmm5,xmm0 - addpd xmm8,xmm4 - addpd xmm9,xmm5 -IF RowCount EQ 2 - mulpd xmm6,xmm1 - mulpd xmm7,xmm1 - addpd xmm12,xmm6 - addpd xmm13,xmm7 -ENDIF - movapd xmm4,XMMWORD PTR [rdx+32] - movapd xmm5,XMMWORD PTR [rdx+48] -IF RowCount EQ 2 - movapd xmm6,xmm4 - movapd xmm7,xmm5 -ENDIF - mulpd xmm4,xmm0 - mulpd xmm5,xmm0 - addpd xmm10,xmm4 - addpd xmm11,xmm5 -IF RowCount EQ 2 - mulpd xmm6,xmm1 - mulpd xmm7,xmm1 - addpd xmm14,xmm6 - addpd xmm15,xmm7 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernel label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; rsi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r8 - Supplies the address of matrix C. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop8xN - LOCAL Compute8xNBlockBy1Loop - LOCAL Output8xNBlock - LOCAL OutputPartial8xNBlock - LOCAL OutputPartialLessThan6xNBlock - LOCAL OutputPartialLessThan4xNBlock - LOCAL OutputPartial1xNBlock - LOCAL SkipAccumulateOutput1xN - -ProcessNextColumnLoop8xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - mov rdi,r9 ; reload CountK - -Compute8xNBlockBy1Loop: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 1, - ComputeBlockSseBy8 RowCount - add rdx,8*8 ; advance matrix B by 8 columns - add rcx,8 ; advance matrix A by 1 column - dec rdi - jne Compute8xNBlockBy1Loop - -Output8xNBlock: - movsd xmm2,QWORD PTR FgemmKernelFrame.Alpha[rsp] - movlhps xmm2,xmm2 - EmitIfCountGE RowCount, 1, - ; multiply by alpha - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - sub rbp,8 - jb OutputPartial8xNBlock - AccumulateAndStoreBlock RowCount, 4 - add r8,8*8 ; advance matrix C by 8 columns - mov rcx,rsi ; reload matrix A - test rbp,rbp - jnz ProcessNextColumnLoop8xN - jmp ExitKernel - -; -; Output a partial 8xN block to the matrix. -; - -OutputPartial8xNBlock: - add rbp,8 ; correct for over-subtract above - cmp ebp,2 - jb OutputPartial1xNBlock - cmp ebp,4 - jb OutputPartialLessThan4xNBlock - cmp ebp,6 - jb OutputPartialLessThan6xNBlock - AccumulateAndStoreBlock RowCount, 3 - test ebp,1 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,6*8 ; advance matrix C by 6 columns - jmp OutputPartial1xNBlock - -OutputPartialLessThan6xNBlock: - AccumulateAndStoreBlock RowCount, 2 - test ebp,1 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,4*8 ; advance matrix C by 4 columns - jmp OutputPartial1xNBlock - -OutputPartialLessThan4xNBlock: - AccumulateAndStoreBlock RowCount, 1 - test ebp,1 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,2*8 ; advance matrix C by 2 columns - -OutputPartial1xNBlock: - test r15b,r15b ; ZeroMode? - jnz SkipAccumulateOutput1xN - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - -SkipAccumulateOutput1xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, -IFB - jmp ExitKernel -ENDIF - - ENDM - -; -; Generate the GEMM kernel. -; - -FgemmKernelSse2Function Double - - END diff --git a/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm deleted file mode 100644 index c372d36e2e38b..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/ErfKernelFma3.asm +++ /dev/null @@ -1,569 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; ErfKernelFma3.asm -; -; Abstract: -; -; This module implements a kernel for computing the error function for a -; buffer of elements. -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMaskMoveAvx:NEAR - EXTERN MlasErfConstants:NEAR - -; -; Structure layout for the erf constants block. -; - -ErfConstants STRUCT - - ErfUpperAbsRange DWORD ? - ErfSplitBoundary DWORD ? - ErfSMALL_P0 DWORD ? - ErfSMALL_P1 DWORD ? - ErfSMALL_P2 DWORD ? - ErfSMALL_P3 DWORD ? - ErfSMALL_P4 DWORD ? - ErfSMALL_P5_Minus_One DWORD ? - ErfReserve0 DWORD ? - ErfBIG_P0 DWORD ? - ErfBIG_P1 DWORD ? - ErfBIG_P2 DWORD ? - ErfBIG_P3 DWORD ? - ErfBIG_P4 DWORD ? - ErfBIG_P5 DWORD ? - ErfBIG_P6_Minus_One DWORD ? - ErfNegZero DWORD ? - ErfOne DWORD ? - - Exp_UpperRange DWORD ? - Exp_LowerRange DWORD ? - Exp_Log2Reciprocal DWORD ? - Exp_log2_hi DWORD ? - Exp_log2_lo DWORD ? - Exp_P0 DWORD ? - Exp_P1 DWORD ? - Exp_P2 DWORD ? - Exp_P3 DWORD ? - Exp_P4 DWORD ? - Exp_P5 DWORD ? - Exp_P6 DWORD ? - Exp_C DWORD ? - Exp_X7F DWORD ? - -ErfConstants ENDS - -; -; Stack frame layout for the erf kernel. -; - -ErfKernelFrame STRUCT - - ErfBuffer0 OWORD 8 DUP(?) - ErfBuffer1 OWORD 8 DUP(?) - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding0 QWORD ? - Padding1 QWORD ? - CountN QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - -ErfKernelFrame ENDS - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the error function. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasErfKernelFma3, _TEXT - - alloc_stack (ErfKernelFrame.ReturnAddress) - - save_xmm128 xmm6,ErfKernelFrame.SavedXmm6 - save_xmm128 xmm7,ErfKernelFrame.SavedXmm7 - save_xmm128 xmm8,ErfKernelFrame.SavedXmm8 - save_xmm128 xmm9,ErfKernelFrame.SavedXmm9 - save_xmm128 xmm10,ErfKernelFrame.SavedXmm10 - save_xmm128 xmm11,ErfKernelFrame.SavedXmm11 - save_xmm128 xmm12,ErfKernelFrame.SavedXmm12 - save_xmm128 xmm13,ErfKernelFrame.SavedXmm13 - save_xmm128 xmm14,ErfKernelFrame.SavedXmm14 - save_xmm128 xmm15,ErfKernelFrame.SavedXmm15 - - END_PROLOGUE - - lea rax,MlasErfConstants - sub r8,8*4 - jb LErfProcessRemainingCount - -LComputeErf4x8Loop: - vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] - vmovups ymm0,YMMWORD PTR [rcx] ; original input vx0 - vmovups ymm1,YMMWORD PTR [rcx+32] ; original input vx1 - vmovups ymm2,YMMWORD PTR [rcx+64] ; original input vx2 - vmovups ymm3,YMMWORD PTR [rcx+96] ; original input vx3 - - vandps ymm4,ymm0,ymm15 ; vsign0 - vandps ymm5,ymm1,ymm15 ; vsign1 - vandps ymm6,ymm2,ymm15 ; vsign2 - vandps ymm7,ymm3,ymm15 ; vsign3 - vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0 - vandnps ymm1,ymm15,ymm1 ; abs(vx1) va1 - vandnps ymm2,ymm15,ymm2 ; abs(vx2) va2 - vandnps ymm3,ymm15,ymm3 ; abs(vx3) va3 - - vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax] - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+32],ymm5 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+64],ymm6 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp+96],ymm7 - - vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax] - vminps ymm0,ymm0,ymm14 ; force abs value in range - vminps ymm1,ymm1,ymm14 - vminps ymm2,ymm2,ymm14 - vminps ymm3,ymm3,ymm14 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax] - vmulps ymm4,ymm0,ymm0 ; vs0 (square) - vmulps ymm5,ymm1,ymm1 ; vs1 - vmulps ymm6,ymm2,ymm2 ; vs2 - vmulps ymm7,ymm3,ymm3 ; vs3 - - vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax] - vfmadd213ps ymm8,ymm4,ymm15 - vfmadd213ps ymm9,ymm5,ymm15 - vfmadd213ps ymm10,ymm6,ymm15 - vfmadd213ps ymm11,ymm7,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax] - vfmadd213ps ymm8,ymm4,ymm14 - vfmadd213ps ymm9,ymm5,ymm14 - vfmadd213ps ymm10,ymm6,ymm14 - vfmadd213ps ymm11,ymm7,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax] - vfmadd213ps ymm8,ymm4,ymm13 - vfmadd213ps ymm9,ymm5,ymm13 - vfmadd213ps ymm10,ymm6,ymm13 - vfmadd213ps ymm11,ymm7,ymm13 - - vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax] - vfmadd213ps ymm8,ymm4,ymm15 - vfmadd213ps ymm9,ymm5,ymm15 - vfmadd213ps ymm10,ymm6,ymm15 - vfmadd213ps ymm11,ymm7,ymm15 - - vfmadd213ps ymm8,ymm4,ymm14 - vfmadd213ps ymm9,ymm5,ymm14 - vfmadd213ps ymm10,ymm6,ymm14 - vfmadd213ps ymm11,ymm7,ymm14 - - vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax] - vfmadd213ps ymm8,ymm0,ymm0 - vfmadd213ps ymm9,ymm1,ymm1 - vfmadd213ps ymm10,ymm2,ymm2 - vfmadd213ps ymm11,ymm3,ymm3 - - vcmpgtps ymm4,ymm0,ymm12 ; vmask0 - vcmpgtps ymm5,ymm1,ymm12 ; vmask1 - vcmpgtps ymm6,ymm2,ymm12 ; vmask2 - vcmpgtps ymm7,ymm3,ymm12 ; vmask3 - - vandnps ymm8,ymm4,ymm8 - vandnps ymm9,ymm5,ymm9 - vandnps ymm10,ymm6,ymm10 - vandnps ymm11,ymm7,ymm11 - - vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax] - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32],ymm9 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64],ymm10 - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96],ymm11 - -LBiggerNumbers: - vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax] - vandps ymm0,ymm4,ymm0 - vandps ymm1,ymm5,ymm1 - vandps ymm2,ymm6,ymm2 - vandps ymm3,ymm7,ymm3 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax] - vfmadd213ps ymm8,ymm0,ymm15 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax] - vfmadd213ps ymm8,ymm0,ymm14 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax] - vfmadd213ps ymm8,ymm0,ymm13 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax] - vfmadd213ps ymm8,ymm0,ymm14 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] - vfmadd213ps ymm8,ymm0,ymm13 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax] - vfmadd213ps ymm8,ymm0,ymm0 - vfmadd213ps ymm9,ymm1,ymm1 - vfmadd213ps ymm10,ymm2,ymm2 - vfmadd213ps ymm11,ymm3,ymm3 - - vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax] - vxorps ymm8,ymm8,ymm15 - vxorps ymm9,ymm9,ymm15 - vxorps ymm10,ymm10,ymm15 - vxorps ymm11,ymm11,ymm15 - - vbroadcastss ymm13,ErfConstants.Exp_C[rax] - vmovaps ymm5,ymm4 - vmovaps ymm6,ymm4 - vmovaps ymm7,ymm4 - - ; expf(ymm8 -- ymm11) - vmaxps ymm8,ymm8,ymm14 - vmaxps ymm9,ymm9,ymm14 - vmaxps ymm10,ymm10,ymm14 - vmaxps ymm11,ymm11,ymm14 - - vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax] - vfmadd213ps ymm4,ymm8,ymm13 - vfmadd213ps ymm5,ymm9,ymm13 - vfmadd213ps ymm6,ymm10,ymm13 - vfmadd213ps ymm7,ymm11,ymm13 - - vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax] - vmovaps ymm1,ymm0 - vmovaps ymm2,ymm0 - vmovaps ymm3,ymm0 - - vsubps ymm4,ymm4,ymm13 ; vr = round() - vsubps ymm5,ymm5,ymm13 - vsubps ymm6,ymm6,ymm13 - vsubps ymm7,ymm7,ymm13 - - vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve - vfmadd213ps ymm1,ymm5,ymm9 - vfmadd213ps ymm2,ymm6,ymm10 - vfmadd213ps ymm3,ymm7,ymm11 - - vbroadcastss ymm8,ErfConstants.Exp_P0[rax] - vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo - vfmadd231ps ymm1,ymm5,ymm15 - vfmadd231ps ymm2,ymm6,ymm15 - vfmadd231ps ymm3,ymm7,ymm15 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm14,ErfConstants.Exp_P1[rax] - vbroadcastss ymm13,ErfConstants.Exp_P2[rax] - vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm12,ErfConstants.Exp_P3[rax] - vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm15,ErfConstants.Exp_P4[rax] - vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3 - vfmadd213ps ymm9,ymm1,ymm12 - vfmadd213ps ymm10,ymm2,ymm12 - vfmadd213ps ymm11,ymm3,ymm12 - - vbroadcastss ymm14,ErfConstants.Exp_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,ErfConstants.Exp_P6[rax] - vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm12,ErfConstants.Exp_X7F[rax] - vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vcvttps2dq ymm4,ymm4 - vcvttps2dq ymm5,ymm5 - vcvttps2dq ymm6,ymm6 - vcvttps2dq ymm7,ymm7 - - vbroadcastss ymm15,ErfConstants.ErfOne[rax] - vpaddd ymm4,ymm4,ymm12 ; +127 - vpaddd ymm5,ymm5,ymm12 - vpaddd ymm6,ymm6,ymm12 - vpaddd ymm7,ymm7,ymm12 - - vpslld ymm4,ymm4,23 - vpslld ymm5,ymm5,23 - vpslld ymm6,ymm6,23 - vpslld ymm7,ymm7,23 - - vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf) - vmulps ymm9,ymm9,ymm5 - vmulps ymm10,ymm10,ymm6 - vmulps ymm11,ymm11,ymm7 - - vsubps ymm8,ymm15,ymm8 - vsubps ymm9,ymm15,ymm9 - vsubps ymm10,ymm15,ymm10 - vsubps ymm11,ymm15,ymm11 - - ; merge small numbers' result - vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp] - vorps ymm9,ymm9,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+32] - vorps ymm10,ymm10,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+64] - vorps ymm11,ymm11,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp+96] - - ; copy sign - vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp] - vorps ymm1,ymm9,YMMWORD PTR 32+ErfKernelFrame.ErfBuffer0[rsp] - vorps ymm2,ymm10,YMMWORD PTR 64+ErfKernelFrame.ErfBuffer0[rsp] - vorps ymm3,ymm11,YMMWORD PTR 96+ErfKernelFrame.ErfBuffer0[rsp] - - vmovups YMMWORD PTR [rdx],ymm0 - vmovups YMMWORD PTR [rdx+32],ymm1 - vmovups YMMWORD PTR [rdx+64],ymm2 - vmovups YMMWORD PTR [rdx+96],ymm3 - - add rcx,32*4 ; advance by 4*8 elements - add rdx,32*4 - sub r8,32 - jae LComputeErf4x8Loop - -LErfProcessRemainingCount: - add r8,32 ; correct for over-subtract above - jz LErfBatchExp - -LErfProcess1x8: - mov DWORD PTR ErfKernelFrame.CountN[rsp],r8d - vbroadcastss ymm3,DWORD PTR ErfKernelFrame.CountN[rsp] - - vpcmpgtd ymm3,ymm3,YMMWORD PTR [MlasMaskMoveAvx] - vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] ; original input vx0 - - vandps ymm4,ymm0,ymm15 ; vsign0 - vandnps ymm0,ymm15,ymm0 ; abs(vx0) va0 - - vbroadcastss ymm14,ErfConstants.ErfUpperAbsRange[rax] - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp],ymm4 - - vbroadcastss ymm8,ErfConstants.ErfSMALL_P0[rax] - vminps ymm0,ymm0,ymm14 ; force abs value in range - - vbroadcastss ymm15,ErfConstants.ErfSMALL_P1[rax] - vmulps ymm4,ymm0,ymm0 ; vs0 (square) - - vbroadcastss ymm14,ErfConstants.ErfSMALL_P2[rax] - vfmadd213ps ymm8,ymm4,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfSMALL_P3[rax] - vfmadd213ps ymm8,ymm4,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfSMALL_P4[rax] - vfmadd213ps ymm8,ymm4,ymm13 - - vbroadcastss ymm14,ErfConstants.ErfSMALL_P5_Minus_One[rax] - vfmadd213ps ymm8,ymm4,ymm15 - - vfmadd213ps ymm8,ymm4,ymm14 - - vbroadcastss ymm12,ErfConstants.ErfSplitBoundary[rax] - vfmadd213ps ymm8,ymm0,ymm0 - - vcmpgtps ymm4,ymm0,ymm12 ; vmask0 - - vandnps ymm8,ymm4,ymm8 - - vmovups YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp],ymm8 - -LBiggerNumbersRemaining: - vbroadcastss ymm15,ErfConstants.ErfBIG_P1[rax] - vbroadcastss ymm8,ErfConstants.ErfBIG_P0[rax] - vandps ymm0,ymm4,ymm0 - - vbroadcastss ymm14,ErfConstants.ErfBIG_P2[rax] - vfmadd213ps ymm8,ymm0,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfBIG_P3[rax] - vfmadd213ps ymm8,ymm0,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfBIG_P4[rax] - vfmadd213ps ymm8,ymm0,ymm13 - - vbroadcastss ymm14,ErfConstants.ErfBIG_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 - - vbroadcastss ymm13,ErfConstants.ErfBIG_P6_Minus_One[rax] - vfmadd213ps ymm8,ymm0,ymm14 - - vbroadcastss ymm15,ErfConstants.ErfNegZero[rax] - vfmadd213ps ymm8,ymm0,ymm13 - - vbroadcastss ymm14,ErfConstants.Exp_LowerRange[rax] - vfmadd213ps ymm8,ymm0,ymm0 - - vbroadcastss ymm4,ErfConstants.Exp_Log2Reciprocal[rax] - vxorps ymm8,ymm8,ymm15 - - vbroadcastss ymm13,ErfConstants.Exp_C[rax] - - ; expf(ymm8 -- ymm11) - vmaxps ymm8,ymm8,ymm14 - - vbroadcastss ymm0,ErfConstants.Exp_log2_hi[rax] - vfmadd213ps ymm4,ymm8,ymm13 - - vbroadcastss ymm15,ErfConstants.Exp_log2_lo[rax] - - vsubps ymm4,ymm4,ymm13 ; vr = round() - - vfmadd213ps ymm0,ymm4,ymm8 ; vf = vr * log2_hi + ve - - vbroadcastss ymm8,ErfConstants.Exp_P0[rax] - - vfmadd231ps ymm0,ymm4,ymm15 ; vf += vr * log_2_lo - - vbroadcastss ymm14,ErfConstants.Exp_P1[rax] - - vbroadcastss ymm13,ErfConstants.Exp_P2[rax] - vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p1 - - vbroadcastss ymm12,ErfConstants.Exp_P3[rax] - vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p2 - - vbroadcastss ymm15,ErfConstants.Exp_P4[rax] - vfmadd213ps ymm8,ymm0,ymm12 ; *+ exp_p3 - - vbroadcastss ymm14,ErfConstants.Exp_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 ; *+ exp_p4 - - vbroadcastss ymm13,ErfConstants.Exp_P6[rax] - vfmadd213ps ymm8,ymm0,ymm14 ; *+ exp_p5 - - vbroadcastss ymm12,ErfConstants.Exp_X7F[rax] - vfmadd213ps ymm8,ymm0,ymm13 ; *+ exp_p6 - - vcvttps2dq ymm4,ymm4 - - vbroadcastss ymm15,ErfConstants.ErfOne[rax] - vpaddd ymm4,ymm4,ymm12 ; +127 - - vpslld ymm4,ymm4,23 - - vmulps ymm8,ymm8,ymm4 ; 2^i * exp(vf) - - vsubps ymm8,ymm15,ymm8 - - ; merge small numbers' result - vorps ymm8,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer1[rsp] - - ; copy sign - vorps ymm0,ymm8,YMMWORD PTR ErfKernelFrame.ErfBuffer0[rsp] - - vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0 - - add rcx,8*4 - add rdx,8*4 - sub r8,8 - jg LErfProcess1x8 - -LErfBatchExp: - vzeroupper - movaps xmm6,ErfKernelFrame.SavedXmm6[rsp] - movaps xmm7,ErfKernelFrame.SavedXmm7[rsp] - movaps xmm8,ErfKernelFrame.SavedXmm8[rsp] - movaps xmm9,ErfKernelFrame.SavedXmm9[rsp] - movaps xmm10,ErfKernelFrame.SavedXmm10[rsp] - movaps xmm11,ErfKernelFrame.SavedXmm11[rsp] - movaps xmm12,ErfKernelFrame.SavedXmm12[rsp] - movaps xmm13,ErfKernelFrame.SavedXmm13[rsp] - movaps xmm14,ErfKernelFrame.SavedXmm14[rsp] - movaps xmm15,ErfKernelFrame.SavedXmm15[rsp] - add rsp,(ErfKernelFrame.ReturnAddress) - - BEGIN_EPILOGUE - - ret - - NESTED_END MlasErfKernelFma3, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvx512FCommon.inc b/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvx512FCommon.inc deleted file mode 100644 index 10d3b5ff21be2..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvx512FCommon.inc +++ /dev/null @@ -1,511 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; FgemmKernelAvx512FCommon.inc -; -; Abstract: -; -; This module implements the kernels for the floating point matrix/matrix -; multiply operation (SGEMM and DGEMM). -; -; This implementation uses AVX512F instructions. -; -;-- - -; -; Macro Description: -; -; This macro multiplies and accumulates for 2 ZMMWORDs by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r13 - Supplies the address into the matrix A data plus 6 rows. -; -; r14 - Supplies the address into the matrix A data plus 9 rows. -; -; zmm4-zmm27 - Supplies the block accumulators. -; - -ComputeBlockAvx512FBy2 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IFNB - prefetcht0 [rdx+VectorOffset+PrefetchOffset] - prefetcht0 [rdx+r12+VectorOffset+PrefetchOffset] -ENDIF -IF RowCount EQ 1 - vbroadcastsf zmm3,FgemmElementPtr [rcx+BroadcastOffset] - vfmadd231pf zmm4,zmm3,ZMMWORD PTR [rdx+VectorOffset] - vfmadd231pf zmm5,zmm3,ZMMWORD PTR [rdx+r12+VectorOffset] -ELSE - vmovapf zmm0,ZMMWORD PTR [rdx+VectorOffset] - vmovapf zmm1,ZMMWORD PTR [rdx+r12+VectorOffset] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for 1 ZMMWORD by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r13 - Supplies the address into the matrix A data plus 6 rows. -; -; r14 - Supplies the address into the matrix A data plus 9 rows. -; -; zmm4-zmm27 - Supplies the block accumulators. -; - -ComputeBlockAvx512FBy1 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IFNB - prefetcht0 [rdx+VectorOffset+PrefetchOffset] -ENDIF - vmovapf zmm0,ZMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockAvx512FLoop MACRO ComputeBlock, RowCount - -IF RowCount GT 3 - lea rbx,[r10*2+r10] -IF RowCount EQ 12 - lea r13,[rcx+rbx*2] ; compute matrix A plus 6 rows - lea r14,[r13+rbx] ; compute matrix A plus 9 rows -ENDIF - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - ComputeBlockLoop ComputeBlock, RowCount, -IF RowCount GT 3 - lea rbx,[rax*2+rax] -IF RowCount EQ 12 - lea r13,[r8+rbx*2] ; compute matrix C plus 6 rows - lea r14,[r13+rbx] ; compute matrix C plus 9 rows -ENDIF - add rbx,r8 ; compute matrix C plus 3 rows -ENDIF - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; rsi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r8 - Supplies the address of matrix C. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -ProcessCountM MACRO RowCount - - LOCAL ProcessNextColumnLoop2xN - LOCAL MultiplyAlpha2xNBlock - LOCAL Store2xNBlock - LOCAL Output1xNBlock - LOCAL Output1xNBlockWithMask - LOCAL MultiplyAlpha1xNBlockWithMask - LOCAL Store1xNBlockWithMask - LOCAL ProcessRemainingCountN - - cmp rbp,FgemmZmmElementCount - jbe ProcessRemainingCountN - -ProcessNextColumnLoop2xN: - EmitIfCountGE RowCount, 12, - ; clear upper block accumulators - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - ComputeBlockAvx512FLoop ComputeBlockAvx512FBy2, RowCount - add rdx,r12 ; advance matrix B by 64*CountK bytes - test r15b,r15b ; ZeroMode? - jnz MultiplyAlpha2xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - jmp Store2xNBlock - -MultiplyAlpha2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - -Store2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - add r8,64 ; advance matrix C by ZMMWORD -IF RowCount GT 3 - add rbx,64 ; advance matrix C plus 3 rows by ZMMWORD -IF RowCount EQ 12 - add r13,64 ; advance matrix C plus 6 rows by ZMMWORD - add r14,64 ; advance matrix C plus 9 rows by ZMMWORD -ENDIF -ENDIF - sub rbp,FgemmZmmElementCount - -Output1xNBlock: - sub rbp,FgemmZmmElementCount - jae Output1xNBlockWithMask - lea ecx,[ebp+FgemmZmmElementCount] - ; correct for over-subtract above - mov edi,1 - shl edi,cl - dec edi - kmovw k1,edi ; update mask for remaining columns - xor ebp,ebp ; no more columns remaining - -Output1xNBlockWithMask: - test r15b,r15b ; ZeroMode? - jnz MultiplyAlpha1xNBlockWithMask - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - jmp Store1xNBlockWithMask - -MultiplyAlpha1xNBlockWithMask: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - -Store1xNBlockWithMask: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - add r8,64 ; advance matrix C by ZMMWORD - mov rcx,rsi ; reload matrix A - vzeroall - cmp rbp,FgemmZmmElementCount - ja ProcessNextColumnLoop2xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - EmitIfCountGE RowCount, 12, - ; clear upper block accumulators - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - EmitIfCountGE RowCount, 12, - ComputeBlockAvx512FLoop ComputeBlockAvx512FBy1, RowCount - jmp Output1xNBlock - - ENDM - -; -; Macro Description: -; -; This macro generates the inner kernel to compute matrix multiplication. -; -; Arguments: -; -; Type - Supplies the element type string for function tags. -; - -FgemmKernelAvx512FFunction MACRO Type - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see SGEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemm&Type&KernelAvx512F, _TEXT - - FgemmKernelEntry Avx512F - - mov r12,r9 - shl r12,6 ; compute 64*CountK bytes - mov edi,-1 - kmovw k1,edi ; update mask to write all columns - vbroadcastsf zmm31,FgemmElementPtr FgemmKernelFrame.Alpha[rsp] - -; -; Process CountM rows of the matrices. -; - - cmp r11,12 - jb ProcessCountMLessThan12 - mov r11d,12 ; return 12 rows handled - ProcessCountM 12 - -ProcessCountMLessThan12: - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - FgemmKernelExit Avx512F - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemm&Type&KernelAvx512F, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvxCommon.inc b/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvxCommon.inc deleted file mode 100644 index 5b65c37bb957b..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/FgemmKernelAvxCommon.inc +++ /dev/null @@ -1,442 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; FgemmKernelAvxCommon.inc -; -; Abstract: -; -; This module implements the kernels for the floating point matrix/matrix -; multiply operation (SGEMM and DGEMM). -; -; This implementation uses AVX instructions. -; -;-- - - EXTERN MlasMaskMoveTableAvx:NEAR - -; -; Macro Description: -; -; This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 2 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm8-ymm15 - Supplies the block accumulators. -; - -ComputeBlockAvxBy2 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IF RowCount EQ 1 - vbroadcastsf ymm3,FgemmElementPtr [rcx+BroadcastOffset] - vmulpf ymm4,ymm3,YMMWORD PTR [rdx+VectorOffset] - vaddpf ymm8,ymm8,ymm4 - vmulpf ymm5,ymm3,YMMWORD PTR [rdx+VectorOffset+32] - vaddpf ymm9,ymm9,ymm5 -ELSE - vmovapf ymm0,YMMWORD PTR [rdx+VectorOffset] - vmovapf ymm1,YMMWORD PTR [rdx+VectorOffset+32] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for 1 YMMWORD by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 2 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm8-ymm15 - Supplies the block accumulators. -; - -ComputeBlockAvxBy1 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IF RowCount EQ 1 - vbroadcastsf ymm3,FgemmElementPtr [rcx+BroadcastOffset] - vmulpf ymm5,ymm3,YMMWORD PTR [rdx+VectorOffset] - vaddpf ymm9,ymm9,ymm5 -ELSE - vmovapf ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockAvxLoop MACRO ComputeBlock, RowCount - -IF RowCount GT 2 - lea rbx,[rcx+r10*2] ; compute matrix A plus 2 rows -ENDIF - ComputeBlockLoop ComputeBlock, RowCount, -IF RowCount GT 2 - lea rbx,[r8+rax*2] ; compute matrix C plus 2 rows -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernel label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; rsi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r8 - Supplies the address of matrix C. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop2xN - LOCAL Store2xNBlock - LOCAL ProcessRemainingCountN - LOCAL Store1xNBlock - LOCAL OutputMasked2xNBlock - LOCAL StoreMasked2xNBlock - LOCAL OutputMasked1xNBlock - LOCAL StoreMasked1xNBlock - - cmp rbp,FgemmYmmElementCount - jbe ProcessRemainingCountN - -ProcessNextColumnLoop2xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - ComputeBlockAvxLoop ComputeBlockAvxBy2, RowCount - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - sub rbp,2*FgemmYmmElementCount - jb OutputMasked2xNBlock - test r15b,r15b ; ZeroMode? - jnz Store2xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - -Store2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - add r8,2*32 ; advance matrix C by 2 YMMWORDs - mov rcx,rsi ; reload matrix A - cmp rbp,FgemmYmmElementCount - ja ProcessNextColumnLoop2xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - ComputeBlockAvxLoop ComputeBlockAvxBy1, RowCount - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - cmp rbp,FgemmYmmElementCount - jb OutputMasked1xNBlock - test r15b,r15b ; ZeroMode? - jnz Store1xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - -Store1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - jmp ExitKernel - -OutputMasked2xNBlock: - test r15b,r15b ; ZeroMode? - jnz StoreMasked2xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - -StoreMasked2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - add r8,32 ; advance matrix C by YMMWORD -IF RowCount GT 2 - add rbx,32 ; advance matrix C plus 2 rows by YMMWORD -ENDIF - add rbp,FgemmYmmElementCount ; correct for over-subtract above - -OutputMasked1xNBlock: - neg rbp - lea rcx,MlasMaskMoveTableAvx+8*4 - vmovdqu ymm0,YMMWORD PTR [rcx+rbp*FgemmElementSize] - test r15b,r15b ; ZeroMode? - jnz StoreMasked1xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - -StoreMasked1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, -IFB - jmp ExitKernel -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates the inner kernel to compute matrix multiplication. -; -; Arguments: -; -; Type - Supplies the element type string for function tags. -; - -FgemmKernelAvxFunction MACRO Type - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see SGEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemm&Type&KernelAvx, _TEXT - - FgemmKernelEntry Avx - - vbroadcastsf ymm2,FgemmElementPtr FgemmKernelFrame.Alpha[rsp] - -; -; Process 4 rows of the matrices. -; - - cmp r11,4 - jb ProcessCountMLessThan4 - mov r11d,4 ; return 4 rows handled - ProcessCountM 4, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - FgemmKernelExit Avx - -; -; Process 2 rows of the matrices. -; - -ProcessCountMLessThan4: - cmp r11,2 - jb ProcessCountMLessThan2 - mov r11d,2 ; return 2 rows handled - ProcessCountM 2 - -; -; Process 1 row of the matrices. -; - -ProcessCountMLessThan2: - ProcessCountM 1 - - NESTED_END MlasGemm&Type&KernelAvx, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/FgemmKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/FgemmKernelCommon.inc deleted file mode 100644 index d80feac1e5ed4..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/FgemmKernelCommon.inc +++ /dev/null @@ -1,257 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; FgemmKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the floating -; point matrix/matrix multiply operation (SGEMM and DGEMM). -; -;-- - -; -; Stack frame layout for the floating point kernels. -; - -FgemmKernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR12 QWORD ? - SavedR13 QWORD ? - SavedR14 QWORD ? - SavedR15 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - lda QWORD ? - ldc QWORD ? - Alpha QWORD ? - ZeroMode QWORD ? - -FgemmKernelFrame ENDS - -; -; Define the number of elements per vector register. -; - -FgemmXmmElementCount EQU (16 / FgemmElementSize) -FgemmYmmElementCount EQU (32 / FgemmElementSize) -FgemmZmmElementCount EQU (64 / FgemmElementSize) - -; -; Macro Description: -; -; This macro implements the common prologue code for the SGEMM and DGEMM -; kernels. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; Return Registers: -; -; rax - Stores the length in bytes of a row from matrix C. -; -; rsi - Stores the address of the matrix A data. -; -; rbp - Stores the CountN argument from the stack frame. -; -; r10 - Stores the length in bytes of a row from matrix A. -; -; r11 - Stores the CountM argument from the stack frame. -; -; rbx, rsi, rdi - Previous values stored on the stack and the registers -; are available as temporaries. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -FgemmKernelEntry MACRO Isa - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r15 - alloc_stack (FgemmKernelFrame.SavedR15) -IFIDNI , - save_reg r12,FgemmKernelFrame.SavedR12 - save_reg r13,FgemmKernelFrame.SavedR13 - save_reg r14,FgemmKernelFrame.SavedR14 -ENDIF - save_xmm128 xmm6,FgemmKernelFrame.SavedXmm6 - save_xmm128 xmm7,FgemmKernelFrame.SavedXmm7 - save_xmm128 xmm8,FgemmKernelFrame.SavedXmm8 - save_xmm128 xmm9,FgemmKernelFrame.SavedXmm9 - save_xmm128 xmm10,FgemmKernelFrame.SavedXmm10 - save_xmm128 xmm11,FgemmKernelFrame.SavedXmm11 - save_xmm128 xmm12,FgemmKernelFrame.SavedXmm12 - save_xmm128 xmm13,FgemmKernelFrame.SavedXmm13 - save_xmm128 xmm14,FgemmKernelFrame.SavedXmm14 - save_xmm128 xmm15,FgemmKernelFrame.SavedXmm15 - - END_PROLOGUE - -IFDIFI , - vzeroall -ENDIF - mov rsi,rcx - mov rbp,FgemmKernelFrame.CountN[rsp] - mov rax,FgemmKernelFrame.ldc[rsp] - shl rax,FgemmElementShift ; convert ldc to bytes - mov r10,FgemmKernelFrame.lda[rsp] - shl r10,FgemmElementShift ; convert lda to bytes - mov r11,FgemmKernelFrame.CountM[rsp] - movzx r15,BYTE PTR FgemmKernelFrame.ZeroMode[rsp] - - ENDM - -; -; Macro Description: -; -; This macro implements the common epilogue code for the SGEMM and DGEMM -; kernels. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; Implicit Arguments: -; -; r11d - Stores the number of rows handled. -; - -FgemmKernelExit MACRO Isa - - mov eax,r11d - movaps xmm6,FgemmKernelFrame.SavedXmm6[rsp] - movaps xmm7,FgemmKernelFrame.SavedXmm7[rsp] - movaps xmm8,FgemmKernelFrame.SavedXmm8[rsp] - movaps xmm9,FgemmKernelFrame.SavedXmm9[rsp] - movaps xmm10,FgemmKernelFrame.SavedXmm10[rsp] - movaps xmm11,FgemmKernelFrame.SavedXmm11[rsp] - movaps xmm12,FgemmKernelFrame.SavedXmm12[rsp] - movaps xmm13,FgemmKernelFrame.SavedXmm13[rsp] - movaps xmm14,FgemmKernelFrame.SavedXmm14[rsp] - movaps xmm15,FgemmKernelFrame.SavedXmm15[rsp] -IFIDNI , - mov r12,FgemmKernelFrame.SavedR12[rsp] - mov r13,FgemmKernelFrame.SavedR13[rsp] - mov r14,FgemmKernelFrame.SavedR14[rsp] -ENDIF - add rsp,(FgemmKernelFrame.SavedR15) - - BEGIN_EPILOGUE - - pop r15 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to access from matrix A. -; -; AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer -; in rbx should also be advanced as part of the loop. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus N rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO ComputeBlock, RowCount, AdvanceMatrixAPlusRows - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL OutputBlock - - mov rdi,r9 ; reload CountK - sub rdi,4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock RowCount, 0, FgemmElementSize*0, 64*4 - ComputeBlock RowCount, 2*32, FgemmElementSize*1, 64*4 - add_immed rdx,2*2*32 ; advance matrix B by 128 bytes - ComputeBlock RowCount, 0, FgemmElementSize*2, 64*4 - ComputeBlock RowCount, 2*32, FgemmElementSize*3, 64*4 - add_immed rdx,2*2*32 ; advance matrix B by 128 bytes - add rcx,4*FgemmElementSize ; advance matrix A by 4 elements -IF AdvanceMatrixAPlusRows - add rbx,4*FgemmElementSize ; advance matrix A plus rows by 4 elements -IF RowCount GE 12 - add r13,4*FgemmElementSize - add r14,4*FgemmElementSize -ENDIF -ENDIF - sub rdi,4 - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add rdi,4 ; correct for over-subtract above - jz OutputBlock - -ComputeBlockBy1Loop: - ComputeBlock RowCount, 0, 0 - add rdx,2*32 ; advance matrix B by 64 bytes - add rcx,FgemmElementSize ; advance matrix A by 1 element -IF AdvanceMatrixAPlusRows - add rbx,FgemmElementSize ; advance matrix A plus rows by 1 element -IF RowCount GE 12 - add r13,FgemmElementSize - add r14,FgemmElementSize -ENDIF -ENDIF - dec rdi - jne ComputeBlockBy1Loop - -OutputBlock: - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/FgemmKernelFma3Common.inc b/onnxruntime/core/mlas/lib/amd64/FgemmKernelFma3Common.inc deleted file mode 100644 index 298b9dca81ee2..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/FgemmKernelFma3Common.inc +++ /dev/null @@ -1,503 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; FgemmKernelFma3Common.inc -; -; Abstract: -; -; This module implements the kernels for the floating point matrix/matrix -; multiply operation (SGEMM and DGEMM). -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - EXTERN MlasMaskMoveTableAvx:NEAR - -; -; Macro Description: -; -; This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockFma3By2 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IFNB - prefetcht0 [rdx+VectorOffset+PrefetchOffset] -ENDIF -IF RowCount EQ 1 - vbroadcastsf ymm3,FgemmElementPtr [rcx+BroadcastOffset] - vfmadd231pf ymm4,ymm3,YMMWORD PTR [rdx+VectorOffset] - vfmadd231pf ymm5,ymm3,YMMWORD PTR [rdx+VectorOffset+32] -ELSE - vmovapf ymm0,YMMWORD PTR [rdx+VectorOffset] - vmovapf ymm1,YMMWORD PTR [rdx+VectorOffset+32] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for 1 YMMWORD by N rows of the output -; matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; PrefetchOffset - Optionally supplies the byte offset from matrix B to -; prefetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockFma3By1 MACRO RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -IFNB - prefetcht0 [rdx+VectorOffset+PrefetchOffset] -ENDIF -IF RowCount EQ 1 - vbroadcastsf ymm3,FgemmElementPtr [rcx+BroadcastOffset] - vfmadd231pf ymm5,ymm3,YMMWORD PTR [rdx+VectorOffset] -ELSE - vmovapf ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockFma3Loop MACRO ComputeBlock, RowCount - -IF RowCount GT 3 - lea rbx,[r10*2+r10] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - ComputeBlockLoop ComputeBlock, RowCount, - vbroadcastsf ymm2,FgemmElementPtr FgemmKernelFrame.Alpha[rsp] -IF RowCount GT 3 - lea rbx,[rax*2+rax] - add rbx,r8 ; compute matrix C plus 3 rows -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernelAndZeroUpper label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; rsi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r8 - Supplies the address of matrix C. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop2xN - LOCAL MultiplyAlpha2xNBlock - LOCAL Store2xNBlock - LOCAL ProcessRemainingCountN - LOCAL MultiplyAlpha1xNBlock - LOCAL Store1xNBlock - LOCAL OutputMasked2xNBlock - LOCAL MultiplyAlphaMasked2xNBlock - LOCAL StoreMasked2xNBlock - LOCAL OutputMasked1xNBlock - LOCAL MultiplyAlphaMasked1xNBlock - LOCAL StoreMasked1xNBlock - - cmp rbp,FgemmYmmElementCount - jbe ProcessRemainingCountN - -ProcessNextColumnLoop2xN: - ComputeBlockFma3Loop ComputeBlockFma3By2, RowCount - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - sub rbp,2*FgemmYmmElementCount - jb OutputMasked2xNBlock - test r15b,r15b ; ZeroMode? - jnz MultiplyAlpha2xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - jmp Store2xNBlock - -MultiplyAlpha2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - -Store2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - add r8,2*32 ; advance matrix C by 2 YMMWORDs - mov rcx,rsi ; reload matrix A - vzeroall - cmp rbp,FgemmYmmElementCount - ja ProcessNextColumnLoop2xN - test rbp,rbp - jz ExitKernel - -ProcessRemainingCountN: - ComputeBlockFma3Loop ComputeBlockFma3By1, RowCount - cmp rbp,FgemmYmmElementCount - jb OutputMasked1xNBlock - test r15b,r15b ; ZeroMode? - jnz MultiplyAlpha1xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp Store1xNBlock - -MultiplyAlpha1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -Store1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp ExitKernelAndZeroUpper - -OutputMasked2xNBlock: - test r15b,r15b ; ZeroMode? - jnz MultiplyAlphaMasked2xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp StoreMasked2xNBlock - -MultiplyAlphaMasked2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -StoreMasked2xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,32 ; advance matrix C by YMMWORD -IF RowCount GT 3 - add rbx,32 ; advance matrix C plus 3 rows by YMMWORD -ENDIF - add rbp,FgemmYmmElementCount ; correct for over-subtract above - -OutputMasked1xNBlock: - neg rbp - lea rcx,MlasMaskMoveTableAvx+8*4 - vmovdqu ymm0,YMMWORD PTR [rcx+rbp*FgemmElementSize] - test r15b,r15b ; ZeroMode? - jnz MultiplyAlphaMasked1xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp StoreMasked1xNBlock - -MultiplyAlphaMasked1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -StoreMasked1xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, -IFB - jmp ExitKernelAndZeroUpper -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates the inner kernel to compute matrix multiplication. -; -; Arguments: -; -; Type - Supplies the element type string for function tags. -; - -FgemmKernelFma3Function MACRO Type - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see GEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemm&Type&KernelFma3, _TEXT - - FgemmKernelEntry Fma3 - -; -; Process CountM rows of the matrices. -; - - cmp r11,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp r11,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp r11,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM6: - mov r11d,6 ; return 6 rows handled - ProcessCountM 6, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernelAndZeroUpper: - vzeroupper - -ExitKernel: - FgemmKernelExit Fma3 - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM5: - ProcessCountM 5 - - NESTED_END MlasGemm&Type&KernelFma3, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/FgemmKernelSse2Common.inc b/onnxruntime/core/mlas/lib/amd64/FgemmKernelSse2Common.inc deleted file mode 100644 index 960263a642839..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/FgemmKernelSse2Common.inc +++ /dev/null @@ -1,156 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; FgemmKernelSse2Common.inc -; -; Abstract: -; -; This module implements the kernels for the floating point matrix/matrix -; multiply operation (SGEMM and DGEMM). -; -; This implementation uses SSE2 instructions. -; -;-- - -; -; Macro Description: -; -; This stores the block accumulators to the output matrix with an optional -; accumulation of the existing contents of the output matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorCount - Supplies the number of vector columns to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; r8 - Supplies the address of matrix C. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; -; xmm8-xmm15 - Supplies the block accumulators. -; - -AccumulateAndStoreBlock MACRO RowCount, VectorCount - - LOCAL SkipAccumulateOutput - - test r15b,r15b ; ZeroMode? - jnz SkipAccumulateOutput - EmitIfCount2GE RowCount, 1, VectorCount, 1, - EmitIfCount2GE RowCount, 1, VectorCount, 2, - EmitIfCount2GE RowCount, 1, VectorCount, 3, - EmitIfCount2GE RowCount, 1, VectorCount, 4, - EmitIfCount2GE RowCount, 2, VectorCount, 1, - EmitIfCount2GE RowCount, 2, VectorCount, 2, - EmitIfCount2GE RowCount, 2, VectorCount, 3, - EmitIfCount2GE RowCount, 2, VectorCount, 4, - EmitIfCount2GE RowCount, 1, VectorCount, 1, - EmitIfCount2GE RowCount, 1, VectorCount, 2, - EmitIfCount2GE RowCount, 1, VectorCount, 3, - EmitIfCount2GE RowCount, 1, VectorCount, 4, - EmitIfCount2GE RowCount, 2, VectorCount, 1, - EmitIfCount2GE RowCount, 2, VectorCount, 2, - EmitIfCount2GE RowCount, 2, VectorCount, 3, - EmitIfCount2GE RowCount, 2, VectorCount, 4, - -SkipAccumulateOutput: - EmitIfCount2GE RowCount, 1, VectorCount, 1, - EmitIfCount2GE RowCount, 1, VectorCount, 2, - EmitIfCount2GE RowCount, 1, VectorCount, 3, - EmitIfCount2GE RowCount, 1, VectorCount, 4, - EmitIfCount2GE RowCount, 2, VectorCount, 1, - EmitIfCount2GE RowCount, 2, VectorCount, 2, - EmitIfCount2GE RowCount, 2, VectorCount, 3, - EmitIfCount2GE RowCount, 2, VectorCount, 4, - - ENDM - -; -; Macro Description: -; -; This macro generates the inner kernel to compute matrix multiplication. -; -; Arguments: -; -; Type - Supplies the element type string for function tags. -; - -FgemmKernelSse2Function MACRO Type - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9d) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see SGEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemm&Type&KernelSse, _TEXT - - FgemmKernelEntry Sse - -; -; Process CountM rows of the matrices. -; - - cmp r11,2 - jb ProcessCountM1 - mov r11d,2 ; return 2 rows handled - ProcessCountM 2, Fallthrough - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - FgemmKernelExit Sse - -ProcessCountM1: - ProcessCountM 1 - - NESTED_END MlasGemm&Type&KernelSse, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/LogisticKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/LogisticKernelFma3.asm deleted file mode 100644 index e50a99baf12fd..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/LogisticKernelFma3.asm +++ /dev/null @@ -1,157 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; LogisticKernelFma3.asm -; -; Abstract: -; -; This module implements a kernel for computing the logistic function for a -; buffer of elements. -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE TransKernelCommon.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - EXTERN MlasLogisticConstants:NEAR - -;++ -; -; Routine Description: -; -; This routine implements the a vectorized kernel for the logistic function. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasComputeLogisticF32KernelFma3, _TEXT - - alloc_stack (TransKernelFrame.ReturnAddress) - - save_xmm128 xmm6,TransKernelFrame.SavedXmm6 - save_xmm128 xmm7,TransKernelFrame.SavedXmm7 - save_xmm128 xmm8,TransKernelFrame.SavedXmm8 - save_xmm128 xmm9,TransKernelFrame.SavedXmm9 - save_xmm128 xmm10,TransKernelFrame.SavedXmm10 - save_xmm128 xmm11,TransKernelFrame.SavedXmm11 - save_xmm128 xmm12,TransKernelFrame.SavedXmm12 - save_xmm128 xmm13,TransKernelFrame.SavedXmm13 - save_xmm128 xmm14,TransKernelFrame.SavedXmm14 - save_xmm128 xmm15,TransKernelFrame.SavedXmm15 - - END_PROLOGUE - - lea rax,MlasLogisticConstants - vbroadcastss ymm4,LogisticConstants.LowerRange[rax] - vbroadcastss ymm5,LogisticConstants.UpperRange[rax] - vbroadcastss ymm6,LogisticConstants.alpha_9[rax] - vbroadcastss ymm7,LogisticConstants.alpha_7[rax] - vbroadcastss ymm8,LogisticConstants.alpha_5[rax] - vbroadcastss ymm9,LogisticConstants.alpha_3[rax] - vbroadcastss ymm10,LogisticConstants.alpha_1[rax] - vbroadcastss ymm11,LogisticConstants.beta_10[rax] - vbroadcastss ymm12,LogisticConstants.beta_6[rax] - vbroadcastss ymm13,LogisticConstants.beta_4[rax] - vbroadcastss ymm14,LogisticConstants.beta_2[rax] - vbroadcastss ymm15,LogisticConstants.beta_0[rax] - - sub r8,8 - jb ProcessRemainingCount - -ComputeLogisticBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rcx] ; clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,LogisticConstants.beta_8[rax] - vfmadd231ps ymm2,ymm1,ymm6 ; p = x2 * alpha_9 + alpha_7 - vfmadd213ps ymm2,ymm1,ymm8 ; p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm9 ; p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm10 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm11 ; q = x2 * beta_10 + beta_8 - vfmadd213ps ymm3,ymm1,ymm12 ; q = x2 * q + beta_6 - vfmadd213ps ymm3,ymm1,ymm13 ; q = x2 * q + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 ; p = x * p - vbroadcastss ymm0,LogisticConstants.one_half[rax] - vdivps ymm2,ymm2,ymm3 - vxorps ymm3,ymm3,ymm3 - vaddps ymm0,ymm2,ymm0 ; logistic = p / q + 0.5 - vmaxps ymm0,ymm3,ymm0 ; clamp lower bound - add rcx,8*4 ; advance input by 8 elements - vmovups YMMWORD PTR [rdx],ymm0 - add rdx,8*4 ; advance output by 8 elements - sub r8,8 - jae ComputeLogisticBy8Loop - -ProcessRemainingCount: - add r8,8 ; correct for over-subtract above - jz ExitKernel - neg r8 - lea r10,MlasMaskMoveTableAvx+8*4 - vmovups ymm2,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rcx] - vmaxps ymm0,ymm4,ymm0 ; clamp lower bound - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,LogisticConstants.beta_8[rax] - vfmadd231ps ymm7,ymm1,ymm6 ; p = x2 * alpha_9 + alpha_7 - vfmadd213ps ymm7,ymm1,ymm8 ; p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm9 ; p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm10 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm11 ; q = x2 * beta_10 + beta_8 - vfmadd213ps ymm3,ymm1,ymm12 ; q = x2 * q + beta_6 - vfmadd213ps ymm3,ymm1,ymm13 ; q = x2 * q + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 ; p = x * p - vbroadcastss ymm0,LogisticConstants.one_half[rax] - vdivps ymm7,ymm7,ymm3 - vxorps ymm3,ymm3,ymm3 - vaddps ymm0,ymm7,ymm0 ; logistic = p / q + 0.5 - vmaxps ymm0,ymm3,ymm0 ; clamp lower bound - vmaskmovps YMMWORD PTR [rdx],ymm2,ymm0 - -ExitKernel: - vzeroupper - movaps xmm6,TransKernelFrame.SavedXmm6[rsp] - movaps xmm7,TransKernelFrame.SavedXmm7[rsp] - movaps xmm8,TransKernelFrame.SavedXmm8[rsp] - movaps xmm9,TransKernelFrame.SavedXmm9[rsp] - movaps xmm10,TransKernelFrame.SavedXmm10[rsp] - movaps xmm11,TransKernelFrame.SavedXmm11[rsp] - movaps xmm12,TransKernelFrame.SavedXmm12[rsp] - movaps xmm13,TransKernelFrame.SavedXmm13[rsp] - movaps xmm14,TransKernelFrame.SavedXmm14[rsp] - movaps xmm15,TransKernelFrame.SavedXmm15[rsp] - add rsp,(TransKernelFrame.ReturnAddress) - - BEGIN_EPILOGUE - - ret - - NESTED_END MlasComputeLogisticF32KernelFma3, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAmx.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAmx.asm deleted file mode 100644 index 4dd292ff1a21c..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAmx.asm +++ /dev/null @@ -1,554 +0,0 @@ -;++ -; -;Copyright (c) Microsoft Corporation. All rights reserved. -; -;Licensed under the MIT License. -; -;Module Name: -; -; QgemmU8S8KernelAmx.asm -; -;Abstract: -; -; This module implements the packing functions for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; These packing functions are suited for AMX Qgemm kernel. The implementation only -; uses AVX2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - -; -; Stack frame layout for the U8S8 CopyPackB routine. -; - -GemmU8S8CopyPackBFrame STRUCT - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - ColumnSumBuffer QWORD ? - BIsSigned QWORD ? - -GemmU8S8CopyPackBFrame ENDS - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; B (rdx) - Supplies the address of the source matrix. -; -; ldb (r8) - Supplies the number of elements per row of the source matrix. -; -; CountN (r9) - Supplies the number of columns of the source matrix to copy. -; -; CountK - Supplies the number of rows of the source matrix to copy. -; -; ColumnSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the columns. -; -; BIsSigned - Supplies true if the source matrix is signed data, else false -; if the source matrix is unsigned data. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemmU8S8CopyPackBAmx, _TEXT - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - END_PROLOGUE - - mov rsi,rdx ; Save B - lea rdi,[r8+r8*2] ; compute ldb * 3 - mov r10,GemmU8S8CopyPackBFrame.CountK[rsp] - mov r11,GemmU8S8CopyPackBFrame.ColumnSumBuffer[rsp] - lea r12,[r10+3] ; compute extra padding for 64|K - shr r12,2 - neg r12 - and r12,15 - vpcmpeqw ymm0,ymm0,ymm0 ; generate word vector [0xFFFF] - vpsrlw ymm0,ymm0,15 ; generate word vector [0x0001] - vpsllw ymm1,ymm0,8 ; generate word vector [0x0100] - vpor ymm1,ymm0,ymm1 ; generate word vector [0x0101] - -; -; Compute the bit flip vector to adjust input from U8 to S8. -; - - vpxor xmm2,xmm2,xmm2 ; generate word vector [0x0000] - cmp BYTE PTR GemmU8S8CopyPackBFrame.BIsSigned[rsp],0 - jnz CopyPackB_SkipUnsignedBitFlipVector - vpsllw ymm2,ymm1,7 ; generate word vector [0x8080] - -CopyPackB_SkipUnsignedBitFlipVector: - -; -; Process 16 columns of matrix B in a loop. -; - - sub r9,16 ; CountN -= 16 - jb CopyPackB_ProcessRemainingColumns - -CopyPackB_ProcessNextColumnN16: - vpxord xmm30,xmm30,xmm30 ; clear column accumulators - vpxord xmm31,xmm31,xmm31 - mov rdx,rsi ; rdx -> B start of 16 columns - add rsi,16 ; advance next matrix B by 16 columns - mov rbx,r10 ; reload rows remaining - sub rbx,4 - jb CopyPackB_ProcessRemainingRowsN16 - -CopyPackB_ProcessNextRowLoopN16: - vmovdqu64 xmm16,XMMWORD PTR [rdx] ; load 4 rows - vmovdqu64 xmm17,XMMWORD PTR [rdx+r8] - vmovdqu64 xmm18,XMMWORD PTR [rdx+r8*2] - vmovdqu64 xmm19,XMMWORD PTR [rdx+rdi] - lea rdx,[rdx+r8*4] ; advance matrix B by 4 rows - -CopyPackB_InterleaveRowDataN16: - vpunpcklbw xmm3,xmm16,xmm17 ; interleave row data - vpunpckhbw xmm17,xmm16,xmm17 - vpunpcklbw xmm16,xmm18,xmm19 - vpunpckhbw xmm19,xmm18,xmm19 - vpunpcklwd xmm18,xmm3,xmm16 - vpunpckhwd xmm3,xmm3,xmm16 - vpunpcklwd xmm16,xmm17,xmm19 - vpunpckhwd xmm17,xmm17,xmm19 - vinserti64x2 ymm18,ymm18,xmm3,1 - vinserti64x2 ymm16,ymm16,xmm17,1 - vpxord ymm18,ymm18,ymm2 ; optionally adjust unsigned data - vpxord ymm16,ymm16,ymm2 - vmovdqu64 YMMWORD PTR [rcx],ymm18 ; store interleaved rows - vmovdqu64 YMMWORD PTR [rcx+32],ymm16 - vpmaddubsw ymm18,ymm1,ymm18 ; horizontal byte+byte=word per row - vpmaddwd ymm18,ymm18,ymm0 ; horizontal word+word=dword per row - vpaddd ymm30,ymm30,ymm18 ; accumulate per column - vpmaddubsw ymm16,ymm1,ymm16 - vpmaddwd ymm16,ymm16,ymm0 - vpaddd ymm31,ymm31,ymm16 - add rcx,64 ; advance matrix D by 64 bytes - sub rbx,4 ; subtract rows remaining - jae CopyPackB_ProcessNextRowLoopN16 - -; -; Process the less than 4 remaining rows where the row has 16 columns. -; - -CopyPackB_ProcessRemainingRowsN16: - add rbx,4 ; correct for over-subtract above - jz CopyPackB_StoreColumnSumBufferN16 - vmovdqu64 xmm16,XMMWORD PTR [rdx] - vmovaps xmm17,xmm2 - vmovaps xmm18,xmm2 - vmovaps xmm19,xmm2 - xor ebx,ebx ; no more rows remaining - test r10b,2 ; (CountK & 2) != 0? - jz CopyPackB_InterleaveRowDataN16 - vmovdqu64 xmm17,XMMWORD PTR [rdx+r8] - test r10b,1 ; (CountK & 1) != 0? - jz CopyPackB_InterleaveRowDataN16 - vmovdqu64 xmm18,XMMWORD PTR [rdx+r8*2] - jmp CopyPackB_InterleaveRowDataN16 - -CopyPackB_StoreColumnSumBufferN16: - vmovdqu64 YMMWORD PTR [r11],ymm30 - vmovdqu64 YMMWORD PTR [r11+32],ymm31 - test r12,r12 - jz CopyPackB_N16K64PaddingFinished - mov rax, r12 - vpxord xmm30,xmm30,xmm30 - -CopyPackB_N16K64Padding: - vmovdqu64 YMMWORD PTR [rcx],ymm30 ; store 0 - vmovdqu64 YMMWORD PTR [rcx+32],ymm30 - add rcx,64 - dec rax - jnz CopyPackB_N16K64Padding - -CopyPackB_N16K64PaddingFinished: - add r11,16*4 ; advance column sum buffer by 16 dwords - sub r9,16 ; subtract columns remaining - jae CopyPackB_ProcessNextColumnN16 - -CopyPackB_ProcessRemainingColumns: - add r9,16 ; correct for over-subtract above - jnz CopyPackB_ProcessColumnNUnaligned - -; -; Restore non-volatile registers and return. -; - -CopyPackB_ExitRoutine: - vzeroupper - - BEGIN_EPILOGUE - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Process the remaining columns of matrix B. -; - -CopyPackB_ProcessColumnNUnaligned: - vpxord xmm30,xmm30,xmm30 ; clear column accumulators - vpxord xmm31,xmm31,xmm31 - mov rax,rcx ; save rcx (D) - mov rcx,r9 ; load left over N - neg ecx ; compute load mask for left over N - and ecx,63 - mov rbx,-1 - shr rbx,cl - kmovq k1,rbx - mov rcx,rax ; restore rcx (D) - sub r10,4 - jb CopyPackB_ProcessRemainingRowsNUnaligned - -CopyPackB_ProcessNextRowLoopNUnaligned: - vmovdqu64 xmm16,xmm2 - vmovdqu8 xmm16 {k1},XMMWORD PTR [rsi] - vmovdqu64 xmm17,xmm2 - vmovdqu8 xmm17 {k1},XMMWORD PTR [rsi+r8] - vmovdqu64 xmm18,xmm2 - vmovdqu8 xmm18 {k1},XMMWORD PTR [rsi+r8*2] - vmovdqu64 xmm19,xmm2 - vmovdqu8 xmm19 {k1},XMMWORD PTR [rsi+rdi] - lea rsi,[rsi+r8*4] ; advance next matrix B by 4 rows - -CopyPackB_InterleaveRowDataUnaligned: - vpunpcklbw xmm3,xmm16,xmm17 ; interleave row data - vpunpckhbw xmm17,xmm16,xmm17 - vpunpcklbw xmm16,xmm18,xmm19 - vpunpckhbw xmm19,xmm18,xmm19 - vpunpcklwd xmm18,xmm3,xmm16 - vpunpckhwd xmm3,xmm3,xmm16 - vpunpcklwd xmm16,xmm17,xmm19 - vpunpckhwd xmm17,xmm17,xmm19 - vinserti64x2 ymm18,ymm18,xmm3,1 - vinserti64x2 ymm16,ymm16,xmm17,1 - vpxord ymm18,ymm18,ymm2 ; optionally adjust unsigned data - vpxord ymm16,ymm16,ymm2 - vmovdqu64 YMMWORD PTR [rcx],ymm18 ; store interleaved rows - vmovdqu64 YMMWORD PTR [rcx+32],ymm16 - vpmaddubsw ymm18,ymm1,ymm18 ; horizontal byte+byte=word per row - vpmaddwd ymm18,ymm18,ymm0 ; horizontal word+word=dword per row - vpaddd ymm30,ymm30,ymm18 ; accumulate per column - vpmaddubsw ymm16,ymm1,ymm16 - vpmaddwd ymm16,ymm16,ymm0 - vpaddd ymm31,ymm31,ymm16 - add rcx,64 ; advance matrix D by 64 bytes - sub r10,4 ; subtract rows remaining - jae CopyPackB_ProcessNextRowLoopNUnaligned - -; -; Process the less than 4 remaining rows where the row has less than 16 columns. -; - -CopyPackB_ProcessRemainingRowsNUnaligned: - add r10,4 - jz CopyPackB_StoreColumnSumBufferNUnaligned - - vmovaps xmm16,xmm2 - vmovdqu8 xmm16 {k1},XMMWORD PTR [rsi] - vmovaps xmm17,xmm2 - vmovaps xmm18,xmm2 - vmovaps xmm19,xmm2 - mov rbx,r10 - xor r10b,r10b ; no more rows remaining - test bl,2 ; (CountK & 2) != 0? - jz CopyPackB_InterleaveRowDataUnaligned - vmovdqu8 xmm17 {k1},XMMWORD PTR [rsi+r8] - test bl,1 ; (CountK & 1) != 0? - jz CopyPackB_InterleaveRowDataUnaligned - vmovdqu8 xmm18 {k1},XMMWORD PTR [rsi+r8*2] - jmp CopyPackB_InterleaveRowDataUnaligned - -CopyPackB_StoreColumnSumBufferNUnaligned: - vmovdqu64 YMMWORD PTR [r11],ymm30 - vmovdqu64 YMMWORD PTR [r11+32],ymm31 - test r12,r12 - jz CopyPackB_ExitRoutine - mov rax, r12 - vpxord xmm30,xmm30,xmm30 - -CopyPackB_K64Padding: - vmovdqu64 YMMWORD PTR [rcx],ymm30 ; store 0 - vmovdqu64 YMMWORD PTR [rcx+32],ymm30 - add rcx,64 - dec rax - jne CopyPackB_K64Padding - jmp CopyPackB_ExitRoutine - - NESTED_END MlasGemmU8S8CopyPackBAmx, _TEXT - - -; -; Stack frame layout for the U8S8 CopyPackA routine. -; - -GemmU8S8CopyPackAFrame STRUCT - - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - RowSumBuffer QWORD ? - -GemmU8S8CopyPackAFrame ENDS - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; A (rdx) - Supplies the address of the source matrix. -; -; lda (r8) - Supplies the number of elements per row of the source matrix. -; -; CountM (r9) - Supplies the number of rows of the source matrix to copy. -; -; CountK - Supplies the number of columns of the source matrix to copy. -; -; RowSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the rows. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemmU8S8CopyPackAAmx, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - END_PROLOGUE - - mov rdi,rcx ; save D - mov rsi,rdx ; save A - mov r10,GemmU8S8CopyPackAFrame.CountK[rsp] - mov r12,GemmU8S8CopyPackAFrame.RowSumBuffer[rsp] - lea r11,[r10+63] - and r11,NOT 63 ; align CountK up to 64 - vpternlogd zmm30,zmm30,zmm30,255 ; generate word vector [0xFFFF] - vpsrlw zmm30,zmm30,15 ; generate word vector [0x0001] - vpsllw zmm31,zmm30,8 ; generate word vector [0x0100] - vpord zmm31,zmm30,zmm31 ; generate word vector [0x0101] - lea r13,[r8+r8*2] ; compute ldb * 3 - lea rax,[r11+r11*2] ; compute AlignedCountK * 3 - mov ecx,r10d ; CountK - neg ecx - and ecx,63 - mov rbx,-1 - shr rbx,cl ; mask for left over k < 64 - kmovq k1,rbx ; mask - -; -; Process 4 rows of matrix A in a loop. -; - - sub r9,4 ; m -= 4 - jb CopyPackA_ProcessRemainingRows - -CopyPackA_ProcessNextRowM4: - vpxor xmm0,xmm0,xmm0 ; clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - mov rdx,rsi ; src = A row beginning - mov rcx,rdi ; dst = D row beginning - lea rsi,[rsi+r8*4] ; advance next matrix A by 4 rows - lea rdi,[rdi+r11*4] ; advance next matrix D by 4 rows - mov rbx,r10 ; k = CountK - sub rbx,64 - jb CopyPackA_ProcessRemainingColumnsM4 - -CopyPackA_ProcessNextColumnLoopM4: - vmovdqu64 zmm16,ZMMWORD PTR [rdx] - vmovdqu64 zmm17,ZMMWORD PTR [rdx+r8] - vmovdqu64 zmm18,ZMMWORD PTR [rdx+r8*2] - vmovdqu64 zmm19,ZMMWORD PTR [rdx+r13] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vmovdqu64 ZMMWORD PTR [rcx+r11],zmm17 - vmovdqu64 ZMMWORD PTR [rcx+r11*2],zmm18 - vmovdqu64 ZMMWORD PTR [rcx+rax],zmm19 - vpmaddubsw zmm16,zmm16,zmm31 ; horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 ; add words to row accumulators - vpmaddubsw zmm17,zmm17,zmm31 - vpaddw zmm1,zmm1,zmm17 - vpmaddubsw zmm18,zmm18,zmm31 - vpaddw zmm2,zmm2,zmm18 - vpmaddubsw zmm19,zmm19,zmm31 - vpaddw zmm3,zmm3,zmm19 - add rdx,64 ; src += 64 - add rcx,64 ; dst += 64 - sub rbx,64 ; k -= 64 - jae CopyPackA_ProcessNextColumnLoopM4 - -CopyPackA_ProcessRemainingColumnsM4: - add rbx,64 ; correct for over-subtract above - jz CopyPackA_ReduceRowSumBufferM4 - - vmovdqu8 zmm16{k1}{z},ZMMWORD PTR [rdx] - vmovdqu8 zmm17{k1}{z},ZMMWORD PTR [rdx+r8] - vmovdqu8 zmm18{k1}{z},ZMMWORD PTR [rdx+r8*2] - vmovdqu8 zmm19{k1}{z},ZMMWORD PTR [rdx+r13] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vmovdqu64 ZMMWORD PTR [rcx+r11],zmm17 - vmovdqu64 ZMMWORD PTR [rcx+r11*2],zmm18 - vmovdqu64 ZMMWORD PTR [rcx+rax],zmm19 - vpmaddubsw zmm16,zmm16,zmm31 ; horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 ; add words to row accumulators - vpmaddubsw zmm17,zmm17,zmm31 - vpaddw zmm1,zmm1,zmm17 - vpmaddubsw zmm18,zmm18,zmm31 - vpaddw zmm2,zmm2,zmm18 - vpmaddubsw zmm19,zmm19,zmm31 - vpaddw zmm3,zmm3,zmm19 - -; -; Reduce the sums for the four rows of output. -; - -CopyPackA_ReduceRowSumBufferM4: - vpmaddwd zmm0,zmm0,zmm30 ; horizontal word+word=dword per row - vpmaddwd zmm1,zmm1,zmm30 - vpmaddwd zmm2,zmm2,zmm30 - vpmaddwd zmm3,zmm3,zmm30 - vextracti64x4 ymm16,zmm0,1 ; fold zmm -> ymm - vextracti64x4 ymm17,zmm1,1 - vextracti64x4 ymm18,zmm2,1 - vextracti64x4 ymm19,zmm3,1 - vpaddd ymm0,ymm0,ymm16 - vpaddd ymm1,ymm1,ymm17 - vpaddd ymm2,ymm2,ymm18 - vpaddd ymm3,ymm3,ymm19 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 - vphaddd ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 ; fold ymm -> xmm - vpaddd xmm0,xmm0,xmm1 - vmovdqu XMMWORD PTR [r12],xmm0 - add r12,4*4 ; advance row sum buffer by 4 dwords - sub r9,4 ; m -= 4 - jae CopyPackA_ProcessNextRowM4 - -CopyPackA_ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jz CopyPackA_ExitRoutine - -; -; Process a single row of matrix A in a loop. -; - -CopyPackA_ProcessNextRowM1: - vpxor xmm0,xmm0,xmm0 ; clear row accumulator - mov rdx,rsi ; src = A - mov rcx,rdi ; dst = D - add rsi,r8 ; A to next row - add rdi,r11 ; D to next row - mov rbx,r10 ; k = CountK - sub rbx,64 ; k -= 64 - jb CopyPackA_ProcessRemainingColumnsM1 - -CopyPackA_ProcessNextColumnLoopM1: - vmovdqu64 zmm16,ZMMWORD PTR [rdx] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vpmaddubsw zmm16,zmm16,zmm31 ; horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 ; add words to row accumulators - add rdx,64 ; src += 64 - add rcx,64 ; dst += 64 - sub rbx,64 ; k -= 64 - jae CopyPackA_ProcessNextColumnLoopM1 - -CopyPackA_ProcessRemainingColumnsM1: - add rbx,64 ; correct for over-subtract above - jz CopyPackA_ReduceRowSumBufferM1 - - vmovdqu8 zmm16{k1}{z},ZMMWORD PTR [rdx] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vpmaddubsw zmm16,zmm16,zmm31 ; horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 ; add words to row accumulators - -; -; Reduce the sum for the single row of output. -; - -CopyPackA_ReduceRowSumBufferM1: - vpmaddwd zmm0,zmm0,zmm30 ; horizontal word+word=dword per row - vextracti64x4 ymm16,zmm0,1 ; fold zmm -> ymm - vpaddd ymm0,ymm0,ymm16 - vextracti128 xmm1,ymm0,1 ; fold ymm -> xmm - vpaddd xmm0,xmm0,xmm1 ; reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r12],xmm0 - add r12,4 ; advance row sum buffer by 1 dword - dec r9 ; decrement rows remaining - jnz CopyPackA_ProcessNextRowM1 - -; -; Restore non-volatile registers and return. -; - -CopyPackA_ExitRoutine: - vzeroupper - - BEGIN_EPILOGUE - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasGemmU8S8CopyPackAAmx, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm deleted file mode 100644 index bc5ac69e730d1..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8S8KernelAvx2.asm +++ /dev/null @@ -1,866 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8S8KernelAvx2.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX2 instructions. -; Support for AVX-VNNI-INT8 for certain code paths. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE AssembleAvxVnni.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - -; -; Stack frame layout for the Int8 CopyPackA routine. -; - -GemmInt8CopyPackAFrame STRUCT - - PaddedMatrixAData OWORD 4 DUP (?) - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - Padding QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - RowSumBuffer QWORD ? - -GemmInt8CopyPackAFrame ENDS - -; -; Stack frame layout for the Int8 CopyPackB routine. -; - -GemmInt8CopyPackBFrame STRUCT - - PaddedMatrixBData OWORD 4 DUP (?) - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - Padding QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - ColumnSumBuffer QWORD ? - BIsSigned QWORD ? - -GemmInt8CopyPackBFrame ENDS - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; A (rdx) - Supplies the address of the source matrix. -; -; lda (r8) - Supplies the number of elements per row of the source matrix. -; -; CountM (r9) - Supplies the number of rows of the source matrix to copy. -; -; CountK - Supplies the number of columns of the source matrix to copy. -; -; RowSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the rows. -; -; Return Value: -; -; None. -; -;-- - -MlasGemmCopyPackAAvx2 MACRO ASigned - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmInt8CopyPackAFrame.SavedR13) - save_xmm128 xmm6,GemmInt8CopyPackAFrame.SavedXmm6 - save_xmm128 xmm7,GemmInt8CopyPackAFrame.SavedXmm7 - save_xmm128 xmm8,GemmInt8CopyPackAFrame.SavedXmm8 - save_xmm128 xmm9,GemmInt8CopyPackAFrame.SavedXmm9 - save_xmm128 xmm10,GemmInt8CopyPackAFrame.SavedXmm10 - - END_PROLOGUE - - mov rdi,rcx - mov rsi,rdx - mov r10,GemmInt8CopyPackAFrame.CountK[rsp] - lea r11,[r10+3] - and r11,NOT 3 ; align CountK up to quad count - mov r12,GemmInt8CopyPackAFrame.RowSumBuffer[rsp] - vpcmpeqw ymm8,ymm8,ymm8 ; generate word vector [0xFFFF] - vpsrlw ymm8,ymm8,15 ; generate word vector [0x0001] - vpsllw ymm9,ymm8,8 ; generate word vector [0x0100] - vpor ymm9,ymm8,ymm9 ; generate word vector [0x0101] - -; -; Compute the conditional load/store mask for an unaligned CountK. -; - - mov eax,r10d - and eax,15 ; isolate unaligned count - add eax,3 - shr eax,2 ; align unaligned count to quad count - neg rax - lea rbx,MlasMaskMoveTableAvx+8*4 - vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] - -; -; Zero initialize the padded stack buffers. -; - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 - -; -; Process 4 rows of matrix A in a loop. -; - - sub r9,4 - jb ProcessRemainingRows - -ProcessNextRowM4: - vpxor xmm0,xmm0,xmm0 ; clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - lea r13,[r8+r8*2] ; compute lda * 3 - lea rax,[r11+r11*2] ; compute output stride * 3 - mov rdx,rsi - mov rcx,rdi - lea rsi,[rsi+r8*4] ; advance next matrix A by 4 rows - lea rdi,[rdi+r11*4] ; advance next matrix D by 4 rows - mov rbx,r10 ; reload columns remaining - sub rbx,32 - jb ProcessRemainingColumnsM4 - -ProcessNextColumnLoopM4: - vmovdqu ymm4,YMMWORD PTR [rdx] - vmovdqu ymm5,YMMWORD PTR [rdx+r8] - vmovdqu ymm6,YMMWORD PTR [rdx+r8*2] - vmovdqu ymm7,YMMWORD PTR [rdx+r13] - vmovdqu YMMWORD PTR [rcx],ymm4 - vmovdqu YMMWORD PTR [rcx+r11],ymm5 - vmovdqu YMMWORD PTR [rcx+r11*2],ymm6 - vmovdqu YMMWORD PTR [rcx+rax],ymm7 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -ELSE - vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators - vpmaddubsw ymm5,ymm5,ymm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw ymm6,ymm6,ymm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw ymm7,ymm7,ymm9 - vpaddw ymm3,ymm3,ymm7 -ENDIF - add rdx,32 ; advance matrix A by 32 bytes - add rcx,32 ; advance matrix D by 32 bytes - sub rbx,32 ; subtract columns remaining - jae ProcessNextColumnLoopM4 - -ProcessRemainingColumnsM4: - add rbx,32 ; correct for over-subtract above - jz ReduceRowSumBufferM4 - test bl,16 ; (CountK & 16) != 0? - jz CopyRemainingCountKLessThan16M4 - vmovdqu xmm4,XMMWORD PTR [rdx] - vmovdqu xmm5,XMMWORD PTR [rdx+r8] - vmovdqu xmm6,XMMWORD PTR [rdx+r8*2] - vmovdqu xmm7,XMMWORD PTR [rdx+r13] - vmovdqu XMMWORD PTR [rcx],xmm4 - vmovdqu XMMWORD PTR [rcx+r11],xmm5 - vmovdqu XMMWORD PTR [rcx+r11*2],xmm6 - vmovdqu XMMWORD PTR [rcx+rax],xmm7 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -ELSE - vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators - vpmaddubsw xmm5,xmm5,xmm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw xmm6,xmm6,xmm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw xmm7,xmm7,xmm9 - vpaddw ymm3,ymm3,ymm7 -ENDIF - add rdx,16 ; advance matrix A by 16 bytes - add rcx,16 ; advance matrix D by 16 bytes - test bl,15 ; test for unaligned columns - jz ReduceRowSumBufferM4 - -; -; Copy the unaligned CountK columns to a zero padded stack buffer. -; - -CopyRemainingCountKLessThan16M4: -.errnz GemmInt8CopyPackAFrame.PaddedMatrixAData - mov rbp,rsp ; GemmInt8CopyPackAFrame.PaddedMatrixAData - test bl,8 ; (CountK & 8) != 0? - jz CopyRemainingCountKLessThan8M4 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r8] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [rdx+r8*2] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [rdx+r13] - mov QWORD PTR [rbp+48],rax - add rdx,8 - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountKLessThan8M4: - test bl,4 ; (CountK & 4) != 0? - jz CopyRemainingCountKLessThan4M4 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r8] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [rdx+r8*2] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [rdx+r13] - mov DWORD PTR [rbp+48],eax - add rdx,4 - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountKLessThan4M4: - test bl,2 ; (CountK & 2) != 0? - jz CopyRemainingCountKLessThan2M4 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r8] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [rdx+r8*2] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [rdx+r13] - mov WORD PTR [rbp+48],ax - add rdx,2 - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountKLessThan2M4: - test bl,1 ; (CountK & 1) != 0? - jz ProcessPaddedMatrixADataM4 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r8] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [rdx+r8*2] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [rdx+r13] - mov BYTE PTR [rbp+48],al - -; -; Process the remaining CountK columns using the zero padded stack buffer. -; - -ProcessPaddedMatrixADataM4: - vmovdqu xmm4,XMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp] - vmovdqu xmm5,XMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp+16] - vmovdqu xmm6,XMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp+32] - vmovdqu xmm7,XMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp+48] - lea rax,[rcx+r11*2] ; compute matrix D plus 2 rows - vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 - vpmaskmovd XMMWORD PTR [rcx+r11],xmm10,xmm5 - vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 - vpmaskmovd XMMWORD PTR [rax+r11],xmm10,xmm7 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -ELSE - vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators - vpmaddubsw xmm5,xmm5,xmm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw xmm6,xmm6,xmm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw xmm7,xmm7,xmm9 - vpaddw ymm3,ymm3,ymm7 -ENDIF - -; -; Reduce the sums for the four rows of output. -; - -ReduceRowSumBufferM4: -IF ASigned EQ 1 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 -ELSE - vpmaddwd ymm0,ymm0,ymm8 ; horizontal word+word=dword per row - vpmaddwd ymm1,ymm1,ymm8 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 - vpmaddwd ymm2,ymm2,ymm8 - vpmaddwd ymm3,ymm3,ymm8 -ENDIF - vphaddd ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 ; extract high dwords - vpaddd xmm0,xmm0,xmm1 ; reduce low/high dwords - vmovdqu XMMWORD PTR [r12],xmm0 - add r12,4*4 ; advance row sum buffer by 4 dwords - sub r9,4 ; subtract rows remaining - jae ProcessNextRowM4 - -ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jz ExitRoutine - -; -; Process a single row of matrix A in a loop. -; - -ProcessNextRowM1: - vpxor xmm0,xmm0,xmm0 ; clear row accumulator - mov rdx,rsi - mov rcx,rdi - add rsi,r8 - add rdi,r11 - mov rbx,r10 ; reload columns remaining - sub rbx,32 - jb ProcessRemainingColumnsM1 - -ProcessNextColumnLoopM1: - vmovdqu ymm4,YMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rcx],ymm4 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -ELSE - vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators -ENDIF - add rdx,32 ; advance matrix A by 32 bytes - add rcx,32 ; advance matrix D by 32 bytes - sub rbx,32 ; subtract columns remaining - jae ProcessNextColumnLoopM1 - -ProcessRemainingColumnsM1: - add rbx,32 ; correct for over-subtract above - jz ReduceRowSumBufferM1 - test bl,16 ; (CountK & 16) != 0? - jz CopyRemainingCountKLessThan16M1 - vmovdqu xmm4,XMMWORD PTR [rdx] - vmovdqu XMMWORD PTR [rcx],xmm4 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -ELSE - vpmaddubsw xmm4,xmm4,xmm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators -ENDIF - add rdx,16 ; advance matrix A by 16 bytes - add rcx,16 ; advance matrix D by 16 bytes - test bl,15 ; test for unaligned columns - jz ReduceRowSumBufferM1 - -; -; Copy the unaligned CountK columns to a zero padded stack buffer. -; - -CopyRemainingCountKLessThan16M1: -.errnz GemmInt8CopyPackAFrame.PaddedMatrixAData - mov rbp,rsp ; GemmInt8CopyPackAFrame.PaddedMatrixAData - test bl,8 ; (CountK & 8) != 0? - jz CopyRemainingCountKLessThan8M1 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountKLessThan8M1: - test bl,4 ; (CountK & 4) != 0? - jz CopyRemainingCountKLessThan4M1 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountKLessThan4M1: - test bl,2 ; (CountK & 2) != 0? - jz CopyRemainingCountKLessThan2M1 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountKLessThan2M1: - test bl,1 ; (CountK & 1) != 0? - jz ProcessPaddedMatrixADataM1 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -; -; Process the remaining CountK columns using the zero padded stack buffer. -; - -ProcessPaddedMatrixADataM1: - vmovdqu xmm4,XMMWORD PTR GemmInt8CopyPackAFrame.PaddedMatrixAData[rsp] - vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 -IF ASigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -ELSE - vpmaddubsw ymm4,ymm4,ymm9 ; horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 ; add words to row accumulators -ENDIF - -; -; Reduce the sum for the single row of output. -; - -ReduceRowSumBufferM1: -IF ASigned EQ 0 - vpmaddwd ymm0,ymm0,ymm8 ; horizontal word+word=dword per row -ENDIF - vextracti128 xmm1,ymm0,1 ; extract high dwords - vpaddd xmm0,xmm0,xmm1 ; reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r12],xmm0 - add r12,4 ; advance row sum buffer by 1 dword - dec r9 ; decrement rows remaining - jnz ProcessNextRowM1 - -; -; Restore non-volatile registers and return. -; - -ExitRoutine: - vzeroupper - movaps xmm6,GemmInt8CopyPackAFrame.SavedXmm6[rsp] - movaps xmm7,GemmInt8CopyPackAFrame.SavedXmm7[rsp] - movaps xmm8,GemmInt8CopyPackAFrame.SavedXmm8[rsp] - movaps xmm9,GemmInt8CopyPackAFrame.SavedXmm9[rsp] - movaps xmm10,GemmInt8CopyPackAFrame.SavedXmm10[rsp] - add rsp,(GemmInt8CopyPackAFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - ENDM - - NESTED_ENTRY MlasGemmU8S8CopyPackAAvx2, _TEXT - MlasGemmCopyPackAAvx2 0 - NESTED_END MlasGemmU8S8CopyPackAAvx2, _TEXT - - NESTED_ENTRY MlasGemmS8CopyPackAAvx2Vnni, _TEXT - MlasGemmCopyPackAAvx2 1 - NESTED_END MlasGemmS8CopyPackAAvx2Vnni, _TEXT - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; B (rdx) - Supplies the address of the source matrix. -; -; ldb (r8) - Supplies the number of elements per row of the source matrix. -; -; CountN (r9) - Supplies the number of columns of the source matrix to copy. -; -; CountK - Supplies the number of rows of the source matrix to copy. -; -; ColumnSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the columns. -; -; BIsSigned - Supplies true if the source matrix is signed data, else false -; if the source matrix is unsigned data. -; -; Return Value: -; -; None. -; -;-- - -MlasGemmCopyPackBAvx2 MACRO IsVnni, BSigned - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - alloc_stack (GemmInt8CopyPackBFrame.SavedRdi) - save_xmm128 xmm6,GemmInt8CopyPackBFrame.SavedXmm6 - save_xmm128 xmm7,GemmInt8CopyPackBFrame.SavedXmm7 - save_xmm128 xmm8,GemmInt8CopyPackBFrame.SavedXmm8 - save_xmm128 xmm9,GemmInt8CopyPackBFrame.SavedXmm9 - - END_PROLOGUE - - mov rsi,rdx - lea rdi,[r8+r8*2] ; compute ldb * 3 - mov r10,GemmInt8CopyPackBFrame.CountK[rsp] - mov r11,GemmInt8CopyPackBFrame.ColumnSumBuffer[rsp] - vpcmpeqw ymm7,ymm7,ymm7 ; generate word vector [0xFFFF] - vpsrlw ymm7,ymm7,15 ; generate word vector [0x0001] - vpsllw ymm8,ymm7,8 ; generate word vector [0x0100] - vpor ymm8,ymm7,ymm8 ; generate word vector [0x0101] - -; -; Compute the bit flip vector to adjust input from U8 to S8. -; - - vpxor xmm9,xmm9,xmm9 ; generate word vector [0x0000] -IF IsVnni EQ 0 - cmp BYTE PTR GemmInt8CopyPackBFrame.BIsSigned[rsp],0 - jnz SkipUnsignedBitFlipVector - vpsllw ymm9,ymm8,7 ; generate word vector [0x8080] -ENDIF -SkipUnsignedBitFlipVector: - -; -; Process 16 columns of matrix B in a loop. -; - - sub r9,16 - jb ProcessRemainingColumns - -ProcessNextColumnN16: - vpxor xmm0,xmm0,xmm0 ; clear column accumulators - vpxor xmm1,xmm1,xmm1 - mov rdx,rsi - add rsi,16 ; advance next matrix B by 16 columns - mov rbx,r10 ; reload rows remaining - sub rbx,4 - jb ProcessRemainingRowsN16 - -ProcessNextRowLoopN16: - vmovdqu xmm2,XMMWORD PTR [rdx] ; load 4 rows - vmovdqu xmm3,XMMWORD PTR [rdx+r8] - vmovdqu xmm4,XMMWORD PTR [rdx+r8*2] - vmovdqu xmm5,XMMWORD PTR [rdx+rdi] - lea rdx,[rdx+r8*4] ; advance matrix B by 4 rows - -InterleaveRowDataN16: - vpunpcklbw xmm6,xmm2,xmm3 ; interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vpunpcklbw xmm2,xmm4,xmm5 - vpunpckhbw xmm5,xmm4,xmm5 - vpunpcklwd xmm4,xmm6,xmm2 - vpunpckhwd xmm6,xmm6,xmm2 - vpunpcklwd xmm2,xmm3,xmm5 - vpunpckhwd xmm3,xmm3,xmm5 - vinserti128 ymm4,ymm4,xmm6,1 - vinserti128 ymm2,ymm2,xmm3,1 -IF IsVnni EQ 0 - vpxor ymm4,ymm4,ymm9 ; optionally adjust unsigned data - vpxor ymm2,ymm2,ymm9 -ENDIF - vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows - vmovdqu YMMWORD PTR [rcx+32],ymm2 -IF IsVnni EQ 1 - IF BSigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbssdYmmYmmYmm ymm1,ymm2,ymm8 - ELSE - VpdpbuudYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbuudYmmYmmYmm ymm1,ymm2,ymm8 - ENDIF -ELSE - vpmaddubsw ymm4,ymm8,ymm4 ; horizontal byte+byte=word per row - vpmaddwd ymm4,ymm4,ymm7 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddubsw ymm2,ymm8,ymm2 - vpmaddwd ymm2,ymm2,ymm7 - vpaddd ymm1,ymm1,ymm2 -ENDIF - add rcx,64 ; advance matrix D by 64 bytes - sub rbx,4 ; subtract rows remaining - jae ProcessNextRowLoopN16 - -; -; Process the less than 4 remaining rows where the row has 16 columns. -; - -ProcessRemainingRowsN16: - add rbx,4 ; correct for over-subtract above - jz StoreColumnSumBufferN16 - vmovdqu xmm2,XMMWORD PTR [rdx] - vmovaps xmm3,xmm9 - vmovaps xmm4,xmm9 - vmovaps xmm5,xmm9 - xor ebx,ebx ; no more rows remaining - test r10b,2 ; (CountK & 2) != 0? - jz InterleaveRowDataN16 - vmovdqu xmm3,XMMWORD PTR [rdx+r8] - test r10b,1 ; (CountK & 1) != 0? - jz InterleaveRowDataN16 - vmovdqu xmm4,XMMWORD PTR [rdx+r8*2] - jmp InterleaveRowDataN16 - -StoreColumnSumBufferN16: - vmovdqu YMMWORD PTR [r11],ymm0 - vmovdqu YMMWORD PTR [r11+32],ymm1 - add r11,16*4 ; advance column sum buffer by 16 dwords - sub r9,16 ; subtract columns remaining - jae ProcessNextColumnN16 - -ProcessRemainingColumns: - add r9,16 ; correct for over-subtract above - jnz ProcessColumnNUnaligned - -; -; Restore non-volatile registers and return. -; - -ExitRoutine: - vzeroupper - movaps xmm6,GemmInt8CopyPackBFrame.SavedXmm6[rsp] - movaps xmm7,GemmInt8CopyPackBFrame.SavedXmm7[rsp] - movaps xmm8,GemmInt8CopyPackBFrame.SavedXmm8[rsp] - movaps xmm9,GemmInt8CopyPackBFrame.SavedXmm9[rsp] - add rsp,(GemmInt8CopyPackBFrame.SavedRdi) - - BEGIN_EPILOGUE - - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Process the remaining columns of matrix B. -; - -ProcessColumnNUnaligned: - vpxor xmm0,xmm0,xmm0 ; clear column accumulators - vpxor xmm1,xmm1,xmm1 - vmovdqu YMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp],ymm9 - vmovdqu YMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp+32],ymm9 - sub r10,4 - jb ProcessRemainingRowsNUnaligned - -ProcessNextRowLoopNUnaligned: - mov rdx,rsi -.errnz GemmInt8CopyPackBFrame.PaddedMatrixBData - mov rbp,rsp ; GemmInt8CopyPackBFrame.PaddedMatrixBData - test r9b,8 ; (CountN & 8) != 0? - jz CopyRemainingCountNLessThan8K4 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r8] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [rdx+r8*2] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [rdx+rdi] - mov QWORD PTR [rbp+48],rax - add rdx,8 ; advance matrix B - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountNLessThan8K4: - test r9b,4 ; (CountN & 4) != 0? - jz CopyRemainingCountNLessThan4K4 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r8] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [rdx+r8*2] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [rdx+rdi] - mov DWORD PTR [rbp+48],eax - add rdx,4 ; advance matrix B - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountNLessThan4K4: - test r9b,2 ; (CountN & 2) != 0? - jz CopyRemainingCountNLessThan2K4 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r8] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [rdx+r8*2] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [rdx+rdi] - mov WORD PTR [rbp+48],ax - add rdx,2 ; advance matrix B - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountNLessThan2K4: - test r9b,1 ; (CountN & 1) != 0? - jz ProcessPaddedMatrixBData - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r8] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [rdx+r8*2] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [rdx+rdi] - mov BYTE PTR [rbp+48],al - -ProcessPaddedMatrixBData: - vmovdqu xmm2,XMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp] - vmovdqu xmm3,XMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp+16] - vmovdqu xmm4,XMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp+32] - vmovdqu xmm5,XMMWORD PTR GemmInt8CopyPackBFrame.PaddedMatrixBData[rsp+48] - vpunpcklbw xmm6,xmm2,xmm3 ; interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vpunpcklbw xmm2,xmm4,xmm5 - vpunpckhbw xmm5,xmm4,xmm5 - vpunpcklwd xmm4,xmm6,xmm2 - vpunpckhwd xmm6,xmm6,xmm2 - vpunpcklwd xmm2,xmm3,xmm5 - vpunpckhwd xmm3,xmm3,xmm5 - vinserti128 ymm4,ymm4,xmm6,1 - vinserti128 ymm2,ymm2,xmm3,1 -IF IsVnni EQ 0 - vpxor ymm4,ymm4,ymm9 ; optionally adjust unsigned data - vpxor ymm2,ymm2,ymm9 -ENDIF - vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows - vmovdqu YMMWORD PTR [rcx+32],ymm2 -IF IsVnni EQ 1 - IF BSigned EQ 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbssdYmmYmmYmm ymm1,ymm2,ymm8 - ELSE - VpdpbuudYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbuudYmmYmmYmm ymm1,ymm2,ymm8 - ENDIF -ELSE - vpmaddubsw ymm4,ymm8,ymm4 ; horizontal byte+byte=word per row - vpmaddwd ymm4,ymm4,ymm7 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddubsw ymm2,ymm8,ymm2 - vpmaddwd ymm2,ymm2,ymm7 - vpaddd ymm1,ymm1,ymm2 -ENDIF - lea rsi,[rsi+r8*4] ; advance next matrix B by 4 rows - add rcx,64 ; advance matrix D by 64 bytes - sub r10,4 ; subtract rows remaining - jae ProcessNextRowLoopNUnaligned - -ProcessRemainingRowsNUnaligned: - add r10,4 - jz StoreColumnSumBufferNUnaligned - -; -; Process the less than 4 remaining rows where the row has less than 16 columns. -; - -.errnz GemmInt8CopyPackBFrame.PaddedMatrixBData - mov rbp,rsp ; GemmInt8CopyPackBFrame.PaddedMatrixBData - vmovdqu YMMWORD PTR [rbp],ymm9 - vmovdqu YMMWORD PTR [rbp+32],ymm9 - -CopyUnalignedRowLoop: - lea rdi,[rbp+16] ; advance next padded buffer by 16 bytes - mov rdx,rsi - test r9b,8 ; (CountN & 8) != 0? - jz CopyRemainingCountNLessThan8KSmall - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 ; advance matrix B - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountNLessThan8KSmall: - test r9b,4 ; (CountN & 4) != 0? - jz CopyRemainingCountNLessThan4KSmall - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 ; advance matrix B - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountNLessThan4KSmall: - test r9b,2 ; (CountN & 2) != 0? - jz CopyRemainingCountNLessThan2KSmall - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 ; advance matrix B - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountNLessThan2KSmall: - test r9b,1 ; (CountN & 1) != 0? - jz DoneCopyRemainingCountNKSmall - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -DoneCopyRemainingCountNKSmall: - dec r10 - jz ProcessPaddedMatrixBData - add rsi,r8 ; advance next matrix B by 1 row - mov rbp,rdi - jmp CopyUnalignedRowLoop - -StoreColumnSumBufferNUnaligned: - vmovdqu YMMWORD PTR [r11],ymm0 - vmovdqu YMMWORD PTR [r11+32],ymm1 - jmp ExitRoutine - -ENDM - - NESTED_ENTRY MlasGemmU8S8CopyPackBAvx2, _TEXT - MlasGemmCopyPackBAvx2 0 ; sign variable not checked if IsVnni = 0 - NESTED_END MlasGemmU8S8CopyPackBAvx2, _TEXT - - NESTED_ENTRY MlasGemmU8CopyPackBAvx2Vnni, _TEXT - MlasGemmCopyPackBAvx2 1, 0 - NESTED_END MlasGemmU8CopyPackBAvx2Vnni, _TEXT - - NESTED_ENTRY MlasGemmS8CopyPackBAvx2Vnni, _TEXT - MlasGemmCopyPackBAvx2 1, 1 - NESTED_END MlasGemmS8CopyPackBAvx2Vnni, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm deleted file mode 100644 index 30c97fd36fa17..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm +++ /dev/null @@ -1,667 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8U8KernelAvx2.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - -; -; Stack frame layout for the U8U8 CopyPackA routine. -; - -GemmU8U8CopyPackAFrame STRUCT - - PaddedMatrixAData OWORD 4 DUP (?) - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - Padding QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - RowSumBuffer QWORD ? - -GemmU8U8CopyPackAFrame ENDS - -; -; Stack frame layout for the U8U8 CopyPackB routine. -; - -GemmU8U8CopyPackBFrame STRUCT - - PaddedMatrixBData OWORD 2 DUP (?) - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountK QWORD ? - ColumnSumBuffer QWORD ? - -GemmU8U8CopyPackBFrame ENDS - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; The kernel expects that elements from matrix A have been zero extended to -; 16-bits and padded to a multiple of 32-bits (two pairs of 16-bit values). -; The kernel can then efficiently broadcast 32-bits from the packed buffer -; and avoid expensive shuffling inside the kernel. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; A (rdx) - Supplies the address of the source matrix. -; -; lda (r8) - Supplies the number of elements per row of the source matrix. -; -; CountM (r9) - Supplies the number of rows of the source matrix to copy. -; -; CountK - Supplies the number of columns of the source matrix to copy. -; -; RowSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the rows. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemmU8U8CopyPackAAvx2, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmU8U8CopyPackAFrame.SavedR13) - save_xmm128 xmm6,GemmU8U8CopyPackAFrame.SavedXmm6 - save_xmm128 xmm7,GemmU8U8CopyPackAFrame.SavedXmm7 - save_xmm128 xmm8,GemmU8U8CopyPackAFrame.SavedXmm8 - save_xmm128 xmm9,GemmU8U8CopyPackAFrame.SavedXmm9 - - END_PROLOGUE - - mov rdi,rcx - mov rsi,rdx - mov r10,GemmU8U8CopyPackAFrame.CountK[rsp] - lea r11,[r10+1] - and r11,NOT 1 ; align CountK up to pair count - mov r12,GemmU8U8CopyPackAFrame.RowSumBuffer[rsp] - vpcmpeqw ymm8,ymm8,ymm8 ; generate word vector [0xFFFF] - vpsrlw ymm8,ymm8,15 ; generate word vector [0x0001] - -; -; Compute the conditional load/store mask for an unaligned CountK. -; - - mov eax,r10d - and eax,15 ; isolate unaligned count - inc eax - shr eax,1 ; align unaligned count to pair count - neg rax - lea rbx,MlasMaskMoveTableAvx+8*4 - vmovdqu ymm9,YMMWORD PTR [rbx+rax*4] - -; -; Zero initialize the padded stack buffers. -; - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32],ymm0 - -; -; Process 4 rows of matrix A in a loop. -; -; Zero extend the source bytes to 16-bits and write to the packed buffer. -; -; The packed buffer has the same data ordering as the source bytes, but CountK -; is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding -; bytes are zero filled. -; -; These 16-bit values are also accumulated into an intermediate per-row -; accumulator. CountK cannot be greater than 128 to avoid overflowing these -; signed 16-bit accumulators. -; - - sub r9,4 - jb ProcessRemainingRows - -ProcessNextRowM4: - vpxor xmm0,xmm0,xmm0 ; clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - mov rdx,rsi - mov rcx,rdi - lea rsi,[rsi+r8*4] ; advance next matrix A by 4 rows - lea rdi,[rdi+r11*8] ; advance next matrix D by 4 rows - mov rbx,r10 ; reload columns remaining - sub rbx,16 - jb ProcessRemainingColumnsM4 - -ProcessNextColumnLoopM4: - lea rax,[rdx+r8*2] ; compute matrix A plus 2 rows - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vpmovzxbw ymm5,XMMWORD PTR [rdx+r8] - vpmovzxbw ymm6,XMMWORD PTR [rax] - vpmovzxbw ymm7,XMMWORD PTR [rax+r8] - lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows - vmovdqu YMMWORD PTR [rcx],ymm4 - vmovdqu YMMWORD PTR [rcx+r11*2],ymm5 - vmovdqu YMMWORD PTR [rax],ymm6 - vmovdqu YMMWORD PTR [rax+r11*2],ymm7 - vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns - vpaddw ymm1,ymm1,ymm5 - vpaddw ymm2,ymm2,ymm6 - vpaddw ymm3,ymm3,ymm7 - add rdx,16 ; advance matrix A by 16 bytes - add rcx,16*2 ; advance matrix D by 16 words - sub rbx,16 ; subtract columns remaining - jae ProcessNextColumnLoopM4 - -ProcessRemainingColumnsM4: - add rbx,16 ; correct for over-subtract above - jz ReduceRowSumBufferM4 - -; -; Copy the unaligned CountK columns to a zero padded stack buffer. -; - -.errnz GemmU8U8CopyPackAFrame.PaddedMatrixAData - mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData - test bl,8 ; (CountK & 8) != 0? - jz CopyRemainingCountKLessThan8M4 - lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r8] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [r13] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [r13+r8] - mov QWORD PTR [rbp+48],rax - add rdx,8 - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountKLessThan8M4: - test bl,4 ; (CountK & 4) != 0? - jz CopyRemainingCountKLessThan4M4 - lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r8] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [r13] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [r13+r8] - mov DWORD PTR [rbp+48],eax - add rdx,4 - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountKLessThan4M4: - test bl,2 ; (CountK & 2) != 0? - jz CopyRemainingCountKLessThan2M4 - lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r8] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [r13] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [r13+r8] - mov WORD PTR [rbp+48],ax - add rdx,2 - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountKLessThan2M4: - test bl,1 ; (CountK & 1) != 0? - jz ProcessPaddedMatrixADataM4 - lea r13,[rdx+r8*2] ; compute matrix A plus 2 rows - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r8] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [r13] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [r13+r8] - mov BYTE PTR [rbp+48],al - -; -; Process the remaining CountK columns using the zero padded stack buffer. -; - -ProcessPaddedMatrixADataM4: - vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp] - vpmovzxbw ymm5,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+16] - vpmovzxbw ymm6,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+32] - vpmovzxbw ymm7,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp+48] - lea rax,[rcx+r11*4] ; compute matrix D plus 2 rows - vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 - vpmaskmovd YMMWORD PTR [rcx+r11*2],ymm9,ymm5 - vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 - vpmaskmovd YMMWORD PTR [rax+r11*2],ymm9,ymm7 - vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns - vpaddw ymm1,ymm1,ymm5 - vpaddw ymm2,ymm2,ymm6 - vpaddw ymm3,ymm3,ymm7 - -; -; Reduce the sums for the four rows of output. -; - -ReduceRowSumBufferM4: - vpmaddwd ymm0,ymm0,ymm8 ; horizontal word+word=dword per row - vpmaddwd ymm1,ymm1,ymm8 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum1/Sum0 - vpmaddwd ymm2,ymm2,ymm8 - vpmaddwd ymm3,ymm3,ymm8 - vphaddd ymm1,ymm2,ymm3 ; reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 ; reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 ; extract high dwords - vpaddd xmm0,xmm0,xmm1 ; reduce low/high dwords - vmovdqu XMMWORD PTR [r12],xmm0 - add r12,4*4 ; advance row sum buffer by 4 dwords - sub r9,4 ; subtract rows remaining - jae ProcessNextRowM4 - -ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jz ExitRoutine - -; -; Process a single row of matrix A in a loop. -; - -ProcessNextRowM1: - vpxor xmm0,xmm0,xmm0 ; clear row accumulator - mov rdx,rsi - mov rcx,rdi - add rsi,r8 - lea rdi,[rdi+r11*2] - mov rbx,r10 ; reload columns remaining - sub rbx,16 - jb ProcessRemainingColumnsM1 - -ProcessNextColumnLoopM1: - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rcx],ymm4 - vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns - add rdx,16 ; advance matrix A by 16 bytes - add rcx,16*2 ; advance matrix D by 16 words - sub rbx,16 ; subtract columns remaining - jae ProcessNextColumnLoopM1 - -ProcessRemainingColumnsM1: - add rbx,16 ; correct for over-subtract above - jz ReduceRowSumBufferM1 - -; -; Copy the unaligned CountK columns to a zero padded stack buffer. -; - -.errnz GemmU8U8CopyPackAFrame.PaddedMatrixAData - mov rbp,rsp ; GemmU8U8CopyPackAFrame.PaddedMatrixAData - test bl,8 ; (CountK & 8) != 0? - jz CopyRemainingCountKLessThan8M1 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountKLessThan8M1: - test bl,4 ; (CountK & 4) != 0? - jz CopyRemainingCountKLessThan4M1 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountKLessThan4M1: - test bl,2 ; (CountK & 2) != 0? - jz CopyRemainingCountKLessThan2M1 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountKLessThan2M1: - test bl,1 ; (CountK & 1) != 0? - jz ProcessPaddedMatrixADataM1 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -; -; Process the remaining CountK columns using the zero padded stack buffer. -; - -ProcessPaddedMatrixADataM1: - vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackAFrame.PaddedMatrixAData[rsp] - vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 - vpaddw ymm0,ymm0,ymm4 ; accumulate per row along columns - -; -; Reduce the sum for the single row of output. -; - -ReduceRowSumBufferM1: - vpmaddwd ymm0,ymm0,ymm8 ; horizontal word+word=dword per row - vextracti128 xmm1,ymm0,1 ; extract high dwords - vpaddd xmm0,xmm0,xmm1 ; reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r12],xmm0 - add r12,4 ; advance row sum buffer by 1 dword - dec r9 ; decrement rows remaining - jnz ProcessNextRowM1 - -; -; Restore non-volatile registers and return. -; - -ExitRoutine: - vzeroupper - movaps xmm6,GemmU8U8CopyPackAFrame.SavedXmm6[rsp] - movaps xmm7,GemmU8U8CopyPackAFrame.SavedXmm7[rsp] - movaps xmm8,GemmU8U8CopyPackAFrame.SavedXmm8[rsp] - movaps xmm9,GemmU8U8CopyPackAFrame.SavedXmm9[rsp] - add rsp,(GemmU8U8CopyPackAFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasGemmU8U8CopyPackAAvx2, _TEXT - -;++ -; -; Routine Description: -; -; This routine copies elements from the source matrix to the destination -; packed buffer. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; B (rdx) - Supplies the address of the source matrix. -; -; ldb (r8) - Supplies the number of elements per row of the source matrix. -; -; CountN (r9) - Supplies the number of columns of the source matrix to copy. -; -; CountK - Supplies the number of rows of the source matrix to copy. -; -; ColumnSumBuffer - Supplies the address of the buffer to receive the sums of -; the elements along each of the columns. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemmU8U8CopyPackBAvx2, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - alloc_stack (GemmU8U8CopyPackBFrame.SavedRsi) - - END_PROLOGUE - - mov rsi,rdx - mov r10,GemmU8U8CopyPackBFrame.CountK[rsp] - mov r11,GemmU8U8CopyPackBFrame.ColumnSumBuffer[rsp] - vpcmpeqw ymm5,ymm5,ymm5 ; generate word vector [0xFFFF] - vpsrlw ymm5,ymm5,15 ; generate word vector [0x0001] - -; -; Zero initialize the padded stack buffers. -; - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp],ymm0 - -; -; Process 16 columns of matrix B in a loop. -; - - sub r9,16 - jb ProcessRemainingColumns - -ProcessNextColumnN16: - vpxor xmm0,xmm0,xmm0 ; clear column accumulators - vpxor xmm1,xmm1,xmm1 - mov rdx,rsi - add rsi,16 ; advance next matrix B by 16 columns - mov rbx,r10 ; reload rows remaining - sub rbx,2 - jb ProcessRemainingRowsN16 - -ProcessNextRowLoopN16: - vmovdqu xmm2,XMMWORD PTR [rdx] ; load 2 rows - vmovdqu xmm3,XMMWORD PTR [rdx+r8] - lea rdx,[rdx+r8*2] ; advance matrix B by 2 rows - vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows - vmovdqu XMMWORD PTR [rcx+16],xmm3 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - add rcx,32 ; advance matrix D by 32 bytes - vpmaddwd ymm4,ymm4,ymm5 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - sub rbx,2 ; subtract rows remaining - jae ProcessNextRowLoopN16 - -ProcessRemainingRowsN16: - add rbx,2 ; correct for over-subtract above - jz StoreColumnSumBufferN16 - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows - vextracti128 xmm3,ymm4,1 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - add rcx,32 ; advance matrix D by 32 bytes - -StoreColumnSumBufferN16: - vmovdqu YMMWORD PTR [r11],ymm0 - vmovdqu YMMWORD PTR [r11+32],ymm1 - add r11,16*4 ; advance column sum buffer by 16 dwords - sub r9,16 ; subtract columns remaining - jae ProcessNextColumnN16 - -ProcessRemainingColumns: - add r9,16 ; correct for over-subtract above - jnz ProcessColumnNUnaligned - -; -; Restore non-volatile registers and return. -; - -ExitRoutine: - vzeroupper - add rsp,(GemmU8U8CopyPackBFrame.SavedRsi) - - BEGIN_EPILOGUE - - pop rsi - pop rbx - pop rbp - ret - -; -; Process the remaining columns of matrix B. -; - -ProcessColumnNUnaligned: - vpxor xmm0,xmm0,xmm0 ; clear column accumulators - vpxor xmm1,xmm1,xmm1 - sub r10,2 - jb ProcessRemainingRowsNUnaligned - -ProcessNextRowLoopNUnaligned: - mov rdx,rsi -.errnz GemmU8U8CopyPackBFrame.PaddedMatrixBData - mov rbp,rsp ; GemmU8U8CopyPackBFrame.PaddedMatrixBData - test r9b,8 ; (CountN & 8) != 0? - jz CopyRemainingCountNLessThan8K2 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r8] - mov QWORD PTR [rbp+16],rax - add rdx,8 ; advance matrix B - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountNLessThan8K2: - test r9b,4 ; (CountN & 4) != 0? - jz CopyRemainingCountNLessThan4K2 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r8] - mov DWORD PTR [rbp+16],eax - add rdx,4 ; advance matrix B - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountNLessThan4K2: - test r9b,2 ; (CountN & 2) != 0? - jz CopyRemainingCountNLessThan2K2 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r8] - mov WORD PTR [rbp+16],ax - add rdx,2 ; advance matrix B - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountNLessThan2K2: - test r9b,1 ; (CountN & 1) != 0? - jz ProcessPaddedMatrixBDataK2 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r8] - mov BYTE PTR [rbp+16],al - -ProcessPaddedMatrixBDataK2: - vmovdqu xmm2,XMMWORD PTR XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp] - vmovdqu xmm3,XMMWORD PTR XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp+16] - vpunpcklbw xmm4,xmm2,xmm3 ; interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vmovdqu XMMWORD PTR [rcx],xmm4 ; store interleaved rows - vmovdqu XMMWORD PTR [rcx+16],xmm3 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - lea rsi,[rsi+r8*2] ; advance next matrix B by 2 rows - add rcx,32 ; advance matrix D by 32 bytes - sub r10,2 ; subtract columns remaining - jae ProcessNextRowLoopNUnaligned - -ProcessRemainingRowsNUnaligned: - add r10,2 - jz StoreColumnSumBufferNUnaligned - mov rdx,rsi -.errnz GemmU8U8CopyPackBFrame.PaddedMatrixBData - mov rbp,rsp ; GemmU8U8CopyPackBFrame.PaddedMatrixBData - test r9b,8 ; (CountN & 8) != 0? - jz CopyRemainingCountNLessThan8K1 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 ; advance matrix B - add rbp,8 ; advance padded buffer destination - -CopyRemainingCountNLessThan8K1: - test r9b,4 ; (CountN & 4) != 0? - jz CopyRemainingCountNLessThan4K1 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 ; advance matrix B - add rbp,4 ; advance padded buffer destination - -CopyRemainingCountNLessThan4K1: - test r9b,2 ; (CountN & 2) != 0? - jz CopyRemainingCountNLessThan2K1 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 ; advance matrix B - add rbp,2 ; advance padded buffer destination - -CopyRemainingCountNLessThan2K1: - test r9b,1 ; (CountN & 1) != 0? - jz ProcessPaddedMatrixBDataK1 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -ProcessPaddedMatrixBDataK1: - vpmovzxbw ymm4,XMMWORD PTR GemmU8U8CopyPackBFrame.PaddedMatrixBData[rsp] - vmovdqu YMMWORD PTR [rcx],ymm4 ; store interleaved rows - vextracti128 xmm3,ymm4,1 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 ; horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 ; accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - -StoreColumnSumBufferNUnaligned: - vmovdqu YMMWORD PTR [r11],ymm0 - vmovdqu YMMWORD PTR [r11+32],ymm1 - jmp ExitRoutine - - NESTED_END MlasGemmU8U8CopyPackBAvx2, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm deleted file mode 100644 index 1705a15fa4dc7..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm +++ /dev/null @@ -1,1014 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8X8KernelAvx2.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX2 and AVX VNNI instructions. -; AVX-VNNI-INT8 support also included. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE AssembleAvxVnni.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - -; -; Stack frame layout for the Int8 kernel. -; - -GemmInt8KernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumBuffer QWORD ? - ColumnSumBuffer QWORD ? - ZeroPointB QWORD ? - ZeroMode QWORD ? - -GemmInt8KernelFrame ENDS - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateRowU8S8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 -IF ColumnCount EQ 16 - vpaddd Vec1Reg,Vec1Reg,ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd Vec2Reg,Vec2Reg,ymm2 -ELSE - vpaddd Vec2Reg,Vec2Reg,ymm3 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm11 - Supplies the block accumulators. -; -; ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. -; - -ComputeBlockAvx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset, ASigned, BSigned - -IF RowCount EQ 1 - vpbroadcastd ymm2,DWORD PTR [rcx+BroadcastOffset] - vpmaddubsw ymm3,ymm2,YMMWORD PTR [rdx+VectorOffset] - vpmaddwd ymm3,ymm3,ymm12 -IF ColumnCount EQ 16 - vpaddd ymm4,ymm4,ymm3 - vpmaddubsw ymm2,ymm2,YMMWORD PTR [rdx+VectorOffset+32] - vpmaddwd ymm2,ymm2,ymm12 - vpaddd ymm5,ymm5,ymm2 -ELSE - vpaddd ymm5,ymm5,ymm3 -ENDIF -ELSE - vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; - -MultiplyAccumulateRowAvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned - -IF ASigned EQ 1 - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbssdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ELSE - IF ColumnCount EQ 16 - VpdpbsudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ENDIF -ELSE - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbusdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ELSE - IF ColumnCount EQ 16 - VpdpbuudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockAvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset, ASigned, BSigned - - vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple times -; and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm11 - Supplies the block accumulators. -; - -ComputeBlockLoop MACRO Isa, ColumnCount, RowCount, ASigned, BSigned - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ComputeBlockLoopExit - - mov rsi,r9 ; reload row length remaining - -IF (ColumnCount EQ 16) AND (RowCount EQ 1) - sub rsi,4*4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock&Isa& ColumnCount, RowCount, 0*64, 0, ASigned, BSigned - ComputeBlock&Isa& ColumnCount, RowCount, 1*64, 4, ASigned, BSigned - ComputeBlock&Isa& ColumnCount, RowCount, 2*64, 8, ASigned, BSigned - ComputeBlock&Isa& ColumnCount, RowCount, 3*64, 12, ASigned, BSigned - add rcx,4*4 ; advance matrix A by 4 quads - add rdx,4*64 ; advance matrix B - sub rsi,4*4 - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add rsi,4*4 ; correct for over-subtract above - jz ComputeBlockLoopExit -ENDIF - -ComputeBlockBy1Loop: - ComputeBlock&Isa& ColumnCount, RowCount, 0, 0, ASigned, BSigned - add rcx,4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,64 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop - -ComputeBlockLoopExit: - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single row of the -; output block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; Vec1Reg - Supplies the high block accumulator register (when ColumnCount -; is 16). -; -; Vec2Reg - Supplies the low block accumulator register. -; -; Implicit Arguments: -; -; ymm0 - Supplies the first vector loaded from matrix B. -; -; ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount -; is 16). -; -; ymm2 - Supplies the broadcast value loaded from matrix A. -; - -MultiplyAccumulateRowU8U8Avx2 MACRO ColumnCount, Vec1Reg, Vec2Reg - - vpmaddwd ymm3,ymm2,ymm0 -IF ColumnCount EQ 16 - vpaddd Vec1Reg,Vec1Reg,ymm3 - vpmaddwd ymm2,ymm2,ymm1 - vpaddd Vec2Reg,Vec2Reg,ymm2 -ELSE - vpaddd Vec2Reg,Vec2Reg,ymm3 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockU8U8Avx2 MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vpmovzxbw ymm0,XMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple times -; and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount - - LOCAL ComputeBlockBy2Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ExitComputeBlockLoop - - mov rsi,r9 ; reload row length remaining - -IF (ColumnCount EQ 16) AND ((RowCount AND 1) EQ 0) - sub rsi,2*4 - jb ProcessRemainingBlocks - -ComputeBlockBy2Loop: - ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 - ComputeBlockU8U8&Isa& ColumnCount, RowCount, 32, 4 - add rcx,2*4 ; advance matrix A by 2 pairs -IF RowCount GT 3 - add rbx,2*4 ; advance matrix A plus 3 rows by 2 pairs -ENDIF - add rdx,2*32 ; advance matrix B - sub rsi,2*4 - jae ComputeBlockBy2Loop - -ProcessRemainingBlocks: - add rsi,2*4 ; correct for over-subtract above - jz ExitComputeBlockLoop - ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 - add rdx,32 ; advance matrix B -ELSE -ComputeBlockBy1Loop: - ComputeBlockU8U8&Isa& ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 pair -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 pair -ENDIF - add rdx,32 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop -ENDIF - -ExitComputeBlockLoop: - - ENDM - -; -; Macro Description: -; -; This macro generates code to produce an output block for a set of columns -; and rows. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r11 - Supplies the address of the row sum buffer. -; -; r12 - Supplies the address of the column sum buffer. -; -; r13 - Optionally supplies the address of the matrix B zero point buffer. -; -; ymm4-ymm15 - Supplies the block accumulators. -; - -ProduceOutputBlock MACRO ColumnCount, RowCount, ASigned, BSigned - - LOCAL SkipScaleByZeroPointB - LOCAL AccumulatorsInitialized - LOCAL ProduceWithInt8AvxVnni - LOCAL ProduceWithU8U8Avx2 - LOCAL ExitProduceOutputBlock - -; -; Initialize the accumulators with the row and column sums. -; - - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, -IF ColumnCount EQ 16 - vmovdqu ymm0,YMMWORD PTR [r12] - vmovdqu ymm1,YMMWORD PTR [r12+32] - add r12,16*4 ; advance ColumnSumBuffer by 16 columns -ELSE - vmovdqu ymm1,YMMWORD PTR [r12] -ENDIF - test r13,r13 ; per column zero points? - jz SkipScaleByZeroPointB -IF ColumnCount EQ 16 - vmovdqu ymm2,YMMWORD PTR [r13] - vmovdqu ymm3,YMMWORD PTR [r13+32] - add r13,16*4 ; advance ZeroPointB by 16 columns -ELSE - vmovdqu ymm3,YMMWORD PTR [r13] -ENDIF - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - jmp AccumulatorsInitialized - -SkipScaleByZeroPointB: - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - -AccumulatorsInitialized: - -; -; Iterate over the length of a matrix A row to produce the output accumulators. -; - -IF RowCount GT 3 - lea rbx,[r9*2+r9] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],0 - jg ProduceWithU8U8Avx2 -IF RowCount LE 4 - jl ProduceWithInt8AvxVnni - ComputeBlockLoop Avx2, ColumnCount, RowCount, ASigned, BSigned - jmp ExitProduceOutputBlock -ENDIF - -ProduceWithInt8AvxVnni: - ComputeBlockLoop AvxVnni, ColumnCount, RowCount, ASigned, BSigned - jmp ExitProduceOutputBlock - -ProduceWithU8U8Avx2: - ComputeBlockLoopU8U8 Avx2, ColumnCount, RowCount - -ExitProduceOutputBlock: -IF RowCount GT 3 - lea rbx,[rax*2+rax] - add rbx,r8 ; compute matrix C plus 3 rows -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r10b - Supplies the zero mode flag. -; -; r11 - Supplies the address of the row sum buffer. -; -; r12 - Supplies the address of the column sum buffer. -; -; r13 - Optionally supplies the address of the matrix B zero point buffer. -; - -ProcessCountM MACRO RowCount, ASigned, BSigned, Fallthrough - - LOCAL ProcessNextColumnLoop16xN - LOCAL SkipAccumulateOutput16xNBlock - LOCAL OutputMasked16xNBlock - LOCAL ExitProcessCountM - LOCAL ProcessRemainingCountN - LOCAL SkipAccumulateOutput8xNBlock - LOCAL SkipAccumulateOutputMasked16xNBlock - LOCAL OutputMasked8xNBlock - LOCAL SkipAccumulateOutputMasked8xNBlock - - cmp rbp,8 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop16xN: - ProduceOutputBlock 16, RowCount, ASigned, BSigned - sub rbp,16 - jb OutputMasked16xNBlock - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,8 - ja ProcessNextColumnLoop16xN - test rbp,rbp - jnz ProcessRemainingCountN - -ExitProcessCountM: - mov eax,RowCount - jmp ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 8, RowCount, ASigned, BSigned - cmp rbp,8 - jb OutputMasked8xNBlock - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp ExitProcessCountM - -OutputMasked16xNBlock: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutputMasked16xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked16xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,8*4 ; advance matrix C by 8 columns -IF RowCount GT 3 - add rbx,8*4 ; advance matrix C plus 3 rows by 8 columns -ENDIF - add rbp,8 ; correct for over-subtract above - -OutputMasked8xNBlock: - neg rbp - lea rcx,MlasMaskMoveTableAvx+8*4 - vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutputMasked8xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutputMasked8xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - jmp ExitProcessCountM - - ENDM - - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmCopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmCopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A. These values have -; been pre-scaled by the zero point offset of matrix B if the offset is -; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be -; scaled by the per-column zero point offsets of matrix B. These values are -; accumulated into every row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix -; B, else nullptr if the matrix B is using per-tensor quantization. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - -MlasGemmInt8KernelAvx2 MACRO ASigned, BSigned - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - alloc_stack (GemmInt8KernelFrame.SavedR13) - save_xmm128 xmm6,GemmInt8KernelFrame.SavedXmm6 - save_xmm128 xmm7,GemmInt8KernelFrame.SavedXmm7 - save_xmm128 xmm8,GemmInt8KernelFrame.SavedXmm8 - save_xmm128 xmm9,GemmInt8KernelFrame.SavedXmm9 - save_xmm128 xmm10,GemmInt8KernelFrame.SavedXmm10 - save_xmm128 xmm11,GemmInt8KernelFrame.SavedXmm11 - save_xmm128 xmm12,GemmInt8KernelFrame.SavedXmm12 - save_xmm128 xmm13,GemmInt8KernelFrame.SavedXmm13 - save_xmm128 xmm14,GemmInt8KernelFrame.SavedXmm14 - save_xmm128 xmm15,GemmInt8KernelFrame.SavedXmm15 - - END_PROLOGUE - - mov DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],eax - mov rdi,rcx - mov rbx,GemmInt8KernelFrame.CountM[rsp] - mov rbp,GemmInt8KernelFrame.CountN[rsp] - mov rax,GemmInt8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmInt8KernelFrame.ZeroMode[rsp] - mov r11,GemmInt8KernelFrame.RowSumBuffer[rsp] - mov r12,GemmInt8KernelFrame.ColumnSumBuffer[rsp] - mov r13,GemmInt8KernelFrame.ZeroPointB[rsp] - vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] - cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],0 - je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers - -; -; Process CountM rows of the matrices. -; - -CheckCountM6OrMore: - cmp rbx,5 - ja ProcessCountM6 - je ProcessCountM5 - -CheckCountM4OrMore: - cmp rbx,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp rbx,1 - je ProcessCountM1 - -ProcessCountM2: - ProcessCountM 2, ASigned, BSigned - -ProcessCountM4: - ProcessCountM 4, ASigned, BSigned - -ProcessCountM6: - ProcessCountM 6, ASigned, BSigned - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,GemmInt8KernelFrame.SavedXmm6[rsp] - movaps xmm7,GemmInt8KernelFrame.SavedXmm7[rsp] - movaps xmm8,GemmInt8KernelFrame.SavedXmm8[rsp] - movaps xmm9,GemmInt8KernelFrame.SavedXmm9[rsp] - movaps xmm10,GemmInt8KernelFrame.SavedXmm10[rsp] - movaps xmm11,GemmInt8KernelFrame.SavedXmm11[rsp] - movaps xmm12,GemmInt8KernelFrame.SavedXmm12[rsp] - movaps xmm13,GemmInt8KernelFrame.SavedXmm13[rsp] - movaps xmm14,GemmInt8KernelFrame.SavedXmm14[rsp] - movaps xmm15,GemmInt8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmInt8KernelFrame.SavedR13) - - BEGIN_EPILOGUE - - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - -ProcessCountM1: - ProcessCountM 1, ASigned, BSigned - -ProcessCountM3: - ProcessCountM 3, ASigned, BSigned - -ProcessCountM5: - ProcessCountM 5, ASigned, BSigned - - ENDM - -; -; Reduce code size for the various types of kernels by sharing the outer logic -; and switching on the selector codes (using sign bit to discriminate). -; - - NESTED_ENTRY MlasGemmU8S8KernelAvxVnni, _TEXT - - mov eax,-1 - MlasGemmInt8KernelAvx2 0, 1 - - NESTED_END MlasGemmU8S8KernelAvxVnni, _TEXT - - NESTED_ENTRY MlasGemmU8U8KernelAvx2Vnni, _TEXT - - mov eax,-1 - MlasGemmInt8KernelAvx2 0, 0 - - NESTED_END MlasGemmU8U8KernelAvx2Vnni, _TEXT - - NESTED_ENTRY MlasGemmU8U8KernelAvx2, _TEXT - - mov eax,1 - MlasGemmInt8KernelAvx2 0, 0 - - NESTED_END MlasGemmU8U8KernelAvx2, _TEXT - - NESTED_ENTRY MlasGemmU8S8KernelAvx2, _TEXT - - xor eax,eax - MlasGemmInt8KernelAvx2 0, 1 - - NESTED_END MlasGemmU8S8KernelAvx2, _TEXT - - NESTED_ENTRY MlasGemmS8S8KernelAvx2Vnni, _TEXT - - mov eax,-1 - MlasGemmInt8KernelAvx2 1, 1 - - NESTED_END MlasGemmS8S8KernelAvx2Vnni, _TEXT - - NESTED_ENTRY MlasGemmS8U8KernelAvx2Vnni, _TEXT - - mov eax,-1 - MlasGemmInt8KernelAvx2 1, 0 - - NESTED_END MlasGemmS8U8KernelAvx2Vnni, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm deleted file mode 100644 index 9606c3ddb006f..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx512Core.asm +++ /dev/null @@ -1,764 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemmU8X8KernelAvx512Core.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/matrix -; multiply operation (QGEMM). -; -; This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE AssembleAvx512Vnni.inc - .list - -; -; Stack frame layout for the U8X8 kernel. -; - -GemmU8X8KernelFrame STRUCT - - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - SavedR14 QWORD ? - SavedR13 QWORD ? - SavedR12 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountM QWORD ? - CountN QWORD ? - ldc QWORD ? - RowSumBuffer QWORD ? - ColumnSumBuffer QWORD ? - ZeroPointB QWORD ? - ZeroMode QWORD ? - -GemmU8X8KernelFrame ENDS - -; -; Macro Description: -; -; This macro generates code to load packed data from matrix B. -; -; Arguments: -; -; VecReg - Supplies the register to load the data into. -; -; AddressOperand - Supplies the address operand. -; - -LoadPackedMatrixBU8S8 MACRO VecReg, AddressOperand - - vmovdqu32 VecReg,ZMMWORD PTR AddressOperand - - ENDM - -LoadPackedMatrixBU8U8 MACRO VecReg, AddressOperand - - vpmovzxbw VecReg,YMMWORD PTR AddressOperand - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulator a single cell of the -; output block. -; -; Arguments: -; -; AccumReg - Supplies the register to accumulate into. -; -; Mult1Reg - Supplies the first multiplication operand register. -; -; Mult2Reg - Supplies the second multiplication operand register. -; -; Implicit Arguments: -; -; zmm4 - Supplies a scratch register for intermediate results. -; -; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. -; - -MultiplyAccumulateCellU8S8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm4,Mult1Reg,Mult2Reg - vpmaddwd zmm4,zmm4,zmm13 - vpaddd AccumReg,AccumReg,zmm4 - - ENDM - -MultiplyAccumulateCellU8S8Avx512Vnni MACRO AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsZmmZmmZmm AccumReg,Mult1Reg,Mult2Reg - - ENDM - -MultiplyAccumulateCellU8U8Avx512Core MACRO AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd zmm4,Mult1Reg,Mult2Reg - vpaddd AccumReg,AccumReg,zmm4 - - ENDM - -; -; Macro Description: -; -; This macro generates code to multiply and accumulate each row of the output -; block. -; -; Arguments: -; -; Type - Supplies the type of kernel to generate (U8S8 or U8U8). -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlock MACRO Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset - -IF ColumnCount GE 48 - LoadPackedMatrixB&Type& zmm0,[rdx+VectorOffset] - LoadPackedMatrixB&Type& zmm1,[rdx+r14+VectorOffset] - LoadPackedMatrixB&Type& zmm2,[rdx+r14*2+VectorOffset] -ELSEIF ColumnCount GE 32 - LoadPackedMatrixB&Type& zmm1,[rdx+VectorOffset] - LoadPackedMatrixB&Type& zmm2,[rdx+r14+VectorOffset] -ELSE - LoadPackedMatrixB&Type& zmm2,[rdx+VectorOffset] -ENDIF - EmitIfCountGE RowCount, 1, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCountGE RowCount, 2, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCountGE RowCount, 3, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCountGE RowCount, 4, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCountGE RowCount, 5, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCountGE RowCount, 6, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple times -; and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string. -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rbx - Supplies the address into the matrix A data plus 3 rows. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; -; zmm14-zmm31 - Supplies the block accumulators. -; - -ComputeBlockLoopU8S8 MACRO Isa, ColumnCount, RowCount - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL ComputeBlockLoopExit - - mov rsi,r9 ; reload row length remaining - -IF (RowCount EQ 1) OR ((RowCount AND 1) EQ 0) - sub rsi,4*4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0*64, 0 - ComputeBlock U8S8, Isa, ColumnCount, RowCount, 1*64, 4 - ComputeBlock U8S8, Isa, ColumnCount, RowCount, 2*64, 8 - ComputeBlock U8S8, Isa, ColumnCount, RowCount, 3*64, 12 - add rcx,4*4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4*4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,4*64 ; advance matrix B - sub rsi,4*4 ; decrement quads remaining - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add rsi,4*4 ; correct for over-subtract above - jz ComputeBlockLoopExit -ENDIF - -ComputeBlockBy1Loop: - ComputeBlock U8S8, Isa, ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 quad -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 quad -ENDIF - add rdx,64 ; advance matrix B - sub rsi,4 ; decrement quads remaining - jnz ComputeBlockBy1Loop - -ComputeBlockLoopExit: - - ENDM - -ComputeBlockLoopU8U8 MACRO Isa, ColumnCount, RowCount - - LOCAL ComputeBlockBy1Loop - - mov rsi,r9 ; reload row length remaining - -ComputeBlockBy1Loop: - ComputeBlock U8U8, Isa, ColumnCount, RowCount, 0, 0 - add rcx,4 ; advance matrix A by 1 pair -IF RowCount GT 3 - add rbx,4 ; advance matrix A plus 3 rows by 1 pair -ENDIF - add rdx,32 ; advance matrix B - sub rsi,4 - jnz ComputeBlockBy1Loop - - ENDM - -; -; Macro Description: -; -; This macro generates code to produce an output block for a set of columns -; and rows. -; -; Arguments: -; -; ColumnCount - Supplies the number of columns to produce. -; -; RowCount - Supplies the number of rows to produce. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address into the matrix A data. -; -; rdx - Supplies the address into the matrix B data. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r11 - Supplies the address of the row sum buffer. -; -; r12 - Supplies the address of the column sum buffer. -; - -ProduceOutputBlock MACRO ColumnCount, RowCount - - LOCAL SkipScaleByZeroPointB - LOCAL AccumulatorsInitialized - LOCAL ProduceWithU8S8Avx512Core - LOCAL ProduceWithU8U8Avx512Core - LOCAL ExitProduceOutputBlock - -; -; Initialize the accumulators with the row and column sums. -; - -IF ColumnCount GE 32 -IF ColumnCount GE 48 - vmovdqu32 zmm2,ZMMWORD PTR [r12] - vmovdqu32 zmm1,ZMMWORD PTR [r12+64] - vmovdqu32 zmm0,ZMMWORD PTR [r12+128] -ELSE - vmovdqu32 zmm1,ZMMWORD PTR [r12] - vmovdqu32 zmm0,ZMMWORD PTR [r12+64] -ENDIF - add_immed r12,ColumnCount*4 ; advance ColumnSumBuffer by N columns -ELSE - vmovdqu32 zmm0,ZMMWORD PTR [r12] -ENDIF - test r13,r13 ; per column zero points? - jz SkipScaleByZeroPointB -IF ColumnCount GE 32 -IF ColumnCount GE 48 - vmovdqu32 zmm5,ZMMWORD PTR [r13] - vmovdqu32 zmm4,ZMMWORD PTR [r13+64] - vmovdqu32 zmm3,ZMMWORD PTR [r13+128] -ELSE - vmovdqu32 zmm4,ZMMWORD PTR [r13] - vmovdqu32 zmm3,ZMMWORD PTR [r13+64] -ENDIF - add_immed r13,ColumnCount*4 ; advance ZeroPointB by N columns -ELSE - vmovdqu32 zmm3,ZMMWORD PTR [r13] -ENDIF - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - jmp AccumulatorsInitialized - -SkipScaleByZeroPointB: - EmitIfCount2GE RowCount, 1, ColumnCount, 16, - EmitIfCount2GE RowCount, 1, ColumnCount, 32, - EmitIfCount2GE RowCount, 1, ColumnCount, 48, - EmitIfCount2GE RowCount, 2, ColumnCount, 16, - EmitIfCount2GE RowCount, 2, ColumnCount, 32, - EmitIfCount2GE RowCount, 2, ColumnCount, 48, - EmitIfCount2GE RowCount, 3, ColumnCount, 16, - EmitIfCount2GE RowCount, 3, ColumnCount, 32, - EmitIfCount2GE RowCount, 3, ColumnCount, 48, - EmitIfCount2GE RowCount, 4, ColumnCount, 16, - EmitIfCount2GE RowCount, 4, ColumnCount, 32, - EmitIfCount2GE RowCount, 4, ColumnCount, 48, - EmitIfCount2GE RowCount, 5, ColumnCount, 16, - EmitIfCount2GE RowCount, 5, ColumnCount, 32, - EmitIfCount2GE RowCount, 5, ColumnCount, 48, - EmitIfCount2GE RowCount, 6, ColumnCount, 16, - EmitIfCount2GE RowCount, 6, ColumnCount, 32, - EmitIfCount2GE RowCount, 6, ColumnCount, 48, - -AccumulatorsInitialized: - -; -; Iterate over the length of a matrix A row to produce the output accumulators. -; - -IF RowCount GT 3 - lea rbx,[r9*2+r9] - add rbx,rcx ; compute matrix A plus 3 rows -ENDIF - cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 - je ProduceWithU8S8Avx512Core - jg ProduceWithU8U8Avx512Core - ComputeBlockLoopU8S8 Avx512Vnni, ColumnCount, RowCount - jmp ExitProduceOutputBlock - -ProduceWithU8U8Avx512Core: - ComputeBlockLoopU8U8 Avx512Core, ColumnCount, RowCount - jmp ExitProduceOutputBlock - -ProduceWithU8S8Avx512Core: - ComputeBlockLoopU8S8 Avx512Core, ColumnCount, RowCount - -ExitProduceOutputBlock: -IF RowCount GT 3 - lea rbx,[rax*2+rax] - add rbx,r8 ; compute matrix C plus 3 rows -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; r8 - Supplies the address of matrix C. -; -; rdi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r9 - Supplies the length in bytes of a row from matrix A. -; -; r10b - Supplies the zero mode flag. -; -; r11 - Supplies the address of the row sum buffer. -; -; r12 - Supplies the address of the column sum buffer. -; -; r14 - Supplies the stride in bytes of between packed blocks of matrix B. -; - -ProcessCountM MACRO RowCount - - LOCAL ProcessNextColumnLoop32xN - LOCAL Output32xNBlock - LOCAL SkipAccumulateOutput32xNBlock - LOCAL Output16xNBlock - LOCAL Output16xNBlockWithMask - LOCAL SkipAccumulateOutput16xNBlockWithMask - LOCAL ProcessRemainingCountN - LOCAL ProcessNextColumnLoop48xN - LOCAL SkipAccumulateOutput48xNBlock - - cmp rbp,32 - ja ProcessNextColumnLoop48xN - cmp rbp,16 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop32xN: - ProduceOutputBlock 32, RowCount - add rdx,r14 ; advance matrix B by packed block stride - -Output32xNBlock: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput32xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput32xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns -IF RowCount GT 3 - add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns -ENDIF - sub rbp,16 - -Output16xNBlock: - sub rbp,16 - jae Output16xNBlockWithMask - lea ecx,[ebp+16] ; correct for over-subtract above - mov esi,1 - shl esi,cl - dec esi - kmovw k1,esi ; update mask for remaining columns - xor ebp,ebp ; no more columns remaining - -Output16xNBlockWithMask: - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput16xNBlockWithMask - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput16xNBlockWithMask: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rdi ; reload matrix A - cmp rbp,32 - ja ProcessNextColumnLoop48xN - cmp rbp,16 - ja ProcessNextColumnLoop32xN - test rbp,rbp - jnz ProcessRemainingCountN - mov eax,RowCount - jmp ExitKernel - -ProcessRemainingCountN: - ProduceOutputBlock 16, RowCount - jmp Output16xNBlock - -ProcessNextColumnLoop48xN: - ProduceOutputBlock 48, RowCount - lea rdx,[rdx+r14*2] ; advance matrix B by packed block stride - test r10b,r10b ; ZeroMode? - jnz SkipAccumulateOutput48xNBlock - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - -SkipAccumulateOutput48xNBlock: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 6, - add r8,16*4 ; advance matrix C by 16 columns -IF RowCount GT 3 - add rbx,16*4 ; advance matrix C plus 3 rows by 16 columns -ENDIF - sub rbp,16 - jmp Output32xNBlock - - ENDM - -; -; Reduce code size for the various types of kernels by sharing the outer logic -; and switching on the selector codes (using sign bit to discriminate). -; - - LEAF_ENTRY MlasGemmU8S8KernelAvx512Vnni, _TEXT - - mov eax,-1 - jmp MlasGemmU8X8KernelAvx512Core - - LEAF_END MlasGemmU8S8KernelAvx512Vnni, _TEXT - - LEAF_ENTRY MlasGemmU8U8KernelAvx512Core, _TEXT - - mov eax,1 - jmp MlasGemmU8X8KernelAvx512Core - - LEAF_END MlasGemmU8U8KernelAvx512Core, _TEXT - - LEAF_ENTRY MlasGemmU8S8KernelAvx512Core, _TEXT - - xor eax,eax - jmp MlasGemmU8X8KernelAvx512Core - - LEAF_END MlasGemmU8S8KernelAvx512Core, _TEXT - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. The matrix data has been packed -; using MlasGemmU8X8CopyPackAAvx2. -; -; B (rdx) - Supplies the address of matrix B. The matrix data has been packed -; using MlasGemmU8X8CopyPackBAvx2. -; -; C (r8) - Supplies the address of matrix C. -; -; PackedCountK (r9) - Supplies the number of packed columns from matrix A and -; the number of packed rows from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldc - Supplies the first dimension of matrix C. -; -; RowSumBuffer - Supplies the sum of each row from matrix A. These values have -; been pre-scaled by the zero point offset of matrix B if the offset is -; per-tensor (ZeroPointB is nullptr). Otherwise, these values must be -; scaled by the per-column zero point offsets of matrix B. These values are -; accumulated into every row of matrix C. -; -; ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied -; by the zero point offset of matrix A. These values are accumulated into -; every column of matrix C. -; -; ZeroPointB - Optionally supplies the per-column zero point offsets of matrix -; B, else nullptr if the matrix B is using per-tensor quantization. -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - NESTED_ENTRY MlasGemmU8X8KernelAvx512Core, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r12 - push_reg r13 - push_reg r14 - alloc_stack (GemmU8X8KernelFrame.SavedR14) - save_xmm128 xmm13,GemmU8X8KernelFrame.SavedXmm13 - save_xmm128 xmm14,GemmU8X8KernelFrame.SavedXmm14 - save_xmm128 xmm15,GemmU8X8KernelFrame.SavedXmm15 - - END_PROLOGUE - - mov DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],eax - mov rdi,rcx - mov rbx,GemmU8X8KernelFrame.CountM[rsp] - mov rbp,GemmU8X8KernelFrame.CountN[rsp] - mov rax,GemmU8X8KernelFrame.ldc[rsp] - shl rax,2 ; convert ldc to bytes - shl r9,2 ; convert to row length - movzx r10,BYTE PTR GemmU8X8KernelFrame.ZeroMode[rsp] - mov r11,GemmU8X8KernelFrame.RowSumBuffer[rsp] - mov r12,GemmU8X8KernelFrame.ColumnSumBuffer[rsp] - mov r13,GemmU8X8KernelFrame.ZeroPointB[rsp] - mov esi,-1 - kmovw k1,esi ; update mask to write all columns - neg esi - vpbroadcastw zmm13,esi ; generate 512-bit word vector [0x0001] - lea rsi,[r9*8] ; compute matrix B packed stride (U8U8) - lea r14,[rsi*2] ; compute matrix B packed stride (U8S8) - cmp DWORD PTR GemmU8X8KernelFrame.PreviousP1Home[rsp],0 - cmovg r14,rsi ; select matrix B packed stride - -; -; Process CountM rows of the matrices. -; - - cmp rbx,5 - ja ProcessCountM6 - je ProcessCountM5 - cmp rbx,3 - ja ProcessCountM4 - je ProcessCountM3 - cmp rbx,1 - ja ProcessCountM2 - -ProcessCountM1: - ProcessCountM 1 - -ProcessCountM2: - ProcessCountM 2 - -ProcessCountM3: - ProcessCountM 3 - -ProcessCountM4: - ProcessCountM 4 - -ProcessCountM5: - ProcessCountM 5 - -ProcessCountM6: - ProcessCountM 6 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm13,GemmU8X8KernelFrame.SavedXmm13[rsp] - movaps xmm14,GemmU8X8KernelFrame.SavedXmm14[rsp] - movaps xmm15,GemmU8X8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmU8X8KernelFrame.SavedR14) - - BEGIN_EPILOGUE - - pop r14 - pop r13 - pop r12 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasGemmU8X8KernelAvx512Core, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm deleted file mode 100644 index 3a6f83d2c2ada..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx2.asm +++ /dev/null @@ -1,374 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8S8KernelAvx2.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/vector -; multiply operation (QGEMV). -; -; This implementation uses AVX2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMaskMoveAvx:NEAR - EXTERN MlasTranspose4x4BytesAvx:NEAR - -; -; Stack frame layout for the U8S8 kernel. -; - -GemvU8S8KernelFrame STRUCT - - SavedXmm6 OWORD ? - Padding QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountN QWORD ? - ldb QWORD ? - -GemvU8S8KernelFrame ENDS - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix/vector multiplication. -; -; Arguments: -; -; A (rcx) - Supplies the address of vector A. -; -; B (rdx) - Supplies the address of matrix B. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from vector A and the number -; of rows from matrix B to iterate over. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldb - Supplies the first dimension of matrix B. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemvU8S8KernelAvx2, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - alloc_stack (GemvU8S8KernelFrame.SavedRdi) - save_xmm128 xmm6,GemvU8S8KernelFrame.SavedXmm6 - - END_PROLOGUE - - mov rsi,rdx - mov rdi,GemvU8S8KernelFrame.ldb[rsp] - mov r10,GemvU8S8KernelFrame.CountN[rsp] - mov r11,rsp ; set ZeroMode to any non-zero value - vpcmpeqw ymm6,ymm6,ymm6 ; generate word vector [0xFFFF] - vpsrlw ymm6,ymm6,15 ; generate word vector [0x0001] - -; -; Process 4 rows of matrix B in a loop. -; - - sub r9,4 - jb ProcessRemainingRows - -ProcessRowLoop4: - mov rdx,rsi ; reload matrix B - lea rsi,[rsi+rdi*4] ; advance matrix B by 4 rows - mov rbx,r8 ; reload matrix C - mov rbp,r10 ; reload CountN - vpbroadcastd ymm0,DWORD PTR [rcx] - add rcx,4 ; advance matrix A by 4 bytes - -; -; Process sets of 32 columns from the 4 rows in a loop. -; -; Some permute operations are deferred until the final store of the 4x32 block -; as these permutes are expensive. -; - -ProcessColumnLoop4By32: - cmp rbp,32 - jb ProcessColumnLoop4By8 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovdqu ymm2,YMMWORD PTR [rdx] - vmovdqu ymm3,YMMWORD PTR [rdx+rdi] - vmovdqu ymm4,YMMWORD PTR [rax] - vmovdqu ymm5,YMMWORD PTR [rax+rdi] - vpunpcklbw ymm1,ymm2,ymm3 ; interleave row data bytes - vpunpckhbw ymm2,ymm2,ymm3 - vpunpcklbw ymm3,ymm4,ymm5 - vpunpckhbw ymm4,ymm4,ymm5 - vpunpcklwd ymm5,ymm1,ymm3 ; interleave row data words - vpunpckhwd ymm1,ymm1,ymm3 - vpunpcklwd ymm3,ymm2,ymm4 - vpunpckhwd ymm2,ymm2,ymm4 - vpmaddubsw ymm5,ymm0,ymm5 ; multiply and reduce - vpmaddwd ymm5,ymm5,ymm6 - vpmaddubsw ymm1,ymm0,ymm1 - vpmaddwd ymm1,ymm1,ymm6 - vpmaddubsw ymm3,ymm0,ymm3 - vpmaddwd ymm3,ymm3,ymm6 - vpmaddubsw ymm2,ymm0,ymm2 - vpmaddwd ymm2,ymm2,ymm6 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By32 - vpaddd ymm5,ymm5,YMMWORD PTR [rbx] - vpaddd ymm1,ymm1,YMMWORD PTR [rbx+32] - vpaddd ymm3,ymm3,YMMWORD PTR [rbx+64] - vpaddd ymm2,ymm2,YMMWORD PTR [rbx+96] - -SkipAccumulateOutput4By32: - cmp r9,4 ; final 4x32 block? - jae StoreOutput4By32 - vperm2i128 ymm4,ymm5,ymm1,31h ; interleave vector results - vperm2i128 ymm5,ymm5,ymm1,20h - vperm2i128 ymm1,ymm3,ymm2,20h - vperm2i128 ymm2,ymm3,ymm2,31h - vmovaps ymm3,ymm4 - -StoreOutput4By32: - vmovdqu YMMWORD PTR [rbx],ymm5 - vmovdqu YMMWORD PTR [rbx+32],ymm1 - vmovdqu YMMWORD PTR [rbx+64],ymm3 - vmovdqu YMMWORD PTR [rbx+96],ymm2 - add rdx,32 ; advance matrix B by 32 bytes - add rbx,32*4 ; advance matrix C by 32 columns - sub rbp,32 ; decrement CountN - jnz ProcessColumnLoop4By32 - -AdvanceRowLoop4: - xor r11,r11 ; clear ZeroMode - sub r9,4 ; decrement CountK - jae ProcessRowLoop4 - -ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jnz ProcessRemainingSmallK - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,GemvU8S8KernelFrame.SavedXmm6[rsp] - add rsp,(GemvU8S8KernelFrame.SavedRdi) - - BEGIN_EPILOGUE - - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Process sets of 8 columns from the 4 rows in a loop. -; - -ProcessColumnLoop4By8: - cmp ebp,8 - jb ProcessColumn4By4 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovq xmm2,QWORD PTR [rdx] - vmovq xmm3,QWORD PTR [rdx+rdi] - vmovq xmm4,QWORD PTR [rax] - vmovq xmm5,QWORD PTR [rax+rdi] - vpunpcklbw xmm2,xmm2,xmm3 ; interleave row data bytes - vpunpcklbw xmm4,xmm4,xmm5 - vpunpcklwd xmm1,xmm2,xmm4 ; interleave row data words - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm1,ymm1,xmm2,1 ; concatenate vector - vpmaddubsw ymm1,ymm0,ymm1 ; multiply and reduce - vpmaddwd ymm1,ymm1,ymm6 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By8 - vpaddd ymm1,ymm1,YMMWORD PTR [rbx] - -SkipAccumulateOutput4By8: - vmovdqu YMMWORD PTR [rbx],ymm1 - add rdx,8 ; advance matrix B by 8 bytes - add rbx,8*4 ; advance matrix C by 8 columns - sub ebp,8 ; decrement CountN - jnz ProcessColumnLoop4By8 - jmp AdvanceRowLoop4 - -; -; Process a set of 4 columns from the 4 rows. -; - -ProcessColumn4By4: - test ebp,4 ; (CountN & 4) != 0? - jz ProcessColumn4BySmallN - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovd xmm1,DWORD PTR [rdx] - vpinsrd xmm1,xmm1,DWORD PTR [rdx+rdi],1 - vpinsrd xmm1,xmm1,DWORD PTR [rax],2 - vpinsrd xmm1,xmm1,DWORD PTR [rax+rdi],3 - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpmaddubsw xmm1,xmm0,xmm1 ; multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By4 - vpaddd xmm1,xmm1,XMMWORD PTR [rbx] - -SkipAccumulateOutput4By4: - vmovdqu XMMWORD PTR [rbx],xmm1 - and ebp,3 ; (CountN & 3) != 0? - jz AdvanceRowLoop4 - add rdx,4 ; advance matrix B by 4 bytes - add rbx,4*4 ; advance matrix C by 4 columns - -; -; Process the remaining 1 to 3 columns from the 4 rows. -; - -ProcessColumn4BySmallN: - mov DWORD PTR GemvU8S8KernelFrame.CountN[rsp],ebp - vbroadcastss xmm2,DWORD PTR GemvU8S8KernelFrame.CountN[rsp] - vpcmpgtd xmm2,xmm2,XMMWORD PTR [MlasMaskMoveAvx] - vpxor xmm1,xmm1,xmm1 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - cmp ebp,2 ; (CountN & 2) != 0? - jb ProcessColumn4By1 - vpinsrw xmm1,xmm1,WORD PTR [rdx],0 - vpinsrw xmm1,xmm1,WORD PTR [rdx+rdi],2 - vpinsrw xmm1,xmm1,WORD PTR [rax],4 - vpinsrw xmm1,xmm1,WORD PTR [rax+rdi],6 - je ComputeOutput4BySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rdx+2],2 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+rdi+2],6 - vpinsrb xmm1,xmm1,BYTE PTR [rax+2],10 - vpinsrb xmm1,xmm1,BYTE PTR [rax+rdi+2],14 - jmp ComputeOutput4BySmallN - -ProcessColumn4By1: - vpinsrb xmm1,xmm1,BYTE PTR [rdx],0 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+rdi],4 - vpinsrb xmm1,xmm1,BYTE PTR [rax],8 - vpinsrb xmm1,xmm1,BYTE PTR [rax+rdi],12 - -ComputeOutput4BySmallN: - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpmaddubsw xmm1,xmm0,xmm1 ; multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 ; ZeroMode? - jnz StoreOutput4BySmallN - vpmaskmovd xmm3,xmm2,XMMWORD PTR [rbx] - vpaddd xmm1,xmm1,xmm3 - -StoreOutput4BySmallN: - vpmaskmovd XMMWORD PTR [rbx],xmm2,xmm1 - jmp AdvanceRowLoop4 - -; -; Broadcast the remaining 1 to 3 values from vector A. -; - -ProcessRemainingSmallK: - vpxor xmm5,xmm5,xmm5 ; keep zero vector for vpinsrb/vpinsrw - cmp r9d,2 - jb LoadVectorASingleRemainingByte - vpinsrw xmm0,xmm5,WORD PTR [rcx],0 - je BroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rcx+2],2 - jmp BroadcastVectorARemainingBytes - -LoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm5,BYTE PTR [rcx],0 - -BroadcastVectorARemainingBytes: - vpshufd xmm0,xmm0,0 ; broadcast values - -; -; Process a set of 4 columns from the remaining rows. -; - -ProcessColumnLoopSmallKBy4: - cmp r10,4 - jb ProcessColumnLoopSmallKBySmallN - vmovd xmm1,DWORD PTR [rsi] - cmp r9d,2 - jb ComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+rdi],1 - je ComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+rdi*2],2 - -ComputeOutputSmallKBy4: - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpmaddubsw xmm1,xmm0,xmm1 ; multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBy4 - vpaddd xmm1,xmm1,XMMWORD PTR [r8] - -SkipAccumulateOutputSmallKBy4: - vmovdqu XMMWORD PTR [r8],xmm1 - add rsi,4 ; advance matrix B by 4 bytes - add r8,4*4 ; advance matrix C by 4 columns - sub r10,4 ; decrement CountN - jnz ProcessColumnLoopSmallKBy4 - jmp ExitKernel - -; -; Process the remaining 1 to 3 columns from the remaining rows. -; -; Single step through each of the columns to keep code size small for the -; uncommon path (typically the row count is a multiple of 4). -; - -ProcessColumnLoopSmallKBySmallN: - vpinsrb xmm1,xmm5,BYTE PTR [rsi],0 - cmp r9d,2 - jb ComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+rdi],1 - je ComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+rdi*2],2 - -ComputeOutputSmallKBySmallN: - vpmaddubsw xmm1,xmm0,xmm1 ; multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBySmallN - vmovd xmm3,DWORD PTR [r8] - vpaddd xmm1,xmm1,xmm3 - -SkipAccumulateOutputSmallKBySmallN: - vmovd DWORD PTR [r8],xmm1 - inc rsi ; advance matrix B by 1 byte - add r8,4 ; advance matrix C by 1 column - dec r10 - jnz ProcessColumnLoopSmallKBySmallN - jmp ExitKernel - - NESTED_END MlasGemvU8S8KernelAvx2, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Common.inc b/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Common.inc deleted file mode 100644 index a97cad9d90180..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Common.inc +++ /dev/null @@ -1,372 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8S8KernelAvx512Common.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the quantized -; integer matrix/vector multiply operation (QGEMV) for the AVX512 core and -; AVX512VNNI kernels. -; -;-- - -GemvU8S8KernelFrame STRUCT - - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountN QWORD ? - ldb QWORD ? - -GemvU8S8KernelFrame ENDS - -; -; Macro Description: -; -; This macro generates the common AVX512 code for the inner kernel to compute -; matrix/vector multiplication. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -GemvU8S8KernelAvx512Function MACRO Isa - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix/vector multiplication. -; -; Arguments: -; -; A (rcx) - Supplies the address of vector A. -; -; B (rdx) - Supplies the address of matrix B. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from vector A and the number -; of rows from matrix B to iterate over. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldb - Supplies the first dimension of matrix B. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemvU8S8Kernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - - END_PROLOGUE - - mov rdi,rcx - mov rsi,rdx - mov r10,GemvU8S8KernelFrame.CountN[rsp] - mov ecx,r10d - and ecx,15 ; isolate unaligned count - mov eax,1 - shl eax,cl - dec eax - kmovw k1,eax ; compute vector load/store mask - mov rcx,GemvU8S8KernelFrame.ldb[rsp] - mov r11,rsp ; set ZeroMode to any non-zero value -IFIDNI , - mov eax,1 - vpbroadcastw zmm29,eax -ENDIF - -; -; Process 4 rows of matrix B in a loop. -; - - sub r9,4 - jb ProcessRemainingRows - -ProcessRowLoop4: - mov rdx,rsi ; reload matrix B - lea rsi,[rsi+rcx*4] ; advance matrix B by 4 rows - mov rbx,r8 ; reload matrix C - mov rbp,r10 ; reload CountN - vpbroadcastd zmm28,DWORD PTR [rdi] - add rdi,4 ; advance matrix A by 4 bytes - -; -; Process sets of 64 columns from the 4 rows in a loop. -; -; Some permute operations are deferred until the final store of the 4x64 block -; as these permutes are expensive. -; - -ProcessColumnLoop4By64: - cmp rbp,64 - jb ProcessColumnLoop4By16 - lea rax,[rdx+rcx*2] ; compute matrix B plus 2 rows - vmovdqu32 zmm16,ZMMWORD PTR [rdx] - vmovdqu32 zmm17,ZMMWORD PTR [rdx+rcx] - vmovdqu32 zmm18,ZMMWORD PTR [rax] - vmovdqu32 zmm19,ZMMWORD PTR [rax+rcx] - vpunpcklbw zmm20,zmm16,zmm17 ; interleave row data bytes - vpunpckhbw zmm21,zmm16,zmm17 - vpunpcklbw zmm22,zmm18,zmm19 - vpunpckhbw zmm23,zmm18,zmm19 - vpunpcklwd zmm16,zmm20,zmm22 ; interleave row data words - vpunpckhwd zmm17,zmm20,zmm22 - vpunpcklwd zmm18,zmm21,zmm23 - vpunpckhwd zmm19,zmm21,zmm23 -IFIDNI , - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 - vpmaddubsw zmm17,zmm28,zmm17 - vpmaddwd zmm21,zmm17,zmm29 - vpmaddubsw zmm18,zmm28,zmm18 - vpmaddwd zmm22,zmm18,zmm29 - vpmaddubsw zmm19,zmm28,zmm19 - vpmaddwd zmm23,zmm19,zmm29 -ELSE - vpxord zmm20,zmm20,zmm20 - vpxord zmm21,zmm21,zmm21 - vpxord zmm22,zmm22,zmm22 - vpxord zmm23,zmm23,zmm23 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 - VpdpbusdsZmmZmmZmm zmm21,zmm28,zmm17 - VpdpbusdsZmmZmmZmm zmm22,zmm28,zmm18 - VpdpbusdsZmmZmmZmm zmm23,zmm28,zmm19 -ENDIF - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By64 - vpaddd zmm20,zmm20,ZMMWORD PTR [rbx] - vpaddd zmm21,zmm21,ZMMWORD PTR [rbx+16*4] - vpaddd zmm22,zmm22,ZMMWORD PTR [rbx+32*4] - vpaddd zmm23,zmm23,ZMMWORD PTR [rbx+48*4] - -SkipAccumulateOutput4By64: - cmp r9,4 ; final 4x64 block? - jae StoreOutput4By64 - vextracti32x4 XMMWORD PTR [rbx],zmm20,0 - vextracti32x4 XMMWORD PTR [rbx+4*4],zmm21,0 - vextracti32x4 XMMWORD PTR [rbx+8*4],zmm22,0 - vextracti32x4 XMMWORD PTR [rbx+12*4],zmm23,0 - vextracti32x4 XMMWORD PTR [rbx+16*4],zmm20,1 - vextracti32x4 XMMWORD PTR [rbx+20*4],zmm21,1 - vextracti32x4 XMMWORD PTR [rbx+24*4],zmm22,1 - vextracti32x4 XMMWORD PTR [rbx+28*4],zmm23,1 - vextracti32x4 XMMWORD PTR [rbx+32*4],zmm20,2 - vextracti32x4 XMMWORD PTR [rbx+36*4],zmm21,2 - vextracti32x4 XMMWORD PTR [rbx+40*4],zmm22,2 - vextracti32x4 XMMWORD PTR [rbx+44*4],zmm23,2 - vextracti32x4 XMMWORD PTR [rbx+48*4],zmm20,3 - vextracti32x4 XMMWORD PTR [rbx+52*4],zmm21,3 - vextracti32x4 XMMWORD PTR [rbx+56*4],zmm22,3 - vextracti32x4 XMMWORD PTR [rbx+60*4],zmm23,3 - jmp AdvanceColumnLoop64 - -StoreOutput4By64: - vmovdqu32 ZMMWORD PTR [rbx],zmm20 - vmovdqu32 ZMMWORD PTR [rbx+16*4],zmm21 - vmovdqu32 ZMMWORD PTR [rbx+32*4],zmm22 - vmovdqu32 ZMMWORD PTR [rbx+48*4],zmm23 - -AdvanceColumnLoop64: - add rdx,64 ; advance matrix B by 64 bytes - add rbx,64*4 ; advance matrix C by 64 columns - sub rbp,64 ; decrement CountN - jnz ProcessColumnLoop4By64 - -AdvanceRowLoop4: - xor r11,r11 ; clear ZeroMode - sub r9,4 ; decrement CountK - jae ProcessRowLoop4 - -ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jnz ProcessRemainingSmallK - -ExitKernel: - vzeroupper - - BEGIN_EPILOGUE - - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Process sets of 16 columns from the 4 rows in a loop or process the remaining -; 1 to 15 columns. -; - -ProcessColumnLoop4By16: - lea rax,[rdx+rcx*2] ; compute matrix B plus 2 rows - cmp ebp,16 - jb LoadPartialVector4BySmallN - vmovdqu xmm2,XMMWORD PTR [rdx] - vmovdqu xmm3,XMMWORD PTR [rdx+rcx] - vmovdqu xmm4,XMMWORD PTR [rax] - vmovdqu xmm5,XMMWORD PTR [rax+rcx] - jmp ComputeOutput4By16 - -LoadPartialVector4BySmallN: - vmovdqu8 zmm2{k1}{z},ZMMWORD PTR [rdx] - vmovdqu8 zmm3{k1}{z},ZMMWORD PTR [rdx+rcx] - vmovdqu8 zmm4{k1}{z},ZMMWORD PTR [rax] - vmovdqu8 zmm5{k1}{z},ZMMWORD PTR [rax+rcx] - -ComputeOutput4By16: - vpunpcklbw xmm1,xmm2,xmm3 ; interleave row data bytes - vpunpckhbw xmm2,xmm2,xmm3 - vpunpcklbw xmm3,xmm4,xmm5 - vpunpckhbw xmm4,xmm4,xmm5 - vpunpcklwd xmm5,xmm1,xmm3 ; interleave row data words - vpunpckhwd xmm1,xmm1,xmm3 - vpunpcklwd xmm3,xmm2,xmm4 - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm5,ymm5,xmm1,1 ; concatenate 256-bit vector - vinserti128 ymm3,ymm3,xmm2,1 - vshufi32x4 zmm16,zmm5,zmm3,044h ; concatenate 512-bit vector -IFIDNI , - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 -ELSE - vpxord zmm20,zmm20,zmm20 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 -ENDIF - cmp ebp,16 - jb StorePartialVector4BySmallN - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By16 - vpaddd zmm20,zmm20,ZMMWORD PTR [rbx] - -SkipAccumulateOutput4By16: - vmovdqu32 ZMMWORD PTR [rbx],zmm20 - add rdx,16 ; advance matrix B by 16 bytes - add rbx,16*4 ; advance matrix C by 16 columns - sub ebp,16 ; decrement CountN - jnz ProcessColumnLoop4By16 - jmp AdvanceRowLoop4 - -StorePartialVector4BySmallN: - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4BySmallN - vpaddd zmm20{k1}{z},zmm20,ZMMWORD PTR [rbx] - -SkipAccumulateOutput4BySmallN: - vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm20 - jmp AdvanceRowLoop4 - -; -; Broadcast the remaining 1 to 3 values from vector A. -; - -ProcessRemainingSmallK: - vpxor xmm0,xmm0,xmm0 - cmp r9d,2 - jb LoadVectorASingleRemainingByte - vpinsrw xmm0,xmm0,WORD PTR [rdi],0 - je BroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rdi+2],2 - jmp BroadcastVectorARemainingBytes - -LoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm0,BYTE PTR [rdi],0 - -BroadcastVectorARemainingBytes: - vpbroadcastd zmm28,xmm0 ; broadcast values - -; -; Process sets of 16 columns from the remaining rows in a loop or process the -; remaining 1 to 15 columns. -; - -ProcessColumnLoopSmallKBy16: - vpxor xmm3,xmm3,xmm3 ; clear optional row vectors - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - cmp r10d,16 - jb LoadPartialVectorSmallKBySmallN - vmovdqu xmm2,XMMWORD PTR [rsi] - cmp r9d,2 - jb ComputeOutputSmallKBy16 - vmovdqu xmm3,XMMWORD PTR [rsi+rcx] - je ComputeOutputSmallKBy16 - vmovdqu xmm4,XMMWORD PTR [rsi+rcx*2] - jmp ComputeOutputSmallKBy16 - -LoadPartialVectorSmallKBySmallN: - vmovdqu8 zmm2{k1}{z},ZMMWORD PTR [rsi] - cmp r9d,2 - jb ComputeOutputSmallKBy16 - vmovdqu8 zmm3{k1}{z},ZMMWORD PTR [rsi+rcx] - je ComputeOutputSmallKBy16 - vmovdqu8 zmm4{k1}{z},ZMMWORD PTR [rsi+rcx*2] - jmp ComputeOutputSmallKBy16 - -ComputeOutputSmallKBy16: - vpunpcklbw xmm1,xmm2,xmm3 ; interleave row data bytes - vpunpckhbw xmm2,xmm2,xmm3 - vpunpcklbw xmm3,xmm4,xmm5 - vpunpckhbw xmm4,xmm4,xmm5 - vpunpcklwd xmm5,xmm1,xmm3 ; interleave row data words - vpunpckhwd xmm1,xmm1,xmm3 - vpunpcklwd xmm3,xmm2,xmm4 - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm5,ymm5,xmm1,1 ; concatenate 256-bit vector - vinserti128 ymm3,ymm3,xmm2,1 - vshufi32x4 zmm16,zmm5,zmm3,044h ; concatenate 512-bit vector -IFIDNI , - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 -ELSE - vpxord zmm20,zmm20,zmm20 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 -ENDIF - cmp r10d,16 - jb StorePartialVectorSmallKBySmallN - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBy16 - vpaddd zmm20,zmm20,ZMMWORD PTR [r8] - -SkipAccumulateOutputSmallKBy16: - vmovdqu32 ZMMWORD PTR [r8],zmm20 - add rsi,16 ; advance matrix B by 16 bytes - add r8,16*4 ; advance matrix C by 16 columns - sub r10d,16 ; decrement CountN - jnz ProcessColumnLoopSmallKBy16 - jmp ExitKernel - -StorePartialVectorSmallKBySmallN: - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBySmallN - vpaddd zmm20{k1}{z},zmm20,ZMMWORD PTR [r8] - -SkipAccumulateOutputSmallKBySmallN: - vmovdqu32 ZMMWORD PTR [r8]{k1},zmm20 - jmp ExitKernel - - NESTED_END MlasGemvU8S8Kernel&Isa&, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm b/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm deleted file mode 100644 index c1727b3e34d25..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Core.asm +++ /dev/null @@ -1,31 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8S8KernelAvx512Core.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/vector -; multiply operation (QGEMV). -; -; This implementation uses AVX512 core instructions (BW/DQ/VL). -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemvU8S8KernelAvx512Common.inc - .list - -; -; Generate the GEMV kernel. -; - -GemvU8S8KernelAvx512Function Avx512Core - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm deleted file mode 100644 index cb175881f556e..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvx512Vnni.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8S8KernelAvx512Vnni.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/vector -; multiply operation (QGEMV). -; -; This implementation uses AVX512VNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE QgemvU8S8KernelAvx512Common.inc -INCLUDE AssembleAvx512Vnni.inc - .list - -; -; Generate the GEMV kernel. -; - -GemvU8S8KernelAvx512Function Avx512Vnni - - END diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm b/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm deleted file mode 100644 index be13dcebbc580..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8S8KernelAvxVnni.asm +++ /dev/null @@ -1,385 +0,0 @@ -;++ -; -; Copyright (c) 2020 Intel Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8S8KernelAvxVnni.asm -; -; Abstract: -; -; This module implements the kernels for the quantized integer matrix/vector -; multiply operation (QGEMV). -; -; This implementation uses AVXVNNI instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE AssembleAvxVnni.inc - .list - - EXTERN MlasMaskMoveAvx:NEAR - EXTERN MlasTranspose4x4BytesAvx:NEAR - -; -; Stack frame layout for the U8S8 kernel. -; - -GemvU8S8KernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - Padding QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountN QWORD ? - ldb QWORD ? - -GemvU8S8KernelFrame ENDS - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix/vector multiplication. -; -; Arguments: -; -; A (rcx) - Supplies the address of vector A. -; -; B (rdx) - Supplies the address of matrix B. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from vector A and the number -; of rows from matrix B to iterate over. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldb - Supplies the first dimension of matrix B. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasGemvU8S8KernelAvxVnni, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - alloc_stack (GemvU8S8KernelFrame.SavedRdi) - save_xmm128 xmm6,GemvU8S8KernelFrame.SavedXmm6 - save_xmm128 xmm7,GemvU8S8KernelFrame.SavedXmm7 - save_xmm128 xmm8,GemvU8S8KernelFrame.SavedXmm8 - save_xmm128 xmm9,GemvU8S8KernelFrame.SavedXmm9 - save_xmm128 xmm10,GemvU8S8KernelFrame.SavedXmm10 - - END_PROLOGUE - - mov rsi,rdx - mov rdi,GemvU8S8KernelFrame.ldb[rsp] - mov r10,GemvU8S8KernelFrame.CountN[rsp] - mov r11,rsp ; set ZeroMode to any non-zero value - -; -; Process 4 rows of matrix B in a loop. -; - - sub r9,4 - jb ProcessRemainingRows - -ProcessRowLoop4: - mov rdx,rsi ; reload matrix B - lea rsi,[rsi+rdi*4] ; advance matrix B by 4 rows - mov rbx,r8 ; reload matrix C - mov rbp,r10 ; reload CountN - vpbroadcastd ymm0,DWORD PTR [rcx] - add rcx,4 ; advance matrix A by 4 bytes - -; -; Process sets of 32 columns from the 4 rows in a loop. -; -; Some permute operations are deferred until the final store of the 4x32 block -; as these permutes are expensive. -; - -ProcessColumnLoop4By32: - cmp rbp,32 - jb ProcessColumnLoop4By8 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovdqu ymm2,YMMWORD PTR [rdx] - vmovdqu ymm3,YMMWORD PTR [rdx+rdi] - vmovdqu ymm4,YMMWORD PTR [rax] - vmovdqu ymm5,YMMWORD PTR [rax+rdi] - vpunpcklbw ymm1,ymm2,ymm3 ; interleave row data bytes - vpunpckhbw ymm2,ymm2,ymm3 - vpxor ymm7,ymm7,ymm7 - vpunpcklbw ymm3,ymm4,ymm5 - vpunpckhbw ymm4,ymm4,ymm5 - vpxor ymm8,ymm8,ymm8 - vpunpcklwd ymm5,ymm1,ymm3 ; interleave row data words - vpunpckhwd ymm1,ymm1,ymm3 - vpxor ymm9,ymm9,ymm9 - vpunpcklwd ymm3,ymm2,ymm4 - vpunpckhwd ymm2,ymm2,ymm4 - vpxor ymm10,ymm10,ymm10 - VpdpbusdsYmmYmmYmm ymm7,ymm0,ymm5 - VpdpbusdsYmmYmmYmm ymm8,ymm0,ymm1 - VpdpbusdsYmmYmmYmm ymm9,ymm0,ymm3 - VpdpbusdsYmmYmmYmm ymm10,ymm0,ymm2 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By32 - vpaddd ymm7,ymm7,YMMWORD PTR [rbx] - vpaddd ymm8,ymm8,YMMWORD PTR [rbx+32] - vpaddd ymm9,ymm9,YMMWORD PTR [rbx+64] - vpaddd ymm10,ymm10,YMMWORD PTR [rbx+96] - -SkipAccumulateOutput4By32: - cmp r9,4 ; final 4x32 block? - jae StoreOutput4By32 - vperm2i128 ymm4,ymm7,ymm8,31h ; interleave vector results - vperm2i128 ymm7,ymm7,ymm8,20h - vperm2i128 ymm8,ymm9,ymm10,20h - vperm2i128 ymm10,ymm9,ymm10,31h - vmovaps ymm9,ymm4 - -StoreOutput4By32: - vmovdqu YMMWORD PTR [rbx],ymm7 - vmovdqu YMMWORD PTR [rbx+32],ymm8 - vmovdqu YMMWORD PTR [rbx+64],ymm9 - vmovdqu YMMWORD PTR [rbx+96],ymm10 - add rdx,32 ; advance matrix B by 32 bytes - add rbx,32*4 ; advance matrix C by 32 columns - sub rbp,32 ; decrement CountN - jnz ProcessColumnLoop4By32 - -AdvanceRowLoop4: - xor r11,r11 ; clear ZeroMode - sub r9,4 ; decrement CountK - jae ProcessRowLoop4 - -ProcessRemainingRows: - add r9,4 ; correct for over-subtract above - jnz ProcessRemainingSmallK - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - vzeroupper - movaps xmm6,GemvU8S8KernelFrame.SavedXmm6[rsp] - movaps xmm7,GemvU8S8KernelFrame.SavedXmm7[rsp] - movaps xmm8,GemvU8S8KernelFrame.SavedXmm8[rsp] - movaps xmm9,GemvU8S8KernelFrame.SavedXmm9[rsp] - movaps xmm10,GemvU8S8KernelFrame.SavedXmm10[rsp] - add rsp,(GemvU8S8KernelFrame.SavedRdi) - - BEGIN_EPILOGUE - - pop rdi - pop rsi - pop rbx - pop rbp - ret - -; -; Process sets of 8 columns from the 4 rows in a loop. -; - -ProcessColumnLoop4By8: - cmp ebp,8 - jb ProcessColumn4By4 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovq xmm2,QWORD PTR [rdx] - vmovq xmm3,QWORD PTR [rdx+rdi] - vmovq xmm4,QWORD PTR [rax] - vmovq xmm5,QWORD PTR [rax+rdi] - vpunpcklbw xmm2,xmm2,xmm3 ; interleave row data bytes - vpunpcklbw xmm4,xmm4,xmm5 - vpunpcklwd xmm1,xmm2,xmm4 ; interleave row data words - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm1,ymm1,xmm2,1 ; concatenate vector - vpxor ymm8,ymm8,ymm8 - VpdpbusdsYmmYmmYmm ymm8,ymm0,ymm1 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By8 - vpaddd ymm8,ymm8,YMMWORD PTR [rbx] - -SkipAccumulateOutput4By8: - vmovdqu YMMWORD PTR [rbx],ymm8 - add rdx,8 ; advance matrix B by 8 bytes - add rbx,8*4 ; advance matrix C by 8 columns - sub ebp,8 ; decrement CountN - jnz ProcessColumnLoop4By8 - jmp AdvanceRowLoop4 - -; -; Process a set of 4 columns from the 4 rows. -; - -ProcessColumn4By4: - test ebp,4 ; (CountN & 4) != 0? - jz ProcessColumn4BySmallN - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - vmovd xmm1,DWORD PTR [rdx] - vpinsrd xmm1,xmm1,DWORD PTR [rdx+rdi],1 - vpinsrd xmm1,xmm1,DWORD PTR [rax],2 - vpinsrd xmm1,xmm1,DWORD PTR [rax+rdi],3 - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutput4By4 - vpaddd xmm8,xmm8,XMMWORD PTR [rbx] - -SkipAccumulateOutput4By4: - vmovdqu XMMWORD PTR [rbx],xmm8 - and ebp,3 ; (CountN & 3) != 0? - jz AdvanceRowLoop4 - add rdx,4 ; advance matrix B by 4 bytes - add rbx,4*4 ; advance matrix C by 4 columns - -; -; Process the remaining 1 to 3 columns from the 4 rows. -; - -ProcessColumn4BySmallN: - mov DWORD PTR GemvU8S8KernelFrame.CountN[rsp],ebp - vbroadcastss xmm2,DWORD PTR GemvU8S8KernelFrame.CountN[rsp] - vpcmpgtd xmm2,xmm2,XMMWORD PTR [MlasMaskMoveAvx] - vpxor xmm1,xmm1,xmm1 - lea rax,[rdx+rdi*2] ; compute matrix B plus 2 rows - cmp ebp,2 ; (CountN & 2) != 0? - jb ProcessColumn4By1 - vpinsrw xmm1,xmm1,WORD PTR [rdx],0 - vpinsrw xmm1,xmm1,WORD PTR [rdx+rdi],2 - vpinsrw xmm1,xmm1,WORD PTR [rax],4 - vpinsrw xmm1,xmm1,WORD PTR [rax+rdi],6 - je ComputeOutput4BySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rdx+2],2 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+rdi+2],6 - vpinsrb xmm1,xmm1,BYTE PTR [rax+2],10 - vpinsrb xmm1,xmm1,BYTE PTR [rax+rdi+2],14 - jmp ComputeOutput4BySmallN - -ProcessColumn4By1: - vpinsrb xmm1,xmm1,BYTE PTR [rdx],0 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+rdi],4 - vpinsrb xmm1,xmm1,BYTE PTR [rax],8 - vpinsrb xmm1,xmm1,BYTE PTR [rax+rdi],12 - -ComputeOutput4BySmallN: - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 ; ZeroMode? - jnz StoreOutput4BySmallN - vpmaskmovd xmm3,xmm2,XMMWORD PTR [rbx] - vpaddd xmm8,xmm8,xmm3 - -StoreOutput4BySmallN: - vpmaskmovd XMMWORD PTR [rbx],xmm2,xmm8 - jmp AdvanceRowLoop4 - -; -; Broadcast the remaining 1 to 3 values from vector A. -; - -ProcessRemainingSmallK: - vpxor xmm5,xmm5,xmm5 ; keep zero vector for vpinsrb/vpinsrw - cmp r9d,2 - jb LoadVectorASingleRemainingByte - vpinsrw xmm0,xmm5,WORD PTR [rcx],0 - je BroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rcx+2],2 - jmp BroadcastVectorARemainingBytes - -LoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm5,BYTE PTR [rcx],0 - -BroadcastVectorARemainingBytes: - vpshufd xmm0,xmm0,0 ; broadcast values - -; -; Process a set of 4 columns from the remaining rows. -; - -ProcessColumnLoopSmallKBy4: - cmp r10,4 - jb ProcessColumnLoopSmallKBySmallN - vmovd xmm1,DWORD PTR [rsi] - cmp r9d,2 - jb ComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+rdi],1 - je ComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+rdi*2],2 - -ComputeOutputSmallKBy4: - vpshufb xmm1,xmm1,XMMWORD PTR [MlasTranspose4x4BytesAvx] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBy4 - vpaddd xmm8,xmm8,XMMWORD PTR [r8] - -SkipAccumulateOutputSmallKBy4: - vmovdqu XMMWORD PTR [r8],xmm8 - add rsi,4 ; advance matrix B by 4 bytes - add r8,4*4 ; advance matrix C by 4 columns - sub r10,4 ; decrement CountN - jnz ProcessColumnLoopSmallKBy4 - jmp ExitKernel - -; -; Process the remaining 1 to 3 columns from the remaining rows. -; -; Single step through each of the columns to keep code size small for the -; uncommon path (typically the row count is a multiple of 4). -; - -ProcessColumnLoopSmallKBySmallN: - vpinsrb xmm1,xmm5,BYTE PTR [rsi],0 - cmp r9d,2 - jb ComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+rdi],1 - je ComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+rdi*2],2 - -ComputeOutputSmallKBySmallN: - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 ; ZeroMode? - jnz SkipAccumulateOutputSmallKBySmallN - vmovd xmm3,DWORD PTR [r8] - vpaddd xmm8,xmm8,xmm3 - -SkipAccumulateOutputSmallKBySmallN: - vmovd DWORD PTR [r8],xmm8 - inc rsi ; advance matrix B by 1 byte - add r8,4 ; advance matrix C by 1 column - dec r10 - jnz ProcessColumnLoopSmallKBySmallN - jmp ExitKernel - - NESTED_END MlasGemvU8S8KernelAvxVnni, _TEXT - - END \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/amd64/QgemvU8X8KernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/QgemvU8X8KernelCommon.inc deleted file mode 100644 index 6b5c942d39761..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/QgemvU8X8KernelCommon.inc +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; QgemvU8X8KernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the quantized -; integer matrix/matrix multiply operation (QGEMM) for the AVX2 kernels. -; -;-- - -GemvU8X8KernelFrame STRUCT - - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountN QWORD ? - ldb QWORD ? - -GemvU8X8KernelFrame ENDS diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx.asm deleted file mode 100644 index dd6c12e2caecf..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx.asm +++ /dev/null @@ -1,367 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision convolution -; operation. -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SconvKernelAvxCommon.inc - .list - -; -; Macro Description: -; -; This macro multiplies and accumulates for FilterCount by OutputCount block -; of the output buffer. -; -; Arguments: -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; VectorOffset - Supplies the byte offset from the filter buffer to fetch -; elements. -; -; BroadcastOffset - Supplies the byte offset from the input buffer to fetch -; elements. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbx - Supplies the address of the filter buffer plus 2 * FilterStride. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; ymm0-ymm7 - Supplies the block accumulators. -; - -ComputeBlock MACRO KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -IFIDNI , - vmovups ymm12,YMMWORD PTR [rdx] - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 2, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, -IF OutputCount EQ 1 - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, - EmitIfCountGE FilterCount, 4, -ELSE - EmitIfCountGE FilterCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCountGE FilterCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows. -; -; Arguments: -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description) when -; KernelType!=Depthwise. Supplies the address of the filter buffer when -; KernelType=Depthwise. -; -; rbp - Supplies the DilationWidth parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r15 - Supplies the InputStride parameter (see function description). -; - -ProcessFilterCountN MACRO KernelFrame, KernelType, FilterCount - - LOCAL ProcessOutputCount - LOCAL ProcessNextOutputCountBy2 - LOCAL ProcessRemainingOutputCount - LOCAL ProcessOutputCountRightPadAndRemaining - -; -; Process the output blocks that include left padding. -; - - mov r10,KernelFrame.OutputCountLeftPad[rsp] - test r10,r10 - jz ProcessOutputCount - call MlasConv&KernelType&FloatSingleAvxFilter&FilterCount - -; -; Process the output blocks that do not include any padding. -; - -ProcessOutputCount: - mov r10,KernelFrame.OutputCount[rsp] - sub r10,2 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy2: - ProcessOutputCountN Avx, KernelFrame, KernelType, 8, FilterCount, 2 - lea rdi,[rdi+r9*2] ; advance input by 2 elements - sub r10,2 - jae ProcessNextOutputCountBy2 - -ProcessRemainingOutputCount: - add r10,2 ; correct for over-subtract above - -; -; Process the output blocks that include right padding plus any remaining output -; blocks from above. -; - -ProcessOutputCountRightPadAndRemaining: - add r10,KernelFrame.OutputCountRightPad[rsp] - jz ExitKernel - call MlasConv&KernelType&FloatSingleAvxFilter&FilterCount - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows for a pointwise convolution. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r10 - Supplies the OutputCount parameter (see function description). -; -; r12 - Supplies the address of the filter buffer. -; - -ProcessPointwiseFilterCountN MACRO FilterCount - - LOCAL ProcessNextOutputCountBy2 - LOCAL ProcessRemainingOutputCount - - sub r10,2 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy2: - ProcessPointwiseOutputCountN Avx, 8, FilterCount, 2 - lea rdi,[rdi+r9*2] ; advance input by 2 elements - sub r10,2 - jae ProcessNextOutputCountBy2 - -ProcessRemainingOutputCount: - add r10,2 ; correct for over-subtract above - jz ExitKernel - ProcessPointwiseOutputCountN Avx, 8, FilterCount, 1 - - ENDM - -; -; Generate the convolution kernels. -; - -SconvKernelFunction Nchw, 8, Avx -SconvKernelFunction Nchwc, 8, Avx, BiasFilter -SconvKernelDepthwiseFunction 8, Avx -SconvKernelPointwiseFunction Avx, BiasFilter - -; -; Macro Description: -; -; This macro generates code to process an output block after the inner -; convolution kernel has executed and then stores the output block to the -; output buffer. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; - - IRP FilterCount, <1, 2, 3, 4> - IRP OutputCount, <1, 2, 3> - - LEAF_ENTRY MlasConvPostProcessFloatAvxFilter&FilterCount&Output&OutputCount, _TEXT - - PUBLIC MlasConvPostProcessFloatFma3Filter&FilterCount&Output&OutputCount -MlasConvPostProcessFloatFma3Filter&FilterCount&Output&OutputCount:: - -IF FilterCount GT 2 - lea rbx,[r8+rax*2] ; compute output plus 2 rows -ENDIF - -; -; Test if the existing contents of the output buffer should be accumulated -; with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz SkipAccumulateOutput - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - -SkipAccumulateOutput: - -; -; Test if the bias buffer should be accumulated with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz SkipBiasAddition -IF OutputCount EQ 1 - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, -ELSE - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, -ENDIF - -SkipBiasAddition: - -; -; Test for fused ReLU activation. -; - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz SkipReluActivation - vxorps xmm15,xmm15,xmm15 - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - -SkipReluActivation: - -; -; Store the output block in the output buffer. -; - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - add_immed r8,OutputCount*8*4 ; advance output by N nchw8c blocks - ret - - LEAF_END MlasConvPostProcessFloatAvxFilter&FilterCount&Output&OutputCount, _TEXT - - ENDM - ENDM - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx512F.asm deleted file mode 100644 index 43eaa88e49946..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvx512F.asm +++ /dev/null @@ -1,518 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelAvx512F.asm -; -; Abstract: -; -; This module implements the kernels for the single precision convolution -; operation. -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SconvKernelCommon.inc - .list - -; -; Macro Description: -; -; This macro generates code to clear the block accumulators. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; zmm0-zmm23 - Supplies the block accumulators. -; - -ClearBlock MACRO FilterCount, OutputCount - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 4, - EmitIfCount2GE FilterCount, 3, OutputCount, 5, - EmitIfCount2GE FilterCount, 3, OutputCount, 6, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 5, - EmitIfCount2GE FilterCount, 4, OutputCount, 6, - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for FilterCount by OutputCount block -; of the output buffer. -; -; Arguments: -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; VectorOffset - Supplies the byte offset from the filter buffer to fetch -; elements. -; -; BroadcastOffset - Supplies the byte offset from the input buffer to fetch -; elements. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbx - Supplies the address of the filter buffer plus 2 * FilterStride. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r14 - Supplies the address of the input buffer plus 3 * StrideWidth. -; -; zmm0-zmm23 - Supplies the block accumulators. -; - -ComputeBlock MACRO KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -IFIDNI , - vmovups zmm24,ZMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, - EmitIfCountGE OutputCount, 4, - EmitIfCountGE OutputCount, 5, - EmitIfCountGE OutputCount, 6, -ELSE -IF FilterCount EQ 1 - vmovups zmm24,ZMMWORD PTR [rdx+VectorOffset] - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, - EmitIfCountGE OutputCount, 4, - EmitIfCountGE OutputCount, 5, - EmitIfCountGE OutputCount, 6, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, - EmitIfCountGE OutputCount, 4, - EmitIfCountGE OutputCount, 5, - EmitIfCountGE OutputCount, 6, -IF OutputCount EQ 1 - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, -ELSE - EmitIfCountGE FilterCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCountGE FilterCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCountGE FilterCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 4, - EmitIfCount2GE FilterCount, 3, OutputCount, 5, - EmitIfCount2GE FilterCount, 3, OutputCount, 6, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 5, - EmitIfCount2GE FilterCount, 4, OutputCount, 6, -ENDIF -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows. -; -; Arguments: -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description) when -; KernelType!=Depthwise. Supplies the address of the filter buffer when -; KernelType=Depthwise. -; -; rbp - Supplies the DilationWidth parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r15 - Supplies the InputStride parameter (see function description). -; - -ProcessFilterCountN MACRO KernelFrame, KernelType, FilterCount - - LOCAL ProcessOutputCount - LOCAL ProcessNextOutputCountBy6 - LOCAL ProcessRemainingOutputCount - LOCAL ProcessRemainingOutputCountLessThan3 - LOCAL ProcessRemainingOutputCount1 - LOCAL ProcessOutputCountRightPadAndRemaining - -; -; Process the output blocks that include left padding. -; - - mov r10,KernelFrame.OutputCountLeftPad[rsp] - test r10,r10 - jz ProcessOutputCount - call MlasConv&KernelType&FloatSingleAvx512FFilter&FilterCount - -; -; Process the output blocks that do not include any padding. -; - -ProcessOutputCount: - mov r10,KernelFrame.OutputCount[rsp] - sub r10,6 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy6: - ProcessOutputCountN Avx512F, KernelFrame, KernelType, 16, FilterCount, 6 - lea rax,[r9*2+r9] - lea rdi,[rdi+rax*2] ; advance input by 6 elements - sub r10,6 - jae ProcessNextOutputCountBy6 - -ProcessRemainingOutputCount: - add r10,6 ; correct for over-subtract above - jz ProcessOutputCountRightPadAndRemaining - cmp r10,3 - jb ProcessRemainingOutputCountLessThan3 - ProcessOutputCountN Avx512F, KernelFrame, KernelType, 16, FilterCount, 3 - lea rax,[r9*2+r9] - add rdi,rax ; advance input by 3 elements - sub r10,3 - jz ProcessOutputCountRightPadAndRemaining - -ProcessRemainingOutputCountLessThan3: - cmp r10,1 - je ProcessOutputCountRightPadAndRemaining - ProcessOutputCountN Avx512F, KernelFrame, KernelType, 16, FilterCount, 2 - lea rdi,[rdi+r9*2] ; advance input by 2 elements - sub r10,2 - -; -; Process the output blocks that include right padding plus any remaining output -; blocks from above. -; - -ProcessOutputCountRightPadAndRemaining: - add r10,KernelFrame.OutputCountRightPad[rsp] - jz ExitKernel - call MlasConv&KernelType&FloatSingleAvx512FFilter&FilterCount - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows for a pointwise convolution. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r10 - Supplies the OutputCount parameter (see function description). -; -; r12 - Supplies the address of the filter buffer. -; - -ProcessPointwiseFilterCountN MACRO FilterCount - - LOCAL ProcessNextOutputCountBy6 - LOCAL ProcessRemainingOutputCount - LOCAL ProcessRemainingOutputCountLessThan3 - LOCAL ProcessRemainingOutputCount1 - - sub r10,6 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy6: - ProcessPointwiseOutputCountN Avx512F, 16, FilterCount, 6 - lea rax,[r9*2+r9] - lea rdi,[rdi+rax*2] ; advance input by 6 elements - sub r10,6 - jae ProcessNextOutputCountBy6 - -ProcessRemainingOutputCount: - add r10,6 ; correct for over-subtract above - jz ExitKernel - cmp r10,3 - jb ProcessRemainingOutputCountLessThan3 - ProcessPointwiseOutputCountN Avx512F, 16, FilterCount, 3 - lea rax,[r9*2+r9] - add rdi,rax ; advance input by 3 elements - sub r10,3 - jz ExitKernel - -ProcessRemainingOutputCountLessThan3: - cmp r10,2 - jb ProcessRemainingOutputCount1 - ProcessPointwiseOutputCountN Avx512F, 16, FilterCount, 2 - jmp ExitKernel - -ProcessRemainingOutputCount1: - ProcessPointwiseOutputCountN Avx512F, 16, FilterCount, 1 - - ENDM - -; -; Generate the convolution kernels. -; -; N.B. BiasFilter is not used here as the AVX-512 EVEX instruction encoding -; efficiently compresses aligned relative byte offsets. -; - -SconvKernelFunction Nchw, 16, Avx512F -SconvKernelFunction Nchwc, 16, Avx512F -SconvKernelDepthwiseFunction 16, Avx512F -SconvKernelPointwiseFunction Avx512F - -; -; Macro Description: -; -; This macro generates code to process an output block after the inner -; convolution kernel has executed and then stores the output block to the -; output buffer. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; - - IRP FilterCount, <1, 2, 3, 4> - IRP OutputCount, <1, 2, 3, 6> - - LEAF_ENTRY MlasConvPostProcessFloatAvx512FFilter&FilterCount&Output&OutputCount, _TEXT - -IF FilterCount GT 2 - lea rbx,[r8+rax*2] ; compute output plus 2 rows -ENDIF - -; -; Test if the existing contents of the output buffer should be accumulated -; with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz SkipAccumulateOutput - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 4, - EmitIfCount2GE FilterCount, 3, OutputCount, 5, - EmitIfCount2GE FilterCount, 3, OutputCount, 6, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 5, - EmitIfCount2GE FilterCount, 4, OutputCount, 6, - -SkipAccumulateOutput: - -; -; Test if the bias buffer should be accumulated with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz SkipBiasAddition -IF OutputCount EQ 1 - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, -ELSE - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 4, - EmitIfCount2GE FilterCount, 3, OutputCount, 5, - EmitIfCount2GE FilterCount, 3, OutputCount, 6, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 5, - EmitIfCount2GE FilterCount, 4, OutputCount, 6, -ENDIF - -SkipBiasAddition: - -; -; Test for fused ReLU activation. -; - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz SkipReluActivation - vpxord zmm24,zmm24,zmm24 - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - -SkipReluActivation: - -; -; Store the output block in the output buffer. -; - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 1, OutputCount, 4, - EmitIfCount2GE FilterCount, 1, OutputCount, 5, - EmitIfCount2GE FilterCount, 1, OutputCount, 6, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 4, - EmitIfCount2GE FilterCount, 2, OutputCount, 5, - EmitIfCount2GE FilterCount, 2, OutputCount, 6, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 4, - EmitIfCount2GE FilterCount, 3, OutputCount, 5, - EmitIfCount2GE FilterCount, 3, OutputCount, 6, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 5, - EmitIfCount2GE FilterCount, 4, OutputCount, 6, - add_immed r8,OutputCount*16*4 ; advance output by N nchw16c blocks - ret - - LEAF_END MlasConvPostProcessFloatAvx512FFilter&FilterCount&Output&OutputCount, _TEXT - - ENDM - ENDM - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvxCommon.inc b/onnxruntime/core/mlas/lib/amd64/SconvKernelAvxCommon.inc deleted file mode 100644 index 40aa54b005231..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelAvxCommon.inc +++ /dev/null @@ -1,51 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelAvxCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision convolution operation for the AVX and FMA3 kernels. -; -;-- - -INCLUDE SconvKernelCommon.inc - -; -; Macro Description: -; -; This macro generates code to clear the block accumulators. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; ymm0-ymm11 - Supplies the block accumulators. -; - -ClearBlock MACRO FilterCount, OutputCount - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/SconvKernelCommon.inc deleted file mode 100644 index 2f5f997c6f12c..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelCommon.inc +++ /dev/null @@ -1,909 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision convolution operation. -; -;-- - -; -; Define the convolution kernel flags. -; - -MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT EQU 00000001h -MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION EQU 00000002h -MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION EQU 00000004h -MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION EQU 00000008h - -; -; Stack frame layout for the convolution kernels. -; - -SconvKernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR12 QWORD ? - SavedR13 QWORD ? - SavedR14 QWORD ? - SavedR15 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? ; Input - PreviousP2Home QWORD ? ; Filter - PreviousP3Home QWORD ? ; Output - PreviousP4Home QWORD ? ; StrideWidth - DilationWidth QWORD ? - FilterCount QWORD ? - InputStride QWORD ? - FilterStride QWORD ? - OutputStride QWORD ? - KernelHeight QWORD ? - KernelWidth QWORD ? - InputBase QWORD ? - InputWidth QWORD ? - DilatedInputWidth QWORD ? - OutputCountLeftPad QWORD ? - OutputCount QWORD ? - OutputCountRightPad QWORD ? - Bias QWORD ? - Flags QWORD ? - -SconvKernelFrame ENDS - -SconvKernelSingleFrame STRUCT - - ReturnAddress QWORD ? - KernelFrame SconvKernelFrame <> - -SconvKernelSingleFrame ENDS - -SconvKernelDepthwiseFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR12 QWORD ? - SavedR13 QWORD ? - SavedR14 QWORD ? - SavedR15 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? ; Input - PreviousP2Home QWORD ? ; Filter - PreviousP3Home QWORD ? ; Output - PreviousP4Home QWORD ? ; StrideWidth - DilationWidth QWORD ? - InputStride QWORD ? - KernelHeight QWORD ? - KernelWidth QWORD ? - InputBase QWORD ? - InputWidth QWORD ? - DilatedInputWidth QWORD ? - OutputCountLeftPad QWORD ? - OutputCount QWORD ? - OutputCountRightPad QWORD ? - Bias QWORD ? - Flags QWORD ? - -SconvKernelDepthwiseFrame ENDS - -SconvKernelDepthwiseSingleFrame STRUCT - - ReturnAddress QWORD ? - KernelFrame SconvKernelDepthwiseFrame <> - -SconvKernelDepthwiseSingleFrame ENDS - -SconvKernelPointwiseFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - SavedR12 QWORD ? - SavedR14 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? ; Input - PreviousP2Home QWORD ? ; Filter - PreviousP3Home QWORD ? ; Output - PreviousP4Home QWORD ? ; StrideWidth - InputChannels QWORD ? - FilterCount QWORD ? - InputStride QWORD ? - FilterStride QWORD ? - OutputStride QWORD ? - OutputCount QWORD ? - Bias QWORD ? - Flags QWORD ? - -SconvKernelPointwiseFrame ENDS - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a vector of input -; blocks and a vector of filter blocks to produce a matrix of output blocks. -; -; OutputCount=1 generates special case code to handle padding blocks. All -; other output counts assume no padding. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string for function tags. -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; KernelType - Supplies the type of kernel to be generated. -; -; BlockSize - Supplies the number of elements per block. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description) when -; KernelType!=Depthwise. Supplies the address of the filter buffer when -; KernelType=Depthwise. -; -; rbp - Supplies the DilationWidth parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r15 - Supplies the InputStride parameter (see function description). -; - -ProcessOutputCountN MACRO Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount - - LOCAL ProcessNextRow - LOCAL ProcessNextColumn - LOCAL HandlePostProcessing - LOCAL SkipOverPadding - - mov rcx,rdi -IFIDNI , - mov rdx,rsi -ELSE - mov rdx,KernelFrame.PreviousP2Home[rsp] -ENDIF - mov r11,KernelFrame.KernelHeight[rsp] - mov r12,KernelFrame.KernelWidth[rsp] -IF OutputCount EQ 1 - mov r13,KernelFrame.InputBase[rsp] - mov r14,KernelFrame.InputWidth[rsp] - neg r13 ; keep negative for lea usage below -ENDIF - ClearBlock FilterCount, OutputCount - test r11,r11 ; zero sized kernel? - jz HandlePostProcessing - -ProcessNextRow: - mov rax,r12 ; reload kernel width remaining - -ProcessNextColumn: -IF OutputCount EQ 1 - lea rbx,[rcx+r13] ; compute (Input - InputBase) - cmp rbx,r14 ; (Input - InputBase) >= InputWidth? - jae SkipOverPadding -ENDIF -IF OutputCount GT 3 - lea r14,[r9+r9*2] - add r14,rcx ; compute input plus 3 blocks -ENDIF -IF FilterCount GT 2 - lea rbx,[rdx+rsi*2] ; compute filter plus 2 rows -ENDIF -IFIDNI , -IF BlockSize EQ 16 - IRP Index, <0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15> - ComputeBlock KernelType, FilterCount, OutputCount, Index*16*4, Index*4 - ENDM -ELSE - IRP Index, <0, 1, 2, 3, 4, 5, 6, 7> - ComputeBlock KernelType, FilterCount, OutputCount, (Index-4)*8*4, Index*4 - ENDM -ENDIF -ELSE - ComputeBlock KernelType, FilterCount, OutputCount, 0, 0 -ENDIF - -SkipOverPadding: - add rcx,rbp ; advance input by dilation width -IFIDNI , - add rdx,BlockSize*BlockSize*4 ; advance filter by 8i8o/16i16o block -ELSE - add rdx,BlockSize*4 ; advance filter by 8o/16o block -ENDIF - dec rax ; decrement columns remaining - jnz ProcessNextColumn - add rcx,r15 ; advance input to next row -IF OutputCount EQ 1 - sub r13,KernelFrame.DilatedInputWidth[rsp] - ; advance input base to next row -ENDIF - dec r11 ; decrement rows remaining - jnz ProcessNextRow - -; -; Handle post processing of the output block. -; - -HandlePostProcessing: - mov edx,DWORD PTR KernelFrame.Flags[rsp] -IF FilterCount GT 1 - mov rax,KernelFrame.OutputStride[rsp] -ENDIF - mov rcx,KernelFrame.Bias[rsp] - call MlasConvPostProcessFloat&Isa&Filter&FilterCount&Output&OutputCount - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner convolution kernel. -; -; Arguments: -; -; KernelType - Supplies the type of kernel to be generated. -; -; BlockSize - Supplies the number of elements per block. -; -; Isa - Supplies the instruction set architecture string for function tags. -; -; BiasFilter - Supplies a non-blank value if the address of the filter buffer -; should be biased to point to the middle of a OIhw8i8o block in order to -; reduce the code size from relative byte offsets. -; - -SconvKernelFunction MACRO KernelType, BlockSize, Isa, BiasFilter - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a convolution for the elements -; of an output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; The address is biased to include padding blocks for the left width -; dimension. The address is not biased to include padding rows for the -; left height dimension; these are accounted for in the outer kernel. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; StrideWidth (r9) - Supplies the length in bytes of the blocked stride width. -; -; DilationWidth - Supplies the length in bytes of the blocked dilation width. -; -; FilterCount - Supplies the number of filters to process in this iteration. -; -; InputStride - Supplies the length in bytes to advance the input buffer to -; the next input row. -; -; FilterStride - Supplies the length in bytes to advance the filter buffer -; to the next set of filters. -; -; OutputStride - Supplies the length in bytes to advance the output buffer -; to the next output address associated with the next set of filters. -; -; KernelHeight - Supplies the height of the kernel to apply. This height may -; be less than the original kernel height after removing any padding -; rows. -; -; KernelWidth - Supplies the width of the kernel to apply. -; -; InputBase - Supplies the address of the valid input buffer. -; -; This parameter is similar to the Input parameter, but does not include -; the padding blocks for the left width dimension. This parameter is used -; with the following InputWidth parameter in order to validate that the -; current input buffer address in bounds and not in the left or right -; width padding region. -; -; InputWidth - Supplies the length in bytes of the blocked input width. -; -; DilatedInputWidth - Supplies the length in bytes to advance the input base -; buffer to the next input row including dilation. -; -; OutputCountLeftPad - Supplies the number of output elements that include -; one or more padding elements from the left edge. -; -; OutputCount - Supplies the number of output elements that do not include -; any padding elements. -; -; OutputCountRightPad - Supplies the number of output elements that include -; one or more padding elements from the right edge. -; -; Bias - Supplies the address of the bias buffer. -; -; Flags - Supplies additional flags controlling the convolution operation, -; especially post calculation options. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConv&KernelType&FloatKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r15 - push_reg r14 - push_reg r13 - push_reg r12 - alloc_stack (SconvKernelFrame.SavedR12) - - save_xmm128 xmm6,SconvKernelFrame.SavedXmm6 - save_xmm128 xmm7,SconvKernelFrame.SavedXmm7 - save_xmm128 xmm8,SconvKernelFrame.SavedXmm8 - save_xmm128 xmm9,SconvKernelFrame.SavedXmm9 - save_xmm128 xmm10,SconvKernelFrame.SavedXmm10 - save_xmm128 xmm11,SconvKernelFrame.SavedXmm11 - save_xmm128 xmm12,SconvKernelFrame.SavedXmm12 - save_xmm128 xmm13,SconvKernelFrame.SavedXmm13 - save_xmm128 xmm14,SconvKernelFrame.SavedXmm14 - save_xmm128 xmm15,SconvKernelFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx -IFNB - add_immed rdx,4*8*4 -ENDIF - mov SconvKernelFrame.PreviousP2Home[rsp],rdx - mov rsi,SconvKernelFrame.FilterStride[rsp] - mov rbp,SconvKernelFrame.DilationWidth[rsp] - mov r11,SconvKernelFrame.FilterCount[rsp] - mov r15,SconvKernelFrame.InputStride[rsp] - -; -; Process the specified number of filter rows. -; - - cmp r11,3 - je ProcessFilterCount3 - jb ProcessFilterCountLessThan3 - ProcessFilterCountN SconvKernelFrame, KernelType, 4 - jmp ExitKernel - -ProcessFilterCount3: - ProcessFilterCountN SconvKernelFrame, KernelType, 3 - jmp ExitKernel - -ProcessFilterCountLessThan3: - cmp r11,2 - jb ProcessFilterCount1 - ProcessFilterCountN SconvKernelFrame, KernelType, 2 - jmp ExitKernel - -ProcessFilterCount1: - ProcessFilterCountN SconvKernelFrame, KernelType, 1 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: -IFDIFI , - vzeroupper -ENDIF - movaps xmm6,SconvKernelFrame.SavedXmm6[rsp] - movaps xmm7,SconvKernelFrame.SavedXmm7[rsp] - movaps xmm8,SconvKernelFrame.SavedXmm8[rsp] - movaps xmm9,SconvKernelFrame.SavedXmm9[rsp] - movaps xmm10,SconvKernelFrame.SavedXmm10[rsp] - movaps xmm11,SconvKernelFrame.SavedXmm11[rsp] - movaps xmm12,SconvKernelFrame.SavedXmm12[rsp] - movaps xmm13,SconvKernelFrame.SavedXmm13[rsp] - movaps xmm14,SconvKernelFrame.SavedXmm14[rsp] - movaps xmm15,SconvKernelFrame.SavedXmm15[rsp] - add rsp,(SconvKernelFrame.SavedR12) - - BEGIN_EPILOGUE - - pop r12 - pop r13 - pop r14 - pop r15 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConv&KernelType&FloatKernel&Isa&, _TEXT - -IFDIFI , - -; -; Generate out-of-band helpers for handling output blocks involving padding. -; - - IRP FilterCount, <1, 2, 3, 4> - - LEAF_ENTRY MlasConv&KernelType&FloatSingle&Isa&Filter&FilterCount, _TEXT - -ProcessNextOutputCount: - ProcessOutputCountN Isa, SconvKernelSingleFrame.KernelFrame, KernelType, BlockSize, FilterCount, 1 - add rdi,r9 ; advance input by 1 element - dec r10 ; decrement output count remaining - jnz ProcessNextOutputCount - ret - - LEAF_END MlasConv&KernelType&FloatSingle&Isa&Filter&FilterCount, _TEXT - - ENDM - -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner convolution kernel for the special -; case of a depthwise separable convolution. -; -; Arguments: -; -; BlockSize - Supplies the number of elements per block. -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -SconvKernelDepthwiseFunction MACRO BlockSize, Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a convolution for the elements -; of an output row for a set of filter rows. -; -; Depthwise seperable convolutions are a form of grouped convolution where -; the number of input and output channels per group are one. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; The address is biased to include padding blocks for the left width -; dimension. The address is not biased to include padding rows for the -; left height dimension; these are accounted for in the outer kernel. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; StrideWidth (r9) - Supplies the length in bytes of the blocked stride width. -; -; DilationWidth - Supplies the length in bytes of the blocked dilation width. -; -; InputStride - Supplies the length in bytes to advance the input buffer to -; the next input row. -; -; KernelHeight - Supplies the height of the kernel to apply. This height may -; be less than the original kernel height after removing any padding -; rows. -; -; KernelWidth - Supplies the width of the kernel to apply. -; -; InputBase - Supplies the address of the valid input buffer. -; -; This parameter is similar to the Input parameter, but does not include -; the padding blocks for the left width dimension. This parameter is used -; with the following InputWidth parameter in order to validate that the -; current input buffer address in bounds and not in the left or right -; width padding region. -; -; InputWidth - Supplies the length in bytes of the blocked input width. -; -; DilatedInputWidth - Supplies the length in bytes to advance the input base -; buffer to the next input row including dilation. -; -; OutputCountLeftPad - Supplies the number of output elements that include -; one or more padding elements from the left edge. -; -; OutputCount - Supplies the number of output elements that do not include -; any padding elements. -; -; OutputCountRightPad - Supplies the number of output elements that include -; one or more padding elements from the right edge. -; -; Bias - Supplies the address of the bias buffer. -; -; Flags - Supplies additional flags controlling the convolution operation, -; especially post calculation options. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvDepthwiseFloatKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r15 - push_reg r14 - push_reg r13 - push_reg r12 - alloc_stack (SconvKernelDepthwiseFrame.SavedR12) - - save_xmm128 xmm6,SconvKernelDepthwiseFrame.SavedXmm6 - save_xmm128 xmm7,SconvKernelDepthwiseFrame.SavedXmm7 - save_xmm128 xmm8,SconvKernelDepthwiseFrame.SavedXmm8 - save_xmm128 xmm9,SconvKernelDepthwiseFrame.SavedXmm9 - save_xmm128 xmm10,SconvKernelDepthwiseFrame.SavedXmm10 - save_xmm128 xmm11,SconvKernelDepthwiseFrame.SavedXmm11 - save_xmm128 xmm12,SconvKernelDepthwiseFrame.SavedXmm12 - save_xmm128 xmm13,SconvKernelDepthwiseFrame.SavedXmm13 - save_xmm128 xmm14,SconvKernelDepthwiseFrame.SavedXmm14 - save_xmm128 xmm15,SconvKernelDepthwiseFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx - mov rsi,rdx - mov rbp,SconvKernelDepthwiseFrame.DilationWidth[rsp] - mov r15,SconvKernelDepthwiseFrame.InputStride[rsp] - -; -; Process the specified number of filter rows. -; - - ProcessFilterCountN SconvKernelDepthwiseFrame, Depthwise, 1 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: -IFDIFI , - vzeroupper -ENDIF - movaps xmm6,SconvKernelDepthwiseFrame.SavedXmm6[rsp] - movaps xmm7,SconvKernelDepthwiseFrame.SavedXmm7[rsp] - movaps xmm8,SconvKernelDepthwiseFrame.SavedXmm8[rsp] - movaps xmm9,SconvKernelDepthwiseFrame.SavedXmm9[rsp] - movaps xmm10,SconvKernelDepthwiseFrame.SavedXmm10[rsp] - movaps xmm11,SconvKernelDepthwiseFrame.SavedXmm11[rsp] - movaps xmm12,SconvKernelDepthwiseFrame.SavedXmm12[rsp] - movaps xmm13,SconvKernelDepthwiseFrame.SavedXmm13[rsp] - movaps xmm14,SconvKernelDepthwiseFrame.SavedXmm14[rsp] - movaps xmm15,SconvKernelDepthwiseFrame.SavedXmm15[rsp] - add rsp,(SconvKernelDepthwiseFrame.SavedR12) - - BEGIN_EPILOGUE - - pop r12 - pop r13 - pop r14 - pop r15 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConvDepthwiseFloatKernel&Isa&, _TEXT - -IFDIFI , - -; -; Generate out-of-band helpers for handling output blocks involving padding. -; - - LEAF_ENTRY MlasConvDepthwiseFloatSingle&Isa&Filter1, _TEXT - -ProcessNextOutputCount: - ProcessOutputCountN Isa, SconvKernelDepthwiseSingleFrame.KernelFrame, Depthwise, BlockSize, 1, 1 - add rdi,r9 ; advance input by 1 element - dec r10 ; decrement output count remaining - jnz ProcessNextOutputCount - ret - - LEAF_END MlasConvDepthwiseFloatSingle&Isa&Filter1, _TEXT - -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a vector of input -; blocks and a vector of filter blocks to produce a matrix of output blocks -; for a pointwise convolution. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string for function tags. -; -; BlockSize - Supplies the number of elements per block. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r12 - Supplies the address of the filter buffer. -; - -ProcessPointwiseOutputCountN MACRO Isa, BlockSize, FilterCount, OutputCount - - LOCAL ProcessNextInputBlock - LOCAL SkipAccumulateOutput - LOCAL SkipBiasAddition - LOCAL SkipReluActivation - - mov rcx,rdi - mov rdx,r12 - mov r11,SconvKernelPointwiseFrame.InputChannels[rsp] - ClearBlock FilterCount, OutputCount - -ProcessNextInputBlock: -IF OutputCount GT 3 - lea r14,[r9+r9*2] - add r14,rcx ; compute input plus 3 blocks -ENDIF -IF FilterCount GT 2 - lea rbx,[rdx+rsi*2] ; compute filter plus 2 rows -ENDIF -IF BlockSize EQ 16 - IRP Index, <0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15> - ComputeBlock Pointwise, FilterCount, OutputCount, Index*16*4, Index*4 - ENDM -ELSE - IRP Index, <0, 1, 2, 3, 4, 5, 6, 7> - ComputeBlock Pointwise, FilterCount, OutputCount, (Index-4)*8*4, Index*4 - ENDM -ENDIF - add rcx,rbp ; advance input to next channel block - add rdx,BlockSize*BlockSize*4 ; advance filter by 8i8o/16i16o block - dec r11 ; decrement input blocks remaining - jnz ProcessNextInputBlock - -; -; Handle post processing of the output block. -; - - mov edx,DWORD PTR SconvKernelPointwiseFrame.Flags[rsp] -IF FilterCount GT 1 - mov rax,SconvKernelPointwiseFrame.OutputStride[rsp] -ENDIF - mov rcx,SconvKernelPointwiseFrame.Bias[rsp] - call MlasConvPostProcessFloat&Isa&Filter&FilterCount&Output&OutputCount - - ENDM - -;++ -; -; Macro Description: -; -; This macro generates code for the inner convolution kernel for the special -; case where the kernel dimensions are 1. -; -; Arguments: -; -; Isa - Supplies the instruction set architecture string for function tags. -; -; BiasFilter - Supplies a non-blank value if the address of the filter buffer -; should be biased to point to the middle of a OIhw8i8o block in order to -; reduce the code size from relative byte offsets. -; -;-- - -SconvKernelPointwiseFunction MACRO Isa, BiasFilter - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute a convolution for the elements -; of an output row for a set of filter rows. -; -; Pointwise convolutions have a kernel size of one. To simplify this -; implementation, no input padding is allowed, which matches typical usage in -; models. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; Filter (rdx) - Supplies the address of the filter buffer. -; -; Output (r8) - Supplies the address of the output buffer. -; -; StrideWidth (r9) - Supplies the length in bytes of the blocked stride width. -; -; InputChannels - Supplies the number of input channels to process. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; InputStride - Supplies the length in bytes to advance the input buffer to -; the next input channel of the same input row. -; -; FilterStride - Supplies the length in bytes to advance the filter buffer -; to the next set of filters. -; -; OutputStride - Supplies the length in bytes to advance the output buffer -; to the next output address associated with the next set of filters. -; -; OutputCount - Supplies the number of output elements. -; -; Bias - Supplies the address of the bias buffer. -; -; Flags - Supplies additional flags controlling the convolution operation, -; especially post calculation options. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasConvPointwiseFloatKernel&Isa&, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r14 - push_reg r12 - alloc_stack (SconvKernelPointwiseFrame.SavedR12) - - save_xmm128 xmm6,SconvKernelPointwiseFrame.SavedXmm6 - save_xmm128 xmm7,SconvKernelPointwiseFrame.SavedXmm7 - save_xmm128 xmm8,SconvKernelPointwiseFrame.SavedXmm8 - save_xmm128 xmm9,SconvKernelPointwiseFrame.SavedXmm9 - save_xmm128 xmm10,SconvKernelPointwiseFrame.SavedXmm10 - save_xmm128 xmm11,SconvKernelPointwiseFrame.SavedXmm11 - save_xmm128 xmm12,SconvKernelPointwiseFrame.SavedXmm12 - save_xmm128 xmm13,SconvKernelPointwiseFrame.SavedXmm13 - save_xmm128 xmm14,SconvKernelPointwiseFrame.SavedXmm14 - save_xmm128 xmm15,SconvKernelPointwiseFrame.SavedXmm15 - - END_PROLOGUE - - mov rdi,rcx -IFNB - lea r12,[rdx+4*8*4] -ELSE - mov r12,rdx -ENDIF - mov r10,SconvKernelPointwiseFrame.OutputCount[rsp] - mov r11,SconvKernelPointwiseFrame.FilterCount[rsp] - mov rsi,SconvKernelPointwiseFrame.FilterStride[rsp] - mov rbp,SconvKernelPointwiseFrame.InputStride[rsp] - -; -; Process the specified number of filter rows. -; - - cmp r11,3 - je ProcessFilterCount3 - jb ProcessFilterCountLessThan3 - ProcessPointwiseFilterCountN 4 - jmp ExitKernel - -ProcessFilterCount3: - ProcessPointwiseFilterCountN 3 - jmp ExitKernel - -ProcessFilterCountLessThan3: - cmp r11,2 - jb ProcessFilterCount1 - ProcessPointwiseFilterCountN 2 - jmp ExitKernel - -ProcessFilterCount1: - ProcessPointwiseFilterCountN 1 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: -IFDIFI , - vzeroupper -ENDIF - movaps xmm6,SconvKernelPointwiseFrame.SavedXmm6[rsp] - movaps xmm7,SconvKernelPointwiseFrame.SavedXmm7[rsp] - movaps xmm8,SconvKernelPointwiseFrame.SavedXmm8[rsp] - movaps xmm9,SconvKernelPointwiseFrame.SavedXmm9[rsp] - movaps xmm10,SconvKernelPointwiseFrame.SavedXmm10[rsp] - movaps xmm11,SconvKernelPointwiseFrame.SavedXmm11[rsp] - movaps xmm12,SconvKernelPointwiseFrame.SavedXmm12[rsp] - movaps xmm13,SconvKernelPointwiseFrame.SavedXmm13[rsp] - movaps xmm14,SconvKernelPointwiseFrame.SavedXmm14[rsp] - movaps xmm15,SconvKernelPointwiseFrame.SavedXmm15[rsp] - add rsp,(SconvKernelPointwiseFrame.SavedR12) - - BEGIN_EPILOGUE - - pop r12 - pop r14 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - NESTED_END MlasConvPointwiseFloatKernel&Isa&, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/SconvKernelFma3.asm deleted file mode 100644 index 38b260b4df93d..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelFma3.asm +++ /dev/null @@ -1,262 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelFma3.asm -; -; Abstract: -; -; This module implements the kernels for the single precision convolution -; operation. -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SconvKernelAvxCommon.inc - .list - -; -; Share the post process functions with the AVX implementation. -; - - IRP FilterCount, <1, 2, 3, 4> - IRP OutputCount, <1, 2, 3> - - EXTERN MlasConvPostProcessFloatFma3Filter&FilterCount&Output&OutputCount:NEAR - - ENDM - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for FilterCount by OutputCount block -; of the output buffer. -; -; Arguments: -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; VectorOffset - Supplies the byte offset from the filter buffer to fetch -; elements. -; -; BroadcastOffset - Supplies the byte offset from the input buffer to fetch -; elements. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbx - Supplies the address of the filter buffer plus 2 * FilterStride. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; ymm0-ymm11 - Supplies the block accumulators. -; - -ComputeBlock MACRO KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -IFIDNI , - vmovups ymm12,YMMWORD PTR [rdx] - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -IF OutputCount EQ 1 - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 4, -ELSE - EmitIfCountGE FilterCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 2, - EmitIfCount2GE FilterCount, 1, OutputCount, 3, - EmitIfCountGE FilterCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 3, - EmitIfCountGE FilterCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 2, - EmitIfCount2GE FilterCount, 3, OutputCount, 3, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 2, - EmitIfCount2GE FilterCount, 4, OutputCount, 3, -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows. -; -; Arguments: -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description) when -; KernelType!=Depthwise. Supplies the address of the filter buffer when -; KernelType=Depthwise. -; -; rbp - Supplies the DilationWidth parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r15 - Supplies the InputStride parameter (see function description). -; - -ProcessFilterCountN MACRO KernelFrame, KernelType, FilterCount - - LOCAL ProcessOutputCountLeftPad - LOCAL ProcessOutputCount - LOCAL ProcessNextOutputCountBy3 - LOCAL ProcessRemainingOutputCount - LOCAL ProcessRemainingOutputCount1 - LOCAL ProcessOutputCountRightPadAndRemaining - -; -; Process the output blocks that include left padding. -; - - mov r10,KernelFrame.OutputCountLeftPad[rsp] - test r10,r10 - jz ProcessOutputCount - call MlasConv&KernelType&FloatSingleFma3Filter&FilterCount - -; -; Process the output blocks that do not include any padding. -; - -ProcessOutputCount: - mov r10,KernelFrame.OutputCount[rsp] - sub r10,3 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy3: - ProcessOutputCountN Fma3, KernelFrame, KernelType, 8, FilterCount, 3 - lea rax,[r9*2+r9] - add rdi,rax ; advance input by 3 elements - sub r10,3 - jae ProcessNextOutputCountBy3 - -ProcessRemainingOutputCount: - add r10,3 ; correct for over-subtract above - jz ProcessOutputCountRightPadAndRemaining - cmp r10,2 - jb ProcessOutputCountRightPadAndRemaining - ProcessOutputCountN Fma3, KernelFrame, KernelType, 8, FilterCount, 2 - lea rdi,[rdi+r9*2] ; advance input by 2 elements - sub r10,2 - -; -; Process the output blocks that include right padding plus any remaining output -; blocks from above. -; - -ProcessOutputCountRightPadAndRemaining: - add r10,KernelFrame.OutputCountRightPad[rsp] - jz ExitKernel - call MlasConv&KernelType&FloatSingleFma3Filter&FilterCount - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows for a pointwise convolution. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r10 - Supplies the OutputCount parameter (see function description). -; -; r12 - Supplies the address of the filter buffer. -; - -ProcessPointwiseFilterCountN MACRO FilterCount - - LOCAL ProcessNextOutputCountBy3 - LOCAL ProcessRemainingOutputCount - LOCAL ProcessRemainingOutputCount1 - - sub r10,3 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy3: - ProcessPointwiseOutputCountN Fma3, 8, FilterCount, 3 - lea rax,[r9*2+r9] - add rdi,rax ; advance input by 3 elements - sub r10,3 - jae ProcessNextOutputCountBy3 - -ProcessRemainingOutputCount: - add r10,3 ; correct for over-subtract above - jz ExitKernel - cmp r10,2 - jb ProcessRemainingOutputCount1 - ProcessPointwiseOutputCountN Fma3, 8, FilterCount, 2 - jmp ExitKernel - -ProcessRemainingOutputCount1: - ProcessPointwiseOutputCountN Fma3, 8, FilterCount, 1 - - ENDM - -; -; Generate the convolution kernels. -; - -SconvKernelFunction Nchw, 8, Fma3 -SconvKernelFunction Nchwc, 8, Fma3, BiasFilter -SconvKernelDepthwiseFunction 8, Fma3 -SconvKernelPointwiseFunction Fma3, BiasFilter - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SconvKernelSse2.asm b/onnxruntime/core/mlas/lib/amd64/SconvKernelSse2.asm deleted file mode 100644 index 57eb402ff0c16..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SconvKernelSse2.asm +++ /dev/null @@ -1,337 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SconvKernelSse2.asm -; -; Abstract: -; -; This module implements the kernels for the single precision convolution -; operation. -; -; This implementation uses SSE2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SconvKernelCommon.inc - .list - -; -; Macro Description: -; -; This macro generates code to clear the block accumulators. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; xmm0-xmm7 - Supplies the block accumulators. -; - -ClearBlock MACRO FilterCount, OutputCount - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for FilterCount by OutputCount block -; of the output buffer. -; -; Arguments: -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; VectorOffset - Supplies the byte offset from the filter buffer to fetch -; elements. -; -; BroadcastOffset - Supplies the byte offset from the input buffer to fetch -; elements. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rdx - Supplies the address of the filter buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbx - Supplies the address of the filter buffer plus 2 * FilterStride. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; xmm0-xmm7 - Supplies the block accumulators. -; - -ComputeBlock MACRO KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -IFIDNI , - movups xmm8,XMMWORD PTR [rdx] - movups xmm9,XMMWORD PTR [rdx+16] - movups xmm10,XMMWORD PTR [rcx] - movups xmm11,XMMWORD PTR [rcx+16] - mulps xmm8,xmm10 - addps xmm0,xmm8 - mulps xmm9,xmm11 - addps xmm1,xmm9 -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 1, - EmitIfCountGE FilterCount, 1, - EmitIfCountGE FilterCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCountGE FilterCount, 2, - EmitIfCountGE FilterCount, 2, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCountGE FilterCount, 3, - EmitIfCountGE FilterCount, 3, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCountGE FilterCount, 4, - EmitIfCountGE FilterCount, 4, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows. -; -; Arguments: -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; KernelType - Supplies the type of kernel to be generated. -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the DilationWidth parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r15 - Supplies the InputStride parameter (see function description). -; - -ProcessFilterCountN MACRO KernelFrame, KernelType, FilterCount - - LOCAL ProcessNextOutputCount - - mov r10,KernelFrame.OutputCountLeftPad[rsp] - add r10,KernelFrame.OutputCount[rsp] - add r10,KernelFrame.OutputCountRightPad[rsp] - -ProcessNextOutputCount: - ProcessOutputCountN Sse, KernelFrame, KernelType, 8, FilterCount, 1 - add rdi,r9 ; advance input by 1 element - dec r10 - jnz ProcessNextOutputCount - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute the convolution for a specified number -; of filter rows for a pointwise convolution. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rsi - Supplies the FilterStride parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; -; r8 - Supplies the address of the output buffer. -; -; r9 - Supplies the StrideWidth parameter (see function description). -; -; r10 - Supplies the OutputCount parameter (see function description). -; -; r12 - Supplies the address of the filter buffer. -; - -ProcessPointwiseFilterCountN MACRO FilterCount - - LOCAL ProcessNextOutputCount - -ProcessNextOutputCount: - ProcessPointwiseOutputCountN Sse, 8, FilterCount, 1 - add rdi,r9 ; advance input by 1 element - dec r10 - jnz ProcessNextOutputCount - - ENDM - -; -; Generate the convolution kernels. -; - -SconvKernelFunction Nchw, 8, Sse -SconvKernelFunction Nchwc, 8, Sse, BiasFilter -SconvKernelDepthwiseFunction 8, Sse -SconvKernelPointwiseFunction Sse, BiasFilter - -; -; Macro Description: -; -; This macro generates code to process an output block after the inner -; convolution kernel has executed and then stores the output block to the -; output buffer. -; -; Arguments: -; -; FilterCount - Supplies the number of rows from the filter to process. -; -; OutputCount - Supplies the number of output blocks to produce. -; - - IRP FilterCount, <1, 2, 3, 4> - IRP OutputCount, <1> - - LEAF_ENTRY MlasConvPostProcessFloatSseFilter&FilterCount&Output&OutputCount, _TEXT - -IF FilterCount GT 2 - lea rbx,[r8+rax*2] ; compute output plus 2 rows -ENDIF - -; -; Test if the existing contents of the output buffer should be accumulated -; with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz SkipAccumulateOutput - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - -SkipAccumulateOutput: - -; -; Test if the bias buffer should be accumulated with the output block. -; - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz SkipBiasAddition - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - -SkipBiasAddition: - -; -; Test for fused ReLU activation. -; - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz SkipReluActivation - xorps xmm15,xmm15 - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - -SkipReluActivation: - -; -; Store the output block in the output buffer. -; - - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 1, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 2, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 3, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - EmitIfCount2GE FilterCount, 4, OutputCount, 1, - add_immed r8,OutputCount*8*4 ; advance output by N nchw8c blocks - ret - - LEAF_END MlasConvPostProcessFloatSseFilter&FilterCount&Output&OutputCount, _TEXT - - ENDM - ENDM - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx.asm deleted file mode 100644 index 9ba503df16805..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc -INCLUDE FgemmKernelAvxCommon.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelAvxFunction Float - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx512F.asm deleted file mode 100644 index d59880a1f552a..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelAvx512F.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelAvx512F.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc -INCLUDE FgemmKernelAvx512FCommon.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelAvx512FFunction Float - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/SgemmKernelCommon.inc deleted file mode 100644 index f7bc7870bb3ae..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelCommon.inc +++ /dev/null @@ -1,45 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision matrix/matrix multiply operation (SGEMM). -; -;-- - -; -; Define the single precision parameters. -; - -FgemmElementShift EQU 2 -FgemmElementSize EQU (1 SHL FgemmElementShift) -FgemmElementPtr EQU DWORD PTR -FgemmElementBcst EQU DWORD BCST - -; -; Define the typed instructions for single precision. -; - -addpf EQU addps -movupf EQU movups - -vaddpf EQU vaddps -vbroadcastsf EQU vbroadcastss -vfmadd213pf EQU vfmadd213ps -vfmadd231pf EQU vfmadd231ps -vmaskmovpf EQU vmaskmovps -vmovapf EQU vmovaps -vmovsf EQU vmovss -vmovupf EQU vmovups -vmulpf EQU vmulps -vxorpf EQU vxorps - -INCLUDE FgemmKernelCommon.inc diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/SgemmKernelFma3.asm deleted file mode 100644 index 3651ad18f5333..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelFma3.asm +++ /dev/null @@ -1,32 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelFma3.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc -INCLUDE FgemmKernelFma3Common.inc - .list - -; -; Generate the GEMM kernel. -; - -FgemmKernelFma3Function Float - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelM1Avx.asm b/onnxruntime/core/mlas/lib/amd64/SgemmKernelM1Avx.asm deleted file mode 100644 index 418c8e332eb46..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelM1Avx.asm +++ /dev/null @@ -1,581 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelM1Avx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). This handles the special case of M=1. -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMaskMoveAvx:NEAR - -; -; Stack frame layout for the SGEMM M=1 kernels. -; - -SgemmKernelM1Frame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - CountN QWORD ? - ldb QWORD ? - Beta QWORD ? - -SgemmKernelM1Frame ENDS - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. This handles the special case of M=1. -; -; The elements in matrix B are not transposed. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; ldb - Supplies the first dimension of matrix B. -; -; Beta - Supplies the scalar beta multiplier (see SGEMM definition). -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasSgemmKernelM1Avx, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - alloc_stack (SgemmKernelM1Frame.SavedRsi) - save_xmm128 xmm6,SgemmKernelM1Frame.SavedXmm6 - save_xmm128 xmm7,SgemmKernelM1Frame.SavedXmm7 - save_xmm128 xmm8,SgemmKernelM1Frame.SavedXmm8 - - END_PROLOGUE - - mov rbx,SgemmKernelM1Frame.ldb[rsp] - shl rbx,2 ; convert ldb to bytes - mov r10,r8 - mov r11,rdx - mov rbp,SgemmKernelM1Frame.CountN[rsp] - -; -; Compute the initial results mask for zeroing or accumulate mode. -; - - vxorps xmm0,xmm0,xmm0 - vcmpeqss xmm0,xmm0,DWORD PTR SgemmKernelM1Frame.Beta[rsp] - vshufps xmm0,xmm0,xmm0,0 - vinsertf128 ymm0,ymm0,xmm0,1 - -; -; Compute the conditional load/store mask for an unaligned CountN. -; - - mov eax,ebp - and eax,7 - vmovd xmm7,eax - vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR [MlasMaskMoveAvx+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR [MlasMaskMoveAvx] - vinsertf128 ymm7,ymm7,xmm6,1 - -; -; Process 4 rows of the matrices in a loop. -; - - sub r9,4 - jb ProcessRemainingCountK - -ProcessRowLoop4: - vbroadcastss ymm2,DWORD PTR [rcx] - mov rax,rbp ; reload CountN - vbroadcastss ymm3,DWORD PTR [rcx+4] - mov rdx,r11 ; reload matrix B - vbroadcastss ymm4,DWORD PTR [rcx+8] - mov r8,r10 ; reload matrix C - vbroadcastss ymm5,DWORD PTR [rcx+12] - add rcx,4*4 ; advance matrix A by 4 columns - lea r11,[rdx+rbx*4] ; advance matrix B by 4 rows - sub rax,16 - jb ProcessRemainingCountN4 - -ProcessColumnLoop4: - lea rsi,[rdx+rbx*2] ; compute matrix B plus 2 rows - vmulps ymm1,ymm2,YMMWORD PTR [rdx] - vmulps ymm6,ymm2,YMMWORD PTR [rdx+32] - vmulps ymm8,ymm3,YMMWORD PTR [rdx+rbx] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm3,YMMWORD PTR [rdx+rbx+32] - vaddps ymm6,ymm6,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rsi] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rsi+32] - vaddps ymm6,ymm6,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rsi+rbx] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rsi+rbx+32] - vaddps ymm6,ymm6,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [r8] - vaddps ymm1,ymm1,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [r8+32] - vaddps ymm6,ymm6,ymm8 - vmovups YMMWORD PTR [r8],ymm1 - vmovups YMMWORD PTR [r8+32],ymm6 - add rdx,16*4 ; advance matrix B by 16 columns - add r8,16*4 ; advance matrix C by 16 columns - sub rax,16 - jae ProcessColumnLoop4 - -ProcessRemainingCountN4: - test al,15 ; test for unaligned columns - jz ProcessedRemainingCountN4 - test al,8 ; CountN >= 8? - jz ProcessRemainingCountNSmall4 - lea rsi,[rdx+rbx*2] ; compute matrix B plus 2 rows - vmulps ymm1,ymm2,YMMWORD PTR [rdx] - vmulps ymm8,ymm3,YMMWORD PTR [rdx+rbx] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rsi] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rsi+rbx] - vaddps ymm1,ymm1,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [r8] - vaddps ymm1,ymm1,ymm8 - vmovups YMMWORD PTR [r8],ymm1 - add rdx,8*4 ; advance matrix B by 8 columns - add r8,8*4 ; advance matrix C by 8 columns - test al,7 - jz ProcessedRemainingCountN4 - -ProcessRemainingCountNSmall4: - lea rsi,[rdx+rbx*2] ; compute matrix B plus 2 rows - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx+rbx] - vmulps ymm8,ymm3,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm8,ymm4,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+rbx] - vmulps ymm8,ymm5,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [r8] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [r8],ymm7,ymm1 - -ProcessedRemainingCountN4: - vxorps xmm0,xmm0,xmm0 ; switch to accumulate mode - sub r9,4 - jae ProcessRowLoop4 - -ProcessRemainingCountK: - test r9d,2 - jnz ProcessRowLoop2 - test r9d,1 - jnz ProcessRowLoop1 - -ExitKernel: - vzeroupper - movaps xmm6,SgemmKernelM1Frame.SavedXmm6[rsp] - movaps xmm7,SgemmKernelM1Frame.SavedXmm7[rsp] - movaps xmm8,SgemmKernelM1Frame.SavedXmm8[rsp] - add rsp,(SgemmKernelM1Frame.SavedRsi) - - BEGIN_EPILOGUE - - pop rsi - pop rbx - pop rbp - ret - -; -; Process 2 rows of the matrices. -; - -ProcessRowLoop2: - vbroadcastss ymm2,DWORD PTR [rcx] - mov rax,rbp ; reload CountN - vbroadcastss ymm3,DWORD PTR [rcx+4] - mov rdx,r11 ; reload matrix B - mov r8,r10 ; reload matrix C - add rcx,2*4 ; advance matrix A by 2 columns - lea r11,[rdx+rbx*2] ; advance matrix B by 2 rows - sub rax,8 - jb ProcessRemainingCountN2 - -ProcessColumnLoop2: - vmulps ymm1,ymm2,YMMWORD PTR [rdx] - vmulps ymm8,ymm3,YMMWORD PTR [rdx+rbx] - vaddps ymm1,ymm1,ymm8 - vandnps ymm6,ymm0,YMMWORD PTR [r8] - vaddps ymm1,ymm1,ymm6 - vmovups YMMWORD PTR [r8],ymm1 - add rdx,8*4 ; advance matrix B by 8 columns - add r8,8*4 ; advance matrix C by 8 columns - sub rax,8 - jae ProcessColumnLoop2 - -ProcessRemainingCountN2: - test al,7 ; test for unaligned columns - jz ProcessedRemainingCountN2 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx+rbx] - vmulps ymm8,ymm3,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [r8] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [r8],ymm7,ymm1 - -ProcessedRemainingCountN2: - test r9d,1 - jz ExitKernel - vxorps xmm0,xmm0,xmm0 ; switch to accumulate mode - -; -; Process 1 row of the matrices. -; - -ProcessRowLoop1: - vbroadcastss ymm2,DWORD PTR [rcx] - mov rax,rbp ; reload CountN - mov rdx,r11 ; reload matrix B - mov r8,r10 ; reload matrix C - sub rax,8 - jb ProcessRemainingCountN1 - -ProcessColumnLoop1: - vmulps ymm1,ymm2,YMMWORD PTR [rdx] - vandnps ymm6,ymm0,YMMWORD PTR [r8] - vaddps ymm1,ymm1,ymm6 - vmovups YMMWORD PTR [r8],ymm1 - add rdx,8*4 ; advance matrix B by 8 columns - add r8,8*4 ; advance matrix C by 8 columns - sub rax,8 - jae ProcessColumnLoop1 - -ProcessRemainingCountN1: - test al,7 ; test for unaligned columns - jz ExitKernel - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [r8] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [r8],ymm7,ymm1 - jmp ExitKernel - - NESTED_END MlasSgemmKernelM1Avx, _TEXT - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. This handles the special case of M=1. -; -; The elements in matrix B are transposed. -; -; Arguments: -; -; A (rcx) - Supplies the address of matrix A. -; -; B (rdx) - Supplies the address of matrix B. The elements are transposed. -; -; C (r8) - Supplies the address of matrix C. -; -; CountK (r9) - Supplies the number of columns from matrix A and the number -; of columns from matrix B to iterate over. -; -; CountN - Supplies the number of rows from matrix B and the number of columns -; from matrix C to iterate over. -; -; ldb - Supplies the first dimension of matrix B. -; -; Beta - Supplies the scalar beta multiplier (see SGEMM definition). -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasSgemmKernelM1TransposeBAvx, _TEXT - - rex_push_reg rbp - push_reg rbx - push_reg rsi - alloc_stack (SgemmKernelM1Frame.SavedRsi) - save_xmm128 xmm6,SgemmKernelM1Frame.SavedXmm6 - save_xmm128 xmm7,SgemmKernelM1Frame.SavedXmm7 - - END_PROLOGUE - - mov rbx,SgemmKernelM1Frame.ldb[rsp] - shl rbx,2 ; convert ldb to bytes - mov r10,rcx - mov r11,rdx - mov rbp,SgemmKernelM1Frame.CountN[rsp] - -; -; Compute the results mask for zeroing or accumulate mode. -; - - vxorps xmm0,xmm0,xmm0 - vcmpeqss xmm0,xmm0,DWORD PTR SgemmKernelM1Frame.Beta[rsp] - vshufps xmm0,xmm0,xmm0,0 - -; -; Compute the conditional load/store mask for an unaligned CountK. -; - - mov eax,r9d - and eax,7 - vmovd xmm7,eax - vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR [MlasMaskMoveAvx+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR [MlasMaskMoveAvx] - vinsertf128 ymm7,ymm7,xmm6,1 - -; -; Process 4 rows of the matrices in a loop. -; - - sub rbp,4 - jb ProcessRemainingCountN - -ProcessRowLoop4: - vxorps xmm2,xmm2,xmm2 ; clear row accumulators - vxorps xmm3,xmm3,xmm3 - vxorps xmm4,xmm4,xmm4 - vxorps xmm5,xmm5,xmm5 - mov rcx,r10 ; reload matrix A - mov rdx,r11 ; reload matrix B - mov rax,r9 ; reload CountK - lea r11,[rdx+rbx*4] ; advance matrix B by 4 rows - sub rax,8 - jb ProcessRemainingCountK4 - -ProcessColumnLoop4: - lea rsi,[rdx+rbx*2] ; compute matrix B plus 2 rows - vmovups ymm1,YMMWORD PTR [rcx] - vmulps ymm6,ymm1,YMMWORD PTR [rdx] - vaddps ymm2,ymm2,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rdx+rbx] - vaddps ymm3,ymm3,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rsi] - vaddps ymm4,ymm4,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rsi+rbx] - vaddps ymm5,ymm5,ymm6 - add rcx,8*4 ; advance matrix A by 8 columns - add rdx,8*4 ; advance matrix B by 8 columns - sub rax,8 - jae ProcessColumnLoop4 - -ProcessRemainingCountK4: - test al,7 ; test for unaligned columns - jz Output4x1Block - lea rsi,[rdx+rbx*2] ; compute matrix B plus 2 rows - vmaskmovps ymm1,ymm7,YMMWORD PTR [rcx] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx+rbx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm3,ymm3,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm6,ymm1,ymm6 - vaddps ymm4,ymm4,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+rbx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm5,ymm5,ymm6 - -; -; Reduce and output the row accumulators. -; - -Output4x1Block: - vunpcklps ymm6,ymm2,ymm3 ; transpose row accumulators - vunpckhps ymm1,ymm2,ymm3 - vunpcklps ymm2,ymm4,ymm5 - vunpckhps ymm3,ymm4,ymm5 - vunpcklpd ymm4,ymm6,ymm2 - vunpckhpd ymm5,ymm6,ymm2 - vaddps ymm4,ymm4,ymm5 - vunpcklpd ymm6,ymm1,ymm3 - vunpckhpd ymm2,ymm1,ymm3 - vaddps ymm4,ymm4,ymm6 - vaddps ymm4,ymm4,ymm2 - vextractf128 xmm5,ymm4,1 - vaddps xmm4,xmm4,xmm5 - vandnps xmm6,xmm0,XMMWORD PTR [r8] - vaddps xmm4,xmm4,xmm6 - vmovups XMMWORD PTR [r8],xmm4 - add r8,4*4 ; advance matrix C by 4 columns - sub rbp,4 - jae ProcessRowLoop4 - -ProcessRemainingCountN: - test ebp,2 - jnz ProcessRowLoop2 - test ebp,1 - jnz ProcessRowLoop1 - -ExitKernel: - vzeroupper - movaps xmm6,SgemmKernelM1Frame.SavedXmm6[rsp] - movaps xmm7,SgemmKernelM1Frame.SavedXmm7[rsp] - add rsp,(SgemmKernelM1Frame.SavedRsi) - - BEGIN_EPILOGUE - - pop rsi - pop rbx - pop rbp - ret - -; -; Process 2 rows of the matrices. -; - -ProcessRowLoop2: - vxorps xmm2,xmm2,xmm2 ; clear row accumulators - vxorps xmm3,xmm3,xmm3 - mov rcx,r10 ; reload matrix A - mov rdx,r11 ; reload matrix B - mov rax,r9 ; reload CountK - lea r11,[rdx+rbx*2] ; advance matrix B by 2 rows - sub rax,8 - jb ProcessRemainingCountK2 - -ProcessColumnLoop2: - vmovups ymm1,YMMWORD PTR [rcx] - vmulps ymm6,ymm1,YMMWORD PTR [rdx] - vaddps ymm2,ymm2,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rdx+rbx] - vaddps ymm3,ymm3,ymm6 - add rcx,8*4 ; advance matrix A by 8 columns - add rdx,8*4 ; advance matrix B by 8 columns - sub rax,8 - jae ProcessColumnLoop2 - -ProcessRemainingCountK2: - test al,7 ; test for unaligned columns - jz Output2x1Block - vmaskmovps ymm1,ymm7,YMMWORD PTR [rcx] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx+rbx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm3,ymm3,ymm6 - -; -; Reduce and output the row accumulators. -; - -Output2x1Block: - vunpcklps ymm4,ymm2,ymm3 ; reduce row accumulators - vunpckhps ymm2,ymm2,ymm3 - vaddps ymm2,ymm2,ymm4 - vextractf128 xmm4,ymm2,1 - vaddps xmm2,xmm2,xmm4 - vmovhlps xmm4,xmm2,xmm2 - vaddps xmm2,xmm2,xmm4 - vmovsd xmm3,QWORD PTR [r8] - vandnps xmm3,xmm0,xmm3 - vaddps xmm2,xmm2,xmm3 - vmovsd QWORD PTR [r8],xmm2 - add r8,2*4 ; advance matrix C by 2 columns - test ebp,1 - jz ExitKernel - -; -; Process 1 row of the matrices. -; - -ProcessRowLoop1: - vxorps xmm2,xmm2,xmm2 ; clear row accumulators - mov rcx,r10 ; reload matrix A - mov rdx,r11 ; reload matrix B - mov rax,r9 ; reload CountK - sub rax,8 - jb ProcessRemainingCountK1 - -ProcessColumnLoop1: - vmovups ymm1,YMMWORD PTR [rcx] - vmulps ymm6,ymm1,YMMWORD PTR [rdx] - vaddps ymm2,ymm2,ymm6 - add rcx,8*4 ; advance matrix A by 8 columns - add rdx,8*4 ; advance matrix B by 8 columns - sub rax,8 - jae ProcessColumnLoop1 - -ProcessRemainingCountK1: - test al,7 ; test for unaligned columns - jz Output1x1Block - vmaskmovps ymm1,ymm7,YMMWORD PTR [rcx] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - -; -; Reduce and output the row accumulators. -; - -Output1x1Block: - vhaddps ymm2,ymm2,ymm2 ; reduce row accumulators - vhaddps ymm2,ymm2,ymm2 - vextractf128 xmm4,ymm2,1 - vaddss xmm2,xmm2,xmm4 - vmovss xmm3,DWORD PTR [r8] - vandnps xmm3,xmm0,xmm3 - vaddss xmm2,xmm2,xmm3 - vmovss DWORD PTR [r8],xmm2 - jmp ExitKernel - - NESTED_END MlasSgemmKernelM1TransposeBAvx, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SgemmKernelSse2.asm b/onnxruntime/core/mlas/lib/amd64/SgemmKernelSse2.asm deleted file mode 100644 index 4146a0e37646e..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SgemmKernelSse2.asm +++ /dev/null @@ -1,280 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelSse2.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses SSE2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc -INCLUDE FgemmKernelSse2Common.inc - .list - -; -; Macro Description: -; -; This macro multiplies and accumulates for a 16xN block of the output matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; Shuffle - Supplies the shuffle mask to extract the element from matrix A. -; -; Implicit Arguments: -; -; rdx - Supplies the address into the matrix B data. -; -; xmm0-xmm1 - Supplies up to four elements loaded from matrix A and matrix A -; plus one row. -; -; xmm8-xmm15 - Supplies the block accumulators. -; - -ComputeBlockSseBy16 MACRO RowCount, VectorOffset, Shuffle - - movaps xmm4,XMMWORD PTR [rdx+VectorOffset] - movaps xmm5,XMMWORD PTR [rdx+VectorOffset+16] - pshufd xmm2,xmm0,Shuffle -IF RowCount EQ 2 - pshufd xmm3,xmm1,Shuffle - movaps xmm6,xmm4 - movaps xmm7,xmm5 -ENDIF - mulps xmm4,xmm2 - mulps xmm5,xmm2 - addps xmm8,xmm4 - addps xmm9,xmm5 -IF RowCount EQ 2 - mulps xmm6,xmm3 - mulps xmm7,xmm3 - addps xmm12,xmm6 - addps xmm13,xmm7 -ENDIF - movaps xmm4,XMMWORD PTR [rdx+VectorOffset+32] - movaps xmm5,XMMWORD PTR [rdx+VectorOffset+48] -IF RowCount EQ 2 - movaps xmm6,xmm4 - movaps xmm7,xmm5 -ENDIF - mulps xmm4,xmm2 - mulps xmm5,xmm2 - addps xmm10,xmm4 - addps xmm11,xmm5 -IF RowCount EQ 2 - mulps xmm6,xmm3 - mulps xmm7,xmm3 - addps xmm14,xmm6 - addps xmm15,xmm7 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute matrix multiplication for a fixed set -; of rows. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; Fallthrough - Supplies a non-blank value if the macro may fall through to -; the ExitKernel label. -; -; Implicit Arguments: -; -; rax - Supplies the length in bytes of a row from matrix C. -; -; rcx - Supplies the address of matrix A. -; -; rdx - Supplies the address of matrix B. -; -; rsi - Supplies the address of matrix A. -; -; rbp - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; r8 - Supplies the address of matrix C. -; -; r9 - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; r10 - Supplies the length in bytes of a row from matrix A. -; -; r15 - Stores the ZeroMode argument from the stack frame. -; - -ProcessCountM MACRO RowCount, Fallthrough - - LOCAL ProcessNextColumnLoop16xN - LOCAL Compute16xNBlockBy4Loop - LOCAL ProcessRemaining16xNBlocks - LOCAL Compute16xNBlockBy1Loop - LOCAL Output16xNBlock - LOCAL OutputPartial16xNBlock - LOCAL OutputPartialLessThan12xNBlock - LOCAL OutputPartialLessThan8xNBlock - LOCAL OutputPartialLessThan4xNBlock - LOCAL SkipAccumulateOutput2xN - LOCAL OutputPartial1xNBlock - LOCAL SkipAccumulateOutput1xN - -ProcessNextColumnLoop16xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - mov rdi,r9 ; reload CountK - sub rdi,4 - jb ProcessRemaining16xNBlocks - -Compute16xNBlockBy4Loop: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - ComputeBlockSseBy16 RowCount, 0, 000h - ComputeBlockSseBy16 RowCount, 16*4, 055h - sub rdx,-32*4 ; advance matrix B by 32 columns - ComputeBlockSseBy16 RowCount, 0, 0AAh - ComputeBlockSseBy16 RowCount, 16*4, 0FFh - sub rdx,-32*4 ; advance matrix B by 32 columns - add rcx,4*4 ; advance matrix A by 4 columns - sub rdi,4 - jae Compute16xNBlockBy4Loop - -ProcessRemaining16xNBlocks: - add rdi,4 ; correct for over-subtract above - jz Output16xNBlock - -Compute16xNBlockBy1Loop: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - ComputeBlockSseBy16 RowCount, 0, 000h - add rdx,16*4 ; advance matrix B by 16 columns - add rcx,4 ; advance matrix A by 1 column - dec rdi - jne Compute16xNBlockBy1Loop - -Output16xNBlock: - movss xmm2,DWORD PTR FgemmKernelFrame.Alpha[rsp] - shufps xmm2,xmm2,0 - EmitIfCountGE RowCount, 1, - ; multiply by alpha - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, - sub rbp,16 - jb OutputPartial16xNBlock - AccumulateAndStoreBlock RowCount, 4 - add r8,16*4 ; advance matrix C by 16 columns - mov rcx,rsi ; reload matrix A - test rbp,rbp - jnz ProcessNextColumnLoop16xN - jmp ExitKernel - -; -; Output a partial 16xN block to the matrix. -; - -OutputPartial16xNBlock: - add rbp,16 ; correct for over-subtract above - cmp ebp,4 - jb OutputPartialLessThan4xNBlock - cmp ebp,8 - jb OutputPartialLessThan8xNBlock - cmp ebp,12 - jb OutputPartialLessThan12xNBlock - AccumulateAndStoreBlock RowCount, 3 - and ebp,3 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,12*4 ; advance matrix C by 12 columns - jmp OutputPartialLessThan4xNBlock - -OutputPartialLessThan12xNBlock: - AccumulateAndStoreBlock RowCount, 2 - and ebp,3 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,8*4 ; advance matrix C by 8 columns - jmp OutputPartialLessThan4xNBlock - -OutputPartialLessThan8xNBlock: - AccumulateAndStoreBlock RowCount, 1 - and ebp,3 ; check if remaining count is small - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift remaining elements down - EmitIfCountGE RowCount, 2, - add r8,4*4 ; advance matrix C by 4 columns - -OutputPartialLessThan4xNBlock: - test ebp,2 - jz OutputPartial1xNBlock - test r15b,r15b ; ZeroMode? - jnz SkipAccumulateOutput2xN - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - -SkipAccumulateOutput2xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - test ebp,1 ; check if remaining count is odd - jz ExitKernel - EmitIfCountGE RowCount, 1, - ; shift third element down - EmitIfCountGE RowCount, 2, - add r8,2*4 ; advance matrix C by 2 columns - -OutputPartial1xNBlock: - test r15b,r15b ; ZeroMode? - jnz SkipAccumulateOutput1xN - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, - -SkipAccumulateOutput1xN: - EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 2, -IFB - jmp ExitKernel -ENDIF - - ENDM - -; -; Generate the GEMM kernel. -; - -FgemmKernelSse2Function Float - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm deleted file mode 100644 index bfdff7009191e..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm +++ /dev/null @@ -1,248 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SoftmaxKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision softmax -; operation. -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMinimumF32Value:NEAR - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel to find the maximum value of -; the supplied buffer. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; N (rdx) - Supplies the number of elements to process. -; -; Return Value: -; -; Returns the maximum value of the supplied buffer. -; -;-- - - LEAF_ENTRY MlasReduceMaximumF32KernelAvx, _TEXT - - vbroadcastss ymm0,DWORD PTR [MlasMinimumF32Value] - test rdx,rdx - jz ExitKernel - cmp rdx,8 - jb ProcessRemainingCountBy1 - cmp rdx,32 - jb ProcessRemainingCountBy8 - vmovaps ymm1,ymm0 - vmovaps ymm2,ymm0 - vmovaps ymm3,ymm0 - -ProcessRemainingCountBy32: - vmaxps ymm0,ymm0,YMMWORD PTR [rcx] - vmaxps ymm1,ymm1,YMMWORD PTR [rcx+8*4] - sub rdx,32 - vmaxps ymm2,ymm2,YMMWORD PTR [rcx+16*4] - vmaxps ymm3,ymm3,YMMWORD PTR [rcx+24*4] - add rcx,32*4 ; advance input by 32 elements - cmp rdx,32 - jae ProcessRemainingCountBy32 - vmaxps ymm0,ymm0,ymm1 ; reduce to single vector - vmaxps ymm2,ymm2,ymm3 - vmaxps ymm0,ymm0,ymm2 - -ProcessRemainingCountBy8: - cmp rdx,8 - jb ProcessRemainingCountLessThan8 - vmaxps ymm0,ymm0,YMMWORD PTR [rcx] - sub rdx,8 - add rcx,8*4 ; advance input by 8 elements - jmp ProcessRemainingCountBy8 - -ProcessRemainingCountLessThan8: - vextractf128 xmm1,ymm0,1 ; reduce to single scalar - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0EEh - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,055h - vmaxss xmm0,xmm0,xmm1 - test rdx,rdx - jz ExitKernel - -ProcessRemainingCountBy1: - vmaxss xmm0,xmm0,DWORD PTR [rcx] - add rcx,4 ; advance input by 1 element - dec edx - jnz ProcessRemainingCountBy1 - -ExitKernel: - vzeroupper - ret - - LEAF_END MlasReduceMaximumF32KernelAvx, _TEXT - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel to produce the final output for -; the softmax operation. -; -; Arguments: -; -; Output (rcx) - Supplies the output buffer. -; -; N (rdx) - Supplies the number of elements to process. -; -; Parameters (r8) - Supplies an array containing the scale value. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasComputeSoftmaxOutputF32KernelAvx, _TEXT - - vbroadcastss ymm4,DWORD PTR [r8] ; broadcast scale value - cmp rdx,32 - jb ProcessRemainingCountBy8 - -ProcessRemainingCountBy32: - vmulps ymm0,ymm4,YMMWORD PTR [rcx] - vmulps ymm1,ymm4,YMMWORD PTR [rcx+8*4] - sub rdx,32 - vmulps ymm2,ymm4,YMMWORD PTR [rcx+16*4] - vmulps ymm3,ymm4,YMMWORD PTR [rcx+24*4] - vmovups YMMWORD PTR [rcx],ymm0 - vmovups YMMWORD PTR [rcx+8*4],ymm1 - vmovups YMMWORD PTR [rcx+16*4],ymm2 - vmovups YMMWORD PTR [rcx+24*4],ymm3 - add rcx,32*4 ; advance output by 32 elements - cmp rdx,32 - jae ProcessRemainingCountBy32 - -ProcessRemainingCountBy8: - cmp rdx,8 - jb ProcessRemainingCountLessThan8 - vmulps ymm0,ymm4,YMMWORD PTR [rcx] - sub rdx,8 - vmovups YMMWORD PTR [rcx],ymm0 - add rcx,8*4 ; advance output by 8 elements - jmp ProcessRemainingCountBy8 - -ProcessRemainingCountLessThan8: - test rdx,rdx - jz ExitKernel - -ProcessRemainingCountBy1: - vmulss xmm0,xmm4,DWORD PTR [rcx] - vmovss DWORD PTR [rcx],xmm0 - add rcx,4 ; advance output by 1 element - dec edx - jnz ProcessRemainingCountBy1 - -ExitKernel: - vzeroupper - ret - - LEAF_END MlasComputeSoftmaxOutputF32KernelAvx, _TEXT - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel to produce the final output for -; the log softmax operation. -; -; Arguments: -; -; Input (rcx) - Supplies the output buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Parameters (r9) - Supplies an array containing the negative maximum and -; logarithm values. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasComputeLogSoftmaxOutputF32KernelAvx, _TEXT - - vbroadcastss ymm4,DWORD PTR [r9] ; broadcast negative minimum value - vbroadcastss ymm5,DWORD PTR [r9+4] ; broadcast log(SumExp) - cmp r8,32 - jb ProcessRemainingCountBy8 - -ProcessRemainingCountBy32: - vaddps ymm0,ymm4,YMMWORD PTR [rcx] - vaddps ymm1,ymm4,YMMWORD PTR [rcx+8*4] - sub r8,32 - vaddps ymm2,ymm4,YMMWORD PTR [rcx+16*4] - vaddps ymm3,ymm4,YMMWORD PTR [rcx+24*4] - add rcx,32*4 ; advance input by 32 elements - vsubps ymm0,ymm0,ymm5 ; do as two steps for numeric stability - vsubps ymm1,ymm1,ymm5 - vsubps ymm2,ymm2,ymm5 - vsubps ymm3,ymm3,ymm5 - vmovups YMMWORD PTR [rdx],ymm0 - vmovups YMMWORD PTR [rdx+8*4],ymm1 - vmovups YMMWORD PTR [rdx+16*4],ymm2 - vmovups YMMWORD PTR [rdx+24*4],ymm3 - add rdx,32*4 ; advance output by 32 elements - cmp r8,32 - jae ProcessRemainingCountBy32 - -ProcessRemainingCountBy8: - cmp r8,8 - jb ProcessRemainingCountLessThan8 - vaddps ymm0,ymm4,YMMWORD PTR [rcx] - add rcx,8*4 ; advance input by 8 elements - vsubps ymm0,ymm0,ymm5 ; do as two steps for numeric stability - sub r8,8 - vmovups YMMWORD PTR [rdx],ymm0 - add rdx,8*4 ; advance output by 8 elements - jmp ProcessRemainingCountBy8 - -ProcessRemainingCountLessThan8: - test r8,r8 - jz ExitKernel - -ProcessRemainingCountBy1: - vaddss xmm0,xmm4,DWORD PTR [rcx] - add rcx,4 ; advance input by 1 element - vsubss xmm0,xmm0,xmm5 - vmovss DWORD PTR [rdx],xmm0 - add rdx,4 ; advance output by 1 element - dec r8d - jnz ProcessRemainingCountBy1 - -ExitKernel: - vzeroupper - ret - - LEAF_END MlasComputeLogSoftmaxOutputF32KernelAvx, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm deleted file mode 100644 index 3e83bc852f558..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx512F.asm +++ /dev/null @@ -1,103 +0,0 @@ -;++ -; -;Copyright (c) Microsoft Corporation. All rights reserved. -; -;Licensed under the MIT License. -; -;Module Name: -; -; SoftmaxKernelAvx512F.asm -; -;Abstract: -; -; This module implements the kernels for the single precision softmax -; operation. -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - EXTERN MlasMinimumF32Value:NEAR - -;++ -; -;Routine Description: -; -; This routine implements a vectorized kernel to find the maximum value of -; the supplied buffer. -; -;Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; N (rdx) - Supplies the number of elements to process. -; -;Return Value: -; -; Returns the maximum value of the supplied buffer. -; -;-- - - LEAF_ENTRY MlasReduceMaximumF32KernelAvx512F, _TEXT - - vbroadcastss zmm0,DWORD PTR [MlasMinimumF32Value] - test rdx,rdx - jz ExitKernel - cmp rdx,16 - jb ProcessRemainingCountBy1 - cmp rdx,64 - jb ProcessRemainingCountBy16 - vmovaps zmm1,zmm0 - vmovaps zmm2,zmm0 - vmovaps zmm3,zmm0 - -ProcessRemainingCountBy64: - vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] - vmaxps zmm1,zmm1,ZMMWORD PTR [rcx+16*4] - sub rdx,64 - vmaxps zmm2,zmm2,ZMMWORD PTR [rcx+32*4] - vmaxps zmm3,zmm3,ZMMWORD PTR [rcx+48*4] - add rcx,64*4 ; advance input by 64 elements - cmp rdx,64 - jae ProcessRemainingCountBy64 - vmaxps zmm0,zmm0,zmm1 ; reduce to single vector - vmaxps zmm2,zmm2,zmm3 - vmaxps zmm0,zmm0,zmm2 - -ProcessRemainingCountBy16: - cmp rdx,16 - jb ProcessRemainingCountLessThan16 - vmaxps zmm0,zmm0,ZMMWORD PTR [rcx] - sub rdx,16 - add rcx,16*4 ; advance input by 16 elements - jmp ProcessRemainingCountBy16 - -ProcessRemainingCountLessThan16: - vextractf32x8 ymm1,zmm0,1 ; reduce to single scalar - vmaxps ymm0,ymm0,ymm1 - vextractf128 xmm1,ymm0,1 - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0EEh - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,055h - vmaxss xmm0,xmm0,xmm1 - test rdx,rdx - jz ExitKernel - -ProcessRemainingCountBy1: - vmaxss xmm0,xmm0,DWORD PTR [rcx] - add rcx,4 ; advance input by 1 element - dec edx - jnz ProcessRemainingCountBy1 - -ExitKernel: - vzeroupper - ret - - LEAF_END MlasReduceMaximumF32KernelAvx512F, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx.asm deleted file mode 100644 index 3d0f64f447f8e..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx.asm +++ /dev/null @@ -1,220 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SpoolKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision pooling -; operation. -; -; This implementation uses AVX instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SpoolKernelAvxCommon.inc - .list - -; -; Macro Description: -; -; This macro generates code to initialize registers used across the kernel. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; - -InitializeKernel MACRO PoolingType - -IFIDNI , - mov DWORD PTR SpoolKernelFrame.PreviousP1Home[rsp],0FF7FFFFFh - vbroadcastss ymm5,DWORD PTR SpoolKernelFrame.PreviousP1Home[rsp] -ELSE - vxorps xmm5,xmm5,xmm5 ; initialize default divisor vector -IFIDNI , - mov rax,SpoolKernelFrame.KernelHeight[rsp] - imul rax,SpoolKernelFrame.KernelWidth[rsp] - vcvtsi2ss xmm5,xmm5,rax -ELSE - vcvtsi2ss xmm5,xmm5,SpoolKernelFrame.ActualKernelSize[rsp] -ENDIF - vshufps xmm5,xmm5,xmm5,0 - vinsertf128 ymm5,ymm5,xmm5,1 ; AVX lacks "vbroadcastss ymm5,xmm5" -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to clear the pooling intermediates. -; -; For PoolingType==Maximum, the pooling intermediates are set to the minimum -; float value. Otherwise, the pooling intermediates are cleared to zero. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; ymm0-ymm2 - Supplies the pooling intermediates. -; -; ymm5 - Supplies a vector containing the minimum float value broadcasted, -; if PoolingType==Maximum. -; - -ClearBlock MACRO PoolingType, OutputCount - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - -IFIDNI , -IF OutputCount EQ 1 - xor rsi,rsi ; reset valid block counter -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to sample the input buffer and update the pooling -; intermediates as appropriate. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; r8 - Supplies the StrideWidth parameter (see function description). -; -; ymm0-ymm2 - Supplies the pooling intermediates. -; - -ComputeBlock MACRO PoolingType, OutputCount - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - -IFIDNI , -IF OutputCount EQ 1 - inc rsi ; increment valid block counter -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to process and store the pooling intermediates. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdx - Supplies the address of the output buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; ymm0-ymm2 - Supplies the pooling intermediates. -; -; ymm5 - Supplies the kernel size computed by InitializeKernel, if -; PoolingType=AverageExcludePad, else the actual kernel size, if -; PoolingType=AverageIncludePad. -; - -PostProcessBlock MACRO PoolingType, OutputCount - -; -; If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -; blocks. OutputCount=1 generates code to count the number of blocks accessed by -; ComputeBlock. Other cases use the kernel size computed by InitializeKernel. -; - -IFIDNI , -IF OutputCount EQ 1 - vxorps xmm4,xmm4,xmm4 - vcvtsi2ss xmm4,xmm4,rsi ; convert valid block counter - vshufps xmm4,xmm4,xmm4,0 - vinsertf128 ymm4,ymm4,xmm4,1 ; AVX lacks "vbroadcastss ymm4,xmm4" - vdivps ymm0,ymm0,ymm4 -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF -ENDIF - -; -; If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -; - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - -; -; Store the output block in the output buffer. -; - - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, - add rdx,OutputCount*8*4 ; advance output by N nchw8c blocks - - ENDM - -; -; Generate the pooling kernels. -; - -SpoolKernelFunction Maximum, Avx -SpoolKernelFunction AverageExcludePad, Avx -SpoolKernelFunction AverageIncludePad, Avx - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx512F.asm deleted file mode 100644 index ca1fb690194d3..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvx512F.asm +++ /dev/null @@ -1,214 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SpoolKernelAvx512F.asm -; -; Abstract: -; -; This module implements the kernels for the single precision pooling -; operation. -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SpoolKernelAvxCommon.inc - .list - -; -; Macro Description: -; -; This macro generates code to initialize registers used across the kernel. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; - -InitializeKernel MACRO PoolingType - -IFIDNI , - mov DWORD PTR SpoolKernelFrame.PreviousP1Home[rsp],0FF7FFFFFh - vbroadcastss zmm5,DWORD PTR SpoolKernelFrame.PreviousP1Home[rsp] -ELSE - vxorps xmm5,xmm5,xmm5 ; initialize default divisor vector -IFIDNI , - mov rax,SpoolKernelFrame.KernelHeight[rsp] - imul rax,SpoolKernelFrame.KernelWidth[rsp] - vcvtsi2ss xmm5,xmm5,rax -ELSE - vcvtsi2ss xmm5,xmm5,SpoolKernelFrame.ActualKernelSize[rsp] -ENDIF - vbroadcastss zmm5,xmm5 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to clear the pooling intermediates. -; -; For PoolingType==Maximum, the pooling intermediates are set to the minimum -; float value. Otherwise, the pooling intermediates are cleared to zero. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; zmm0-zmm2 - Supplies the pooling intermediates. -; -; zmm5 - Supplies a vector containing the minimum float value broadcasted, -; if PoolingType==Maximum. -; - -ClearBlock MACRO PoolingType, OutputCount - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - -IFIDNI , -IF OutputCount EQ 1 - xor rsi,rsi ; reset valid block counter -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to sample the input buffer and update the pooling -; intermediates as appropriate. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; r8 - Supplies the StrideWidth parameter (see function description). -; -; zmm0-zmm2 - Supplies the pooling intermediates. -; - -ComputeBlock MACRO PoolingType, OutputCount - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - -IFIDNI , -IF OutputCount EQ 1 - inc rsi ; increment valid block counter -ENDIF -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to process and store the pooling intermediates. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdx - Supplies the address of the output buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; zmm0-zmm2 - Supplies the pooling intermediates. -; -; zmm5 - Supplies the kernel size computed by InitializeKernel, if -; PoolingType=AverageExcludePad, else the actual kernel size, if -; PoolingType=AverageIncludePad. -; - -PostProcessBlock MACRO PoolingType, OutputCount - -; -; If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -; blocks. OutputCount=1 generates code to count the number of blocks accessed by -; ComputeBlock. Other cases use the kernel size computed by InitializeKernel. -; - -IFIDNI , -IF OutputCount EQ 1 - vxorps xmm4,xmm4,xmm4 - vcvtsi2ss xmm4,xmm4,rsi ; convert valid block counter - vbroadcastss zmm4,xmm4 - vdivps zmm0,zmm0,zmm4 -ELSE - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF -ENDIF - -; -; If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -; - -IFIDNI , - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, -ENDIF - - EmitIfCountGE OutputCount, 1, - EmitIfCountGE OutputCount, 2, - EmitIfCountGE OutputCount, 3, - add rdx,OutputCount*16*4 ; advance output by N nchw16c blocks - - ENDM - -; -; Generate the pooling kernels. -; - -SpoolKernelFunction Maximum, Avx512F -SpoolKernelFunction AverageExcludePad, Avx512F -SpoolKernelFunction AverageIncludePad, Avx512F - - END diff --git a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvxCommon.inc b/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvxCommon.inc deleted file mode 100644 index fd082af2bf3fe..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SpoolKernelAvxCommon.inc +++ /dev/null @@ -1,147 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SpoolKernelAvxCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision pooling operation for the AVX and AVX512F kernels. -; -;-- - -INCLUDE SpoolKernelCommon.inc - -; -; Macro Description: -; -; This macro generates code for the inner pooling kernel. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -SpoolKernelFunction MACRO PoolingType, Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute pooling for the elements of an -; output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; The address is biased to include padding blocks for the left width -; dimension. The address is not biased to include padding rows for the -; left height dimension; these are accounted for in the outer kernel. -; -; Output (rdx) - Supplies the address of the output buffer. -; -; StrideWidth (r8) - Supplies the length in bytes of the blocked stride width. -; -; DilationWidth (r9) - Supplies the length in bytes of the blocked dilation -; width. -; -; InputStride - Supplies the length in bytes to advance the input buffer to -; the next input row. -; -; ActualKernelSize - Supplies the size of the kernel based on the original -; kernel dimensions, used for PoolingType=AverageIncludePad. -; -; KernelHeight - Supplies the height of the kernel to apply. This height may -; be less than the original kernel height after removing any padding -; rows. -; -; KernelWidth - Supplies the width of the kernel to apply. -; -; InputBase - Supplies the address of the valid input buffer. -; -; This parameter is similar to the Input parameter, but does not include -; the padding blocks for the left width dimension. This parameter is used -; with the following InputWidth parameter in order to validate that the -; current input buffer address in bounds and not in the left or right -; width padding region. -; -; InputWidth - Supplies the length in bytes of the blocked input width. -; -; DilatedInputWidth - Supplies the length in bytes to advance the input base -; buffer to the next input row including dilation. -; -; OutputCountLeftPad - Supplies the number of output elements that include -; one or more padding elements from the left edge. -; -; OutputCount - Supplies the number of output elements that do not include -; any padding elements. -; -; OutputCountRightPad - Supplies the number of output elements that include -; one or more padding elements from the right edge. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasPool&PoolingType&FloatKernel&Isa&, _TEXT - - SpoolKernelEntry PoolingType - -ProcessOutputCountLeftPad: - mov r10,SpoolKernelFrame.OutputCountLeftPad[rsp] - test r10,r10 - jz ProcessOutputCount - call MlasPool&PoolingType&FloatSingle&Isa& - -ProcessOutputCount: - mov r10,SpoolKernelFrame.OutputCount[rsp] - sub r10,3 - jb ProcessRemainingOutputCount - -ProcessNextOutputCountBy3: - ProcessOutputCountN SpoolKernelFrame, PoolingType, 3 - lea rax,[r8*2+r8] - add rdi,rax ; advance input by 3 elements - sub r10,3 - jae ProcessNextOutputCountBy3 - -ProcessRemainingOutputCount: - add r10,3 ; correct for over-subtract above - -ProcessOutputCountRightPad: - add r10,SpoolKernelFrame.OutputCountRightPad[rsp] - jz ExitKernel - call MlasPool&PoolingType&FloatSingle&Isa& - -ExitKernel: - vzeroupper - SpoolKernelExit - - NESTED_END MlasPool&PoolingType&FloatKernel&Isa&, _TEXT - -; -; Generate out-of-band helpers for handling output blocks involving padding. -; - - LEAF_ENTRY MlasPool&PoolingType&FloatSingle&Isa&, _TEXT - -ProcessNextOutputCount: - ProcessOutputCountN SpoolKernelSingleFrame.KernelFrame, PoolingType, 1 - add rdi,r8 ; advance input by 1 element - dec r10 ; decrement output count remaining - jnz ProcessNextOutputCount - ret - - LEAF_END MlasPool&PoolingType&FloatSingle&Isa&, _TEXT - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/SpoolKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/SpoolKernelCommon.inc deleted file mode 100644 index 3f735bb27cde2..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SpoolKernelCommon.inc +++ /dev/null @@ -1,183 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SpoolKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision pooling operation. -; -;-- - -; -; Stack frame layout for the pooling kernels. -; - -SpoolKernelFrame STRUCT - - SavedR12 QWORD ? - SavedR13 QWORD ? - SavedR14 QWORD ? - SavedRdi QWORD ? - SavedRsi QWORD ? - SavedRbx QWORD ? - SavedRbp QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? ; Input - PreviousP2Home QWORD ? ; Output - PreviousP3Home QWORD ? ; StrideWidth - PreviousP4Home QWORD ? ; DilationWidth - InputStride QWORD ? - ActualKernelSize QWORD ? - KernelHeight QWORD ? - KernelWidth QWORD ? - InputBase QWORD ? - InputWidth QWORD ? - DilatedInputWidth QWORD ? - OutputCountLeftPad QWORD ? - OutputCount QWORD ? - OutputCountRightPad QWORD ? - -SpoolKernelFrame ENDS - -SpoolKernelSingleFrame STRUCT - - ReturnAddress QWORD ? - KernelFrame SpoolKernelFrame <> - -SpoolKernelSingleFrame ENDS - -; -; Macro Description: -; -; This macro generates the common prologue code for the pooling kernels. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; - -SpoolKernelEntry MACRO PoolingType - - rex_push_reg rbp - push_reg rbx - push_reg rsi - push_reg rdi - push_reg r14 - push_reg r13 - push_reg r12 - - END_PROLOGUE - - mov rdi,rcx - mov rbp,SpoolKernelFrame.InputStride[rsp] - InitializeKernel PoolingType - - ENDM - -; -; Macro Description: -; -; This macro generates the common epilogue code for the pooling kernels. -; -; Arguments: -; -; None. -; - -SpoolKernelExit MACRO - - BEGIN_EPILOGUE - - pop r12 - pop r13 - pop r14 - pop rdi - pop rsi - pop rbx - pop rbp - ret - - ENDM - -; -; Macro Description: -; -; This macro generates code to compute pooling for a vector of input blocks -; to produce a matrix of output blocks. -; -; OutputCount=1 generates special case code to handle padding blocks. All -; other output counts assume no padding. -; -; Arguments: -; -; KernelFrame - Supplies the symbol name to access the convolution kernel -; stack. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdi - Supplies the address of the input buffer. -; -; rdx - Supplies the address of the output buffer. -; -; r8 - Supplies the StrideWidth parameter (see function description). -; -; r9 - Supplies the DilationWidth parameter (see function description). -; -; rbp - Supplies the InputStride parameter (see function description). -; - -ProcessOutputCountN MACRO KernelFrame, PoolingType, OutputCount - - LOCAL ProcessNextRow - LOCAL ProcessNextColumn - LOCAL SkipOverPadding - LOCAL HandlePostProcessing - - mov rcx,rdi - mov r11,KernelFrame.KernelHeight[rsp] - mov r12,KernelFrame.KernelWidth[rsp] -IF OutputCount EQ 1 - mov r13,KernelFrame.InputBase[rsp] - mov r14,KernelFrame.InputWidth[rsp] - neg r13 ; keep negative for lea usage below -ENDIF - ClearBlock PoolingType, OutputCount - test r11,r11 ; zero sized kernel? - jz HandlePostProcessing - -ProcessNextRow: - mov rax,r12 - -ProcessNextColumn: -IF OutputCount EQ 1 - lea rbx,[rcx+r13] ; compute (Input - InputBase) - cmp rbx,r14 ; (Input - InputBase) >= InputWidth? - jae SkipOverPadding -ENDIF - ComputeBlock PoolingType, OutputCount - -SkipOverPadding: - add rcx,r9 ; advance input by dilation width - dec rax ; decrement columns remaining - jnz ProcessNextColumn - add rcx,rbp ; advance input to next row -IF OutputCount EQ 1 - sub r13,KernelFrame.DilatedInputWidth[rsp] - ; advance input base to next row -ENDIF - dec r11 - jnz ProcessNextRow - -HandlePostProcessing: - PostProcessBlock PoolingType, OutputCount - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/SpoolKernelSse2.asm b/onnxruntime/core/mlas/lib/amd64/SpoolKernelSse2.asm deleted file mode 100644 index d5c0ef8c9af42..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/SpoolKernelSse2.asm +++ /dev/null @@ -1,296 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SpoolKernelSse2.asm -; -; Abstract: -; -; This module implements the kernels for the single precision pooling -; operation. -; -; This implementation uses SSE2 instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE SpoolKernelCommon.inc - .list - -; -; Macro Description: -; -; This macro generates code to initialize registers used across the kernel. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; - -InitializeKernel MACRO PoolingType - -IFIDNI , - mov eax,0FF7FFFFFh - movd xmm5,eax - shufps xmm5,xmm5,0 -ENDIF - -IFIDNI , - cvtsi2ss xmm5,SpoolKernelFrame.ActualKernelSize[rsp] - shufps xmm5,xmm5,0 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to clear the pooling intermediates. -; -; For PoolingType==Maximum, the pooling intermediates are set to the minimum -; float value. Otherwise, the pooling intermediates are cleared to zero. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; xmm0-xmm1 - Supplies the pooling intermediates. -; -; xmm5 - Supplies a vector containing the minimum float value broadcasted, -; if PoolingType==Maximum. -; - -ClearBlock MACRO PoolingType, OutputCount - -IFIDNI , - movaps xmm0,xmm5 - movaps xmm1,xmm5 -ELSE - xorps xmm0,xmm0 - xorps xmm1,xmm1 -ENDIF - -IFIDNI , - xor rsi,rsi ; reset valid block counter -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to sample the input buffer and update the pooling -; intermediates as appropriate. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the input buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; r8 - Supplies the StrideWidth parameter (see function description). -; -; xmm0-xmm1 - Supplies the pooling intermediates. -; - -ComputeBlock MACRO PoolingType, OutputCount - -IFIDNI , - maxps xmm0,XMMWORD PTR [rcx] - maxps xmm1,XMMWORD PTR [rcx+16] -ELSE - addps xmm0,XMMWORD PTR [rcx] - addps xmm1,XMMWORD PTR [rcx+16] -ENDIF - -IFIDNI , - inc rsi ; increment valid block counter -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to process and store the pooling intermediates. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; OutputCount - Supplies the number of output blocks to produce. -; -; Implicit Arguments: -; -; rdx - Supplies the address of the output buffer. -; -; rsi - Supplies the number of blocks accessed by ComputeBlock, if -; PoolingType=AverageExcludePad and OutputCount=1. -; -; xmm0-xmm1 - Supplies the pooling intermediates. -; -; xmm5 - Supplies the kernel size computed by InitializeKernel, if -; PoolingType=AverageExcludePad, else the actual kernel size, if -; PoolingType=AverageIncludePad. -; - -PostProcessBlock MACRO PoolingType, OutputCount - -; -; If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -; blocks. -; - -IFIDNI , - xorps xmm4,xmm4 - cvtsi2ss xmm4,rsi ; convert valid block counter - shufps xmm4,xmm4,0 - divps xmm0,xmm4 - divps xmm1,xmm4 -ENDIF - -; -; If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -; - -IFIDNI , - divps xmm0,xmm5 - divps xmm1,xmm5 -ENDIF - -; -; Store the output block in the output buffer. -; - - movups XMMWORD PTR [rdx],xmm0 - movups XMMWORD PTR [rdx+16],xmm1 - add rdx,8*4 ; advance output by 1 nchw8c block - - ENDM - -; -; Macro Description: -; -; This macro generates code for the inner pooling kernel. -; -; Arguments: -; -; PoolingType - Supplies the pooling type string. -; -; Isa - Supplies the instruction set architecture string for function tags. -; - -SpoolKernelFunction MACRO PoolingType, Isa - -;++ -; -; Routine Description: -; -; This routine is the inner kernel to compute pooling for the elements of an -; output row for a set of filter rows. -; -; Arguments: -; -; Input (rcx) - Supplies the address of the input buffer. -; -; The address is biased to include padding blocks for the left width -; dimension. The address is not biased to include padding rows for the -; left height dimension; these are accounted for in the outer kernel. -; -; Output (rdx) - Supplies the address of the output buffer. -; -; StrideWidth (r8) - Supplies the length in bytes of the blocked stride width. -; -; DilationWidth (r9) - Supplies the length in bytes of the blocked dilation -; width. -; -; InputStride - Supplies the length in bytes to advance the input buffer to -; the next input row. -; -; ActualKernelSize - Supplies the size of the kernel based on the original -; kernel dimensions, used for PoolingType=AverageIncludePad. -; -; KernelHeight - Supplies the height of the kernel to apply. This height may -; be less than the original kernel height after removing any padding -; rows. -; -; KernelWidth - Supplies the width of the kernel to apply. -; -; InputBase - Supplies the address of the valid input buffer. -; -; This parameter is similar to the Input parameter, but does not include -; the padding blocks for the left width dimension. This parameter is used -; with the following InputWidth parameter in order to validate that the -; current input buffer address in bounds and not in the left or right -; width padding region. -; -; InputWidth - Supplies the length in bytes of the blocked input width. -; -; DilatedInputWidth - Supplies the length in bytes to advance the input base -; buffer to the next input row including dilation. -; -; OutputCountLeftPad - Supplies the number of output elements that include -; one or more padding elements from the left edge. -; -; OutputCount - Supplies the number of output elements that do not include -; any padding elements. -; -; OutputCountRightPad - Supplies the number of output elements that include -; one or more padding elements from the right edge. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasPool&PoolingType&FloatKernel&Isa&, _TEXT - - SpoolKernelEntry PoolingType - - mov r10,SpoolKernelFrame.OutputCountLeftPad[rsp] - add r10,SpoolKernelFrame.OutputCount[rsp] - add r10,SpoolKernelFrame.OutputCountRightPad[rsp] - jz ExitKernel - -ProcessNextOutputCount: - ProcessOutputCountN SpoolKernelFrame, PoolingType, 1 - add rdi,r8 ; advance input by 1 element - dec r10 - jnz ProcessNextOutputCount - -ExitKernel: - SpoolKernelExit - - NESTED_END MlasPool&PoolingType&FloatKernel&Isa&, _TEXT - - ENDM - -; -; Generate the pooling kernels. -; - -SpoolKernelFunction Maximum, Sse -SpoolKernelFunction AverageExcludePad, Sse -SpoolKernelFunction AverageIncludePad, Sse - - END diff --git a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm deleted file mode 100644 index 6d94d533d72ad..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm +++ /dev/null @@ -1,150 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; TanhKernelFma3.asm -; -; Abstract: -; -; This module implements a kernel for computing the hyperbolic tangent -; function for a buffer of elements. -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE TransKernelCommon.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - EXTERN MlasTanhConstants:NEAR - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the hyperbolic tangent -; function. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasComputeTanhF32KernelFma3, _TEXT - - alloc_stack (TransKernelFrame.ReturnAddress) - - save_xmm128 xmm6,TransKernelFrame.SavedXmm6 - save_xmm128 xmm7,TransKernelFrame.SavedXmm7 - save_xmm128 xmm8,TransKernelFrame.SavedXmm8 - save_xmm128 xmm9,TransKernelFrame.SavedXmm9 - save_xmm128 xmm10,TransKernelFrame.SavedXmm10 - save_xmm128 xmm11,TransKernelFrame.SavedXmm11 - save_xmm128 xmm12,TransKernelFrame.SavedXmm12 - save_xmm128 xmm13,TransKernelFrame.SavedXmm13 - save_xmm128 xmm14,TransKernelFrame.SavedXmm14 - save_xmm128 xmm15,TransKernelFrame.SavedXmm15 - - END_PROLOGUE - - lea rax,MlasTanhConstants - vbroadcastss ymm4,TanhConstants.LowerRange[rax] - vbroadcastss ymm5,TanhConstants.UpperRange[rax] - vbroadcastss ymm6,TanhConstants.alpha_13[rax] - vbroadcastss ymm7,TanhConstants.alpha_11[rax] - vbroadcastss ymm8,TanhConstants.alpha_9[rax] - vbroadcastss ymm9,TanhConstants.alpha_7[rax] - vbroadcastss ymm10,TanhConstants.alpha_5[rax] - vbroadcastss ymm11,TanhConstants.alpha_3[rax] - vbroadcastss ymm12,TanhConstants.alpha_1[rax] - vbroadcastss ymm13,TanhConstants.beta_6[rax] - vbroadcastss ymm14,TanhConstants.beta_2[rax] - vbroadcastss ymm15,TanhConstants.beta_0[rax] - - sub r8,8 - jb ProcessRemainingCount - -ComputeTanhBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rcx] ; clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,TanhConstants.beta_4[rax] - vfmadd231ps ymm2,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm2,ymm1,ymm8 ; p = x2 * p + alpha_9 - vfmadd213ps ymm2,ymm1,ymm9 ; p = x2 * p + alpha_7 - vfmadd213ps ymm2,ymm1,ymm10 ; p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm11 ; p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm12 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 ; p = x * p - vdivps ymm0,ymm2,ymm3 ; tanh = p / q - add rcx,8*4 ; advance input by 8 elements - vmovups YMMWORD PTR [rdx],ymm0 - add rdx,8*4 ; advance output by 8 elements - sub r8,8 - jae ComputeTanhBy8Loop - -ProcessRemainingCount: - add r8,8 ; correct for over-subtract above - jz ExitKernel - neg r8 - lea r10,MlasMaskMoveTableAvx+8*4 - vmovups ymm2,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rcx] - vmaxps ymm0,ymm4,ymm0 ; clamp lower bound - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,TanhConstants.beta_4[rax] - vfmadd231ps ymm7,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm7,ymm1,ymm8 ; p = x2 * p + alpha_9 - vfmadd213ps ymm7,ymm1,ymm9 ; p = x2 * p + alpha_7 - vfmadd213ps ymm7,ymm1,ymm10 ; p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm11 ; p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm12 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 ; p = x * p - vdivps ymm0,ymm7,ymm3 ; tanh = p / q - vmaskmovps YMMWORD PTR [rdx],ymm2,ymm0 - -ExitKernel: - vzeroupper - movaps xmm6,TransKernelFrame.SavedXmm6[rsp] - movaps xmm7,TransKernelFrame.SavedXmm7[rsp] - movaps xmm8,TransKernelFrame.SavedXmm8[rsp] - movaps xmm9,TransKernelFrame.SavedXmm9[rsp] - movaps xmm10,TransKernelFrame.SavedXmm10[rsp] - movaps xmm11,TransKernelFrame.SavedXmm11[rsp] - movaps xmm12,TransKernelFrame.SavedXmm12[rsp] - movaps xmm13,TransKernelFrame.SavedXmm13[rsp] - movaps xmm14,TransKernelFrame.SavedXmm14[rsp] - movaps xmm15,TransKernelFrame.SavedXmm15[rsp] - add rsp,(TransKernelFrame.ReturnAddress) - - BEGIN_EPILOGUE - - ret - - NESTED_END MlasComputeTanhF32KernelFma3, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/TransKernelAvx512F.asm b/onnxruntime/core/mlas/lib/amd64/TransKernelAvx512F.asm deleted file mode 100644 index 2a34d7884c0c6..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/TransKernelAvx512F.asm +++ /dev/null @@ -1,272 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; TransKernelAvx512F.asm -; -; Abstract: -; -; This module implements kernels for various transcendental functions. -; -; This implementation uses AVX512F instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE TransKernelCommon.inc - .list - - EXTERN MlasExpConstants:NEAR - EXTERN MlasOpmask16BitTableAvx512:NEAR - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the exponential function. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasComputeExpF32KernelAvx512F, _TEXT - - lea rax,MlasExpConstants - vbroadcastss zmm21,ExpConstants.LowerRange[rax] - vbroadcastss zmm22,ExpConstants.RoundingBias[rax] - vbroadcastss zmm23,ExpConstants.Log2Reciprocal[rax] - vbroadcastss zmm24,ExpConstants.Log2High[rax] - vbroadcastss zmm25,ExpConstants.Log2Low[rax] - vbroadcastss zmm26,ExpConstants.poly_0[rax] - vbroadcastss zmm27,ExpConstants.poly_1[rax] - vbroadcastss zmm28,ExpConstants.poly_2[rax] - vbroadcastss zmm29,ExpConstants.poly_3[rax] - vbroadcastss zmm30,ExpConstants.poly_4[rax] - vbroadcastss zmm31,ExpConstants.poly_56[rax] - - sub r8,16 - jb ProcessRemainingCount - -ComputeExpBy16Loop: - vmaxps zmm16,zmm21,ZMMWORD PTR [rcx] ; clamp lower bound - vmovaps zmm18,zmm23 - vfmadd213ps zmm18,zmm16,zmm22 ; (input / ln2) plus rounding bias - vmovaps zmm17,zmm26 ; p = poly_0 - vsubps zmm18,zmm18,zmm22 ; m = round(input / ln2) - vfmadd231ps zmm16,zmm18,zmm24 ; range reduce: x -= (m * ln2_high) - vfmadd231ps zmm16,zmm18,zmm25 ; range reduce: x -= (m * ln2_low) - vmovaps zmm17,zmm26 ; p = poly_0 - vfmadd213ps zmm17,zmm16,zmm27 ; p = p * x + poly_1 - vfmadd213ps zmm17,zmm16,zmm28 ; p = p * x + poly_2 - vfmadd213ps zmm17,zmm16,zmm29 ; p = p * x + poly_3 - vfmadd213ps zmm17,zmm16,zmm30 ; p = p * x + poly_4 - vfmadd213ps zmm17,zmm16,zmm31 ; p = p * x + poly_5 - vfmadd213ps zmm17,zmm16,zmm31 ; p = p * x + poly_6 - vscalefps zmm17,zmm17,zmm18 ; scale p with exponent - add rcx,16*4 ; advance input by 16 elements - vmovups ZMMWORD PTR [rdx],zmm17 - add rdx,16*4 ; advance output by 16 elements - sub r8,16 - jae ComputeExpBy16Loop - -ProcessRemainingCount: - add r8,16 ; correct for over-subtract above - jz ExitKernel - lea r10,MlasOpmask16BitTableAvx512 - kmovw k1,WORD PTR [r10+r8*2] - vmaxps zmm16{k1}{z},zmm21,ZMMWORD PTR [rcx] - ; clamp lower bound - vfmadd213ps zmm23,zmm16,zmm22 ; (input / ln2) plus rounding bias - vsubps zmm23,zmm23,zmm22 ; round(input / ln2) - vfmadd231ps zmm16,zmm23,zmm24 ; range reduce: x -= (m * ln2_high) - vfmadd231ps zmm16,zmm23,zmm25 ; range reduce: x -= (m * ln2_low) - vfmadd213ps zmm26,zmm16,zmm27 ; p = p * x + poly_1 - vfmadd213ps zmm26,zmm16,zmm28 ; p = p * x + poly_2 - vfmadd213ps zmm26,zmm16,zmm29 ; p = p * x + poly_3 - vfmadd213ps zmm26,zmm16,zmm30 ; p = p * x + poly_4 - vfmadd213ps zmm26,zmm16,zmm31 ; p = p * x + poly_5 - vfmadd213ps zmm26,zmm16,zmm31 ; p = p * x + poly_6 - vscalefps zmm26,zmm26,zmm23 ; scale p with exponent - vmovups ZMMWORD PTR [rdx]{k1},zmm26 - -ExitKernel: - ret - - LEAF_END MlasComputeExpF32KernelAvx512F, _TEXT - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the sum of exponential -; functions. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Optionally supplies the output buffer. When used for Softmax, -; the output buffer is used to store the intermediate exp() results. When -; used for LogSoftmax, the intermediate exp() results are not required. -; -; N (r8) - Supplies the number of elements to process. -; -; NegativeMaximum (r9) - Supplies the address of the negative maximum value -; that is added to each element before computing the exponential function. -; -; Return Value: -; -; Returns the sum of the exponential functions. -; -;-- - - LEAF_ENTRY MlasComputeSumExpF32KernelAvx512F, _TEXT - - lea rax,MlasExpConstants - vbroadcastss zmm21,ExpConstants.LowerRange[rax] - vbroadcastss zmm22,ExpConstants.RoundingBias[rax] - vbroadcastss zmm23,ExpConstants.Log2Reciprocal[rax] - vbroadcastss zmm24,ExpConstants.Log2High[rax] - vbroadcastss zmm25,ExpConstants.Log2Low[rax] - vbroadcastss zmm26,ExpConstants.poly_0[rax] - vbroadcastss zmm27,ExpConstants.poly_1[rax] - vbroadcastss zmm28,ExpConstants.poly_2[rax] - vbroadcastss zmm29,ExpConstants.poly_3[rax] - vbroadcastss zmm30,ExpConstants.poly_4[rax] - vbroadcastss zmm31,ExpConstants.poly_56[rax] - - vbroadcastss zmm19,DWORD PTR [r9] ; broadcast negative maximum value - vpxord zmm20,zmm20,zmm20 ; clear exp() accumulator - sub r8,48 - jb ProcessRemainingCount - -ComputeExpBy48Loop: - vaddps zmm0,zmm19,ZMMWORD PTR [rcx] ; bias by negative maximum value - vaddps zmm3,zmm19,ZMMWORD PTR [rcx+64] - vaddps zmm16,zmm19,ZMMWORD PTR [rcx+128] - vmaxps zmm0,zmm21,zmm0 ; clamp lower bound - vmovaps zmm2,zmm23 - vmaxps zmm3,zmm21,zmm3 - vmovaps zmm5,zmm23 - vmaxps zmm16,zmm21,zmm16 - vmovaps zmm18,zmm23 - vfmadd213ps zmm2,zmm0,zmm22 ; (input / ln2) plus rounding bias - vfmadd213ps zmm5,zmm3,zmm22 - vfmadd213ps zmm18,zmm16,zmm22 - vmovaps zmm1,zmm26 ; p = poly_0 - vmovaps zmm4,zmm26 - vmovaps zmm17,zmm26 - vsubps zmm2,zmm2,zmm22 ; m = round(input / ln2) - vsubps zmm5,zmm5,zmm22 - vsubps zmm18,zmm18,zmm22 - vfmadd231ps zmm0,zmm2,zmm24 ; range reduce: x -= (m * ln2_high) - vfmadd231ps zmm3,zmm5,zmm24 - vfmadd231ps zmm16,zmm18,zmm24 - vfmadd231ps zmm0,zmm2,zmm25 ; range reduce: x -= (m * ln2_low) - vfmadd231ps zmm3,zmm5,zmm25 - vfmadd231ps zmm16,zmm18,zmm25 - vfmadd213ps zmm1,zmm0,zmm27 ; p = p * x + poly_1 - vfmadd213ps zmm4,zmm3,zmm27 - vfmadd213ps zmm17,zmm16,zmm27 - vfmadd213ps zmm1,zmm0,zmm28 ; p = p * x + poly_2 - vfmadd213ps zmm4,zmm3,zmm28 - vfmadd213ps zmm17,zmm16,zmm28 - vfmadd213ps zmm1,zmm0,zmm29 ; p = p * x + poly_3 - vfmadd213ps zmm4,zmm3,zmm29 - vfmadd213ps zmm17,zmm16,zmm29 - vfmadd213ps zmm1,zmm0,zmm30 ; p = p * x + poly_4 - vfmadd213ps zmm4,zmm3,zmm30 - vfmadd213ps zmm17,zmm16,zmm30 - vfmadd213ps zmm1,zmm0,zmm31 ; p = p * x + poly_5 - vfmadd213ps zmm4,zmm3,zmm31 - vfmadd213ps zmm17,zmm16,zmm31 - vfmadd213ps zmm1,zmm0,zmm31 ; p = p * x + poly_6 - vfmadd213ps zmm4,zmm3,zmm31 - vfmadd213ps zmm17,zmm16,zmm31 - vscalefps zmm1,zmm1,zmm2 - vscalefps zmm4,zmm4,zmm5 - vscalefps zmm17,zmm17,zmm18 - vaddps zmm20,zmm20,zmm1 ; accumulate exp() results - vaddps zmm20,zmm20,zmm4 - vaddps zmm20,zmm20,zmm17 - add rcx,48*4 ; advance input by 48 elements - test rdx,rdx - jz SkipStoreResultsBy48 - vmovups ZMMWORD PTR [rdx],zmm1 - vmovups ZMMWORD PTR [rdx+64],zmm4 - vmovups ZMMWORD PTR [rdx+128],zmm17 - add rdx,48*4 ; advance output by 48 elements - -SkipStoreResultsBy48: - sub r8,48 - jae ComputeExpBy48Loop - -ProcessRemainingCount: - add r8,48 ; correct for over-subtract above - jz ReduceAccumulator - mov eax,-1 - kmovw k1,eax ; update mask to access all elements - -ComputeExpBy16Loop: - cmp r8,16 - jae ProcessSingleVector - lea r10,MlasOpmask16BitTableAvx512 - kmovw k1,WORD PTR [r10+r8*2] - -ProcessSingleVector: - vaddps zmm0{k1}{z},zmm19,ZMMWORD PTR [rcx] - ; bias by negative maximum value - vmaxps zmm0,zmm21,zmm0 ; clamp lower bound - vmovaps zmm2,zmm23 - vfmadd213ps zmm2,zmm0,zmm22 ; (input / ln2) plus rounding bias - vmovaps zmm1,zmm26 ; p = poly_0 - vsubps zmm2,zmm2,zmm22 ; m = round(input / ln2) - vfmadd231ps zmm0,zmm2,zmm24 ; range reduce: x -= (m * ln2_high) - vfmadd231ps zmm0,zmm2,zmm25 ; range reduce: x -= (m * ln2_low) - vfmadd213ps zmm1,zmm0,zmm27 ; p = p * x + poly_1 - vfmadd213ps zmm1,zmm0,zmm28 ; p = p * x + poly_2 - vfmadd213ps zmm1,zmm0,zmm29 ; p = p * x + poly_3 - vfmadd213ps zmm1,zmm0,zmm30 ; p = p * x + poly_4 - vfmadd213ps zmm1,zmm0,zmm31 ; p = p * x + poly_5 - vfmadd213ps zmm1,zmm0,zmm31 ; p = p * x + poly_6 - vscalefps zmm1,zmm1,zmm2 - vaddps zmm20{k1},zmm20,zmm1 ; accumulate exp() results - add rcx,16*4 ; advance input by 16 elements - test rdx,rdx - jz SkipStoreResultsBy16 - vmovups ZMMWORD PTR [rdx]{k1},zmm1 - add rdx,16*4 ; advance output by 16 elements - -SkipStoreResultsBy16: - sub r8,16 - ja ComputeExpBy16Loop - -ReduceAccumulator: - vextractf64x4 ymm0,zmm20,1 - vaddps zmm0,zmm0,zmm20 - vhaddps ymm0,ymm0,ymm0 - vhaddps ymm0,ymm0,ymm0 - vextractf128 xmm1,ymm0,1 - vaddss xmm0,xmm0,xmm1 - - vzeroupper - ret - - LEAF_END MlasComputeSumExpF32KernelAvx512F, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/TransKernelCommon.inc b/onnxruntime/core/mlas/lib/amd64/TransKernelCommon.inc deleted file mode 100644 index 96582fe46c851..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/TransKernelCommon.inc +++ /dev/null @@ -1,111 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; TransKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the -; transcendental functions. -; -;-- - -; -; Structure layout for the exponential function constants block. -; - -ExpConstants STRUCT - - LowerRange DWORD ? - UpperRange DWORD ? - LowerRangeSumExp DWORD ? - UpperRangeSumExp DWORD ? - RoundingBias DWORD ? - Log2Reciprocal DWORD ? - Log2High DWORD ? - Log2Low DWORD ? - poly_0 DWORD ? - poly_1 DWORD ? - poly_2 DWORD ? - poly_3 DWORD ? - poly_4 DWORD ? - poly_56 DWORD ? - MinimumExponent DWORD ? - MaximumExponent DWORD ? - -ExpConstants ENDS - -; -; Structure layout for the logistic constants block. -; - -LogisticConstants STRUCT - - LowerRange DWORD ? - UpperRange DWORD ? - alpha_9 DWORD ? - alpha_7 DWORD ? - alpha_5 DWORD ? - alpha_3 DWORD ? - alpha_1 DWORD ? - beta_10 DWORD ? - beta_8 DWORD ? - beta_6 DWORD ? - beta_4 DWORD ? - beta_2 DWORD ? - beta_0 DWORD ? - one_half DWORD ? - -LogisticConstants ENDS - -; -; Structure layout for the tanh constants block. -; - -TanhConstants STRUCT - - LowerRange DWORD ? - UpperRange DWORD ? - alpha_13 DWORD ? - alpha_11 DWORD ? - alpha_9 DWORD ? - alpha_7 DWORD ? - alpha_5 DWORD ? - alpha_3 DWORD ? - alpha_1 DWORD ? - beta_6 DWORD ? - beta_4 DWORD ? - beta_2 DWORD ? - beta_0 DWORD ? - -TanhConstants ENDS - -; -; Stack frame layout for the transcedental functions. -; - -TransKernelFrame STRUCT - - SavedXmm6 OWORD ? - SavedXmm7 OWORD ? - SavedXmm8 OWORD ? - SavedXmm9 OWORD ? - SavedXmm10 OWORD ? - SavedXmm11 OWORD ? - SavedXmm12 OWORD ? - SavedXmm13 OWORD ? - SavedXmm14 OWORD ? - SavedXmm15 OWORD ? - Padding QWORD ? - ReturnAddress QWORD ? - PreviousP1Home QWORD ? - PreviousP2Home QWORD ? - PreviousP3Home QWORD ? - PreviousP4Home QWORD ? - -TransKernelFrame ENDS diff --git a/onnxruntime/core/mlas/lib/amd64/TransKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/TransKernelFma3.asm deleted file mode 100644 index 993bb0686c4d7..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/TransKernelFma3.asm +++ /dev/null @@ -1,379 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; TransKernelFma3.asm -; -; Abstract: -; -; This module implements kernels for various transcendental functions. -; -; This implementation uses AVX fused multiply/add instructions. -; -;-- - - .xlist -INCLUDE mlasi.inc -INCLUDE TransKernelCommon.inc - .list - - EXTERN MlasMaskMoveTableAvx:NEAR - EXTERN MlasExpConstants:NEAR - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the exponential function. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Supplies the output buffer. -; -; N (r8) - Supplies the number of elements to process. -; -; Return Value: -; -; None. -; -;-- - - NESTED_ENTRY MlasComputeExpF32KernelFma3, _TEXT - - alloc_stack (TransKernelFrame.ReturnAddress) - - save_xmm128 xmm6,TransKernelFrame.SavedXmm6 - save_xmm128 xmm7,TransKernelFrame.SavedXmm7 - save_xmm128 xmm8,TransKernelFrame.SavedXmm8 - save_xmm128 xmm9,TransKernelFrame.SavedXmm9 - save_xmm128 xmm10,TransKernelFrame.SavedXmm10 - save_xmm128 xmm11,TransKernelFrame.SavedXmm11 - save_xmm128 xmm12,TransKernelFrame.SavedXmm12 - save_xmm128 xmm13,TransKernelFrame.SavedXmm13 - save_xmm128 xmm14,TransKernelFrame.SavedXmm14 - save_xmm128 xmm15,TransKernelFrame.SavedXmm15 - - END_PROLOGUE - - lea rax,MlasExpConstants - vbroadcastss ymm4,ExpConstants.LowerRange[rax] - vbroadcastss ymm5,ExpConstants.UpperRange[rax] - vbroadcastss ymm6,ExpConstants.MinimumExponent[rax] - vbroadcastss ymm7,ExpConstants.MaximumExponent[rax] - vbroadcastss ymm8,ExpConstants.RoundingBias[rax] - vbroadcastss ymm9,ExpConstants.Log2Low[rax] - vbroadcastss ymm10,ExpConstants.poly_0[rax] - vbroadcastss ymm11,ExpConstants.poly_1[rax] - vbroadcastss ymm12,ExpConstants.poly_2[rax] - vbroadcastss ymm13,ExpConstants.poly_3[rax] - vbroadcastss ymm14,ExpConstants.poly_4[rax] - vbroadcastss ymm15,ExpConstants.poly_56[rax] - - sub r8,8 - jb ProcessRemainingCount - -ComputeExpBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rcx] ; clamp lower bound - vbroadcastss ymm2,ExpConstants.Log2Reciprocal[rax] - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vbroadcastss ymm3,ExpConstants.Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm8 ; (x / ln2) plus rounding bias - vsubps ymm1,ymm2,ymm8 ; m = round(x / ln2) - vfmadd231ps ymm0,ymm1,ymm3 ; range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm9 ; range reduce: x -= (m * ln2_low) - vmovaps ymm1,ymm10 ; p = poly_0 - vfmadd213ps ymm1,ymm0,ymm11 ; p = p * x + poly_1 - vpslld ymm2,ymm2,23 ; shift m to exponent field - vfmadd213ps ymm1,ymm0,ymm12 ; p = p * x + poly_2 - vpminsd ymm3,ymm2,ymm7 ; clamp upper normal exponent to +127 - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_3 - vpmaxsd ymm3,ymm3,ymm6 ; clamp lower normal exponent to -126 - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_4 - vpsubd ymm2,ymm2,ymm3 ; compute overflow exponent - vpaddd ymm3,ymm3,ymm7 ; add exponent bias to normal scale - vpaddd ymm2,ymm2,ymm7 ; add exponent bias to overflow scale - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_56 - vmulps ymm0,ymm0,ymm2 ; scale x with overflow exponent - vfmadd213ps ymm1,ymm0,ymm2 ; p = p * (x * overflow) + overflow - vmulps ymm1,ymm1,ymm3 ; scale p with normal exponent - add rcx,8*4 ; advance input by 8 elements - vmovups YMMWORD PTR [rdx],ymm1 - add rdx,8*4 ; advance output by 8 elements - sub r8,8 - jae ComputeExpBy8Loop - -ProcessRemainingCount: - add r8,8 ; correct for over-subtract above - jz ExitKernel - neg r8 - lea r10,MlasMaskMoveTableAvx+8*4 - vmovups ymm2,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rcx] - vmaxps ymm0,ymm4,ymm0 ; clamp lower bound - vbroadcastss ymm4,ExpConstants.Log2Reciprocal[rax] - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vbroadcastss ymm3,ExpConstants.Log2High[rax] - vfmadd213ps ymm4,ymm0,ymm8 ; (x / ln2) plus rounding bias - vsubps ymm1,ymm4,ymm8 ; m = round(x / ln2) - vfmadd231ps ymm0,ymm1,ymm3 ; range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm9 ; range reduce: x -= (m * ln2_low) - vmovaps ymm1,ymm10 ; p = poly_0 - vfmadd213ps ymm1,ymm0,ymm11 ; p = p * x + poly_1 - vpslld ymm4,ymm4,23 ; shift m to exponent field - vfmadd213ps ymm1,ymm0,ymm12 ; p = p * x + poly_2 - vpminsd ymm3,ymm4,ymm7 ; clamp upper normal exponent to +127 - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_3 - vpmaxsd ymm3,ymm3,ymm6 ; clamp lower normal exponent to -126 - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_4 - vpsubd ymm4,ymm4,ymm3 ; compute overflow exponent - vpaddd ymm3,ymm3,ymm7 ; add exponent bias to normal scale - vpaddd ymm4,ymm4,ymm7 ; add exponent bias to overflow scale - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_5 - vmulps ymm0,ymm0,ymm4 ; scale x with overflow exponent - vfmadd213ps ymm1,ymm0,ymm4 ; p = p * (x * overflow) + overflow - vmulps ymm1,ymm1,ymm3 ; scale p with normal exponent - vmaskmovps YMMWORD PTR [rdx],ymm2,ymm1 - -ExitKernel: - vzeroupper - movaps xmm6,TransKernelFrame.SavedXmm6[rsp] - movaps xmm7,TransKernelFrame.SavedXmm7[rsp] - movaps xmm8,TransKernelFrame.SavedXmm8[rsp] - movaps xmm9,TransKernelFrame.SavedXmm9[rsp] - movaps xmm10,TransKernelFrame.SavedXmm10[rsp] - movaps xmm11,TransKernelFrame.SavedXmm11[rsp] - movaps xmm12,TransKernelFrame.SavedXmm12[rsp] - movaps xmm13,TransKernelFrame.SavedXmm13[rsp] - movaps xmm14,TransKernelFrame.SavedXmm14[rsp] - movaps xmm15,TransKernelFrame.SavedXmm15[rsp] - add rsp,(TransKernelFrame.ReturnAddress) - - BEGIN_EPILOGUE - - ret - - NESTED_END MlasComputeExpF32KernelFma3, _TEXT - -;++ -; -; Routine Description: -; -; This routine implements a vectorized kernel for the sum of exponential -; functions. -; -; Arguments: -; -; Input (rcx) - Supplies the input buffer. -; -; Output (rdx) - Optionally supplies the output buffer. When used for Softmax, -; the output buffer is used to store the intermediate exp() results. When -; used for LogSoftmax, the intermediate exp() results are not required. -; -; N (r8) - Supplies the number of elements to process. -; -; NegativeMaximum (r9) - Supplies the address of the negative maximum value -; that is added to each element before computing the exponential function. -; -; Return Value: -; -; Returns the sum of the exponential functions. -; -;-- - - NESTED_ENTRY MlasComputeSumExpF32KernelFma3, _TEXT - - alloc_stack (TransKernelFrame.ReturnAddress) - - save_xmm128 xmm6,TransKernelFrame.SavedXmm6 - save_xmm128 xmm7,TransKernelFrame.SavedXmm7 - save_xmm128 xmm8,TransKernelFrame.SavedXmm8 - save_xmm128 xmm9,TransKernelFrame.SavedXmm9 - save_xmm128 xmm10,TransKernelFrame.SavedXmm10 - save_xmm128 xmm11,TransKernelFrame.SavedXmm11 - save_xmm128 xmm12,TransKernelFrame.SavedXmm12 - save_xmm128 xmm13,TransKernelFrame.SavedXmm13 - save_xmm128 xmm14,TransKernelFrame.SavedXmm14 - save_xmm128 xmm15,TransKernelFrame.SavedXmm15 - - END_PROLOGUE - - lea rax,MlasExpConstants - vbroadcastss ymm9,DWORD PTR [r9] ; broadcast negative maximum value - vxorps xmm10,xmm10,xmm10 ; clear exp() accumulator - sub r8,24 - jb ProcessRemainingCount - -ComputeExpBy24Loop: - vbroadcastss ymm11,ExpConstants.LowerRangeSumExp[rax] - vbroadcastss ymm2,ExpConstants.Log2Reciprocal[rax] - vaddps ymm0,ymm9,YMMWORD PTR [rcx] ; bias by negative maximum value - vaddps ymm3,ymm9,YMMWORD PTR [rcx+32] - vaddps ymm6,ymm9,YMMWORD PTR [rcx+64] - vbroadcastss ymm15,ExpConstants.RoundingBias[rax] - vmaxps ymm0,ymm11,ymm0 ; clamp lower bound - vmovaps ymm5,ymm2 - vmaxps ymm3,ymm11,ymm3 - vmovaps ymm8,ymm2 - vmaxps ymm6,ymm11,ymm6 - vbroadcastss ymm13,ExpConstants.Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm15 ; (x / ln2) plus rounding bias - vfmadd213ps ymm5,ymm3,ymm15 - vfmadd213ps ymm8,ymm6,ymm15 - vbroadcastss ymm14,ExpConstants.Log2Low[rax] - vsubps ymm1,ymm2,ymm15 ; m = round(x / ln2) - vsubps ymm4,ymm5,ymm15 - vsubps ymm7,ymm8,ymm15 - vfmadd231ps ymm0,ymm1,ymm13 ; range reduce: x -= (m * ln2_high) - vfmadd231ps ymm3,ymm4,ymm13 - vfmadd231ps ymm6,ymm7,ymm13 - vfmadd231ps ymm0,ymm1,ymm14 ; range reduce: x -= (m * ln2_low) - vfmadd231ps ymm3,ymm4,ymm14 - vfmadd231ps ymm6,ymm7,ymm14 - vbroadcastss ymm1,ExpConstants.poly_0[rax] - vbroadcastss ymm13,ExpConstants.poly_1[rax] - vmovaps ymm4,ymm1 - vmovaps ymm7,ymm1 - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_1 - vfmadd213ps ymm4,ymm3,ymm13 - vfmadd213ps ymm7,ymm6,ymm13 - vbroadcastss ymm14,ExpConstants.poly_2[rax] - vpslld ymm2,ymm2,23 ; shift m to exponent field - vpslld ymm5,ymm5,23 - vpslld ymm8,ymm8,23 - vbroadcastss ymm15,ExpConstants.MaximumExponent[rax] - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_2 - vfmadd213ps ymm4,ymm3,ymm14 - vfmadd213ps ymm7,ymm6,ymm14 - vbroadcastss ymm13,ExpConstants.poly_3[rax] - vpaddd ymm2,ymm2,ymm15 ; add exponent bias to scale - vpaddd ymm5,ymm5,ymm15 - vpaddd ymm8,ymm8,ymm15 - vbroadcastss ymm14,ExpConstants.poly_4[rax] - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_3 - vfmadd213ps ymm4,ymm3,ymm13 - vfmadd213ps ymm7,ymm6,ymm13 - vbroadcastss ymm15,ExpConstants.poly_56[rax] - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_4 - vfmadd213ps ymm4,ymm3,ymm14 - vfmadd213ps ymm7,ymm6,ymm14 - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_5 - vfmadd213ps ymm4,ymm3,ymm15 - vfmadd213ps ymm7,ymm6,ymm15 - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_6 - vfmadd213ps ymm4,ymm3,ymm15 - vfmadd213ps ymm7,ymm6,ymm15 - vmulps ymm1,ymm1,ymm2 ; scale p with exponent - vmulps ymm4,ymm4,ymm5 - vaddps ymm10,ymm10,ymm1 ; accumulate exp() results - vmulps ymm7,ymm7,ymm8 - vaddps ymm10,ymm10,ymm4 - add rcx,24*4 ; advance input by 24 elements - vaddps ymm10,ymm10,ymm7 - test rdx,rdx - jz SkipStoreResultsBy24 - vmovups YMMWORD PTR [rdx],ymm1 - vmovups YMMWORD PTR [rdx+32],ymm4 - vmovups YMMWORD PTR [rdx+64],ymm7 - add rdx,24*4 ; advance output by 24 elements - -SkipStoreResultsBy24: - sub r8,24 - jae ComputeExpBy24Loop - -ProcessRemainingCount: - add r8,24 ; correct for over-subtract above - jz ReduceAccumulator - vbroadcastss ymm11,ExpConstants.LowerRangeSumExp[rax] - -ComputeExpBy8Loop: - cmp r8,8 ; remaining count < 8? - jb LoadPartialVector - vmovups ymm0,YMMWORD PTR [rcx] - jmp ProcessSingleVector - -LoadPartialVector: - lea r10,MlasMaskMoveTableAvx+8*4 - neg r8 ; carry flag unchanged - vmovups ymm3,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] - vandps ymm9,ymm9,ymm3 ; mask unused maximum value to 0.0 - -ProcessSingleVector: - vbroadcastss ymm2,ExpConstants.Log2Reciprocal[rax] - vaddps ymm0,ymm9,ymm0 ; bias by negative maximum value - vbroadcastss ymm15,ExpConstants.RoundingBias[rax] - vmaxps ymm0,ymm11,ymm0 ; clamp lower bound - vbroadcastss ymm13,ExpConstants.Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm15 ; (input / ln2) plus rounding bias - vbroadcastss ymm14,ExpConstants.Log2Low[rax] - vsubps ymm1,ymm2,ymm15 ; round(input / ln2) - vfmadd231ps ymm0,ymm1,ymm13 ; range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm14 ; range reduce: x -= (m * ln2_low) - vbroadcastss ymm1,ExpConstants.poly_0[rax] - vbroadcastss ymm13,ExpConstants.poly_1[rax] - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_1 - vbroadcastss ymm14,ExpConstants.poly_2[rax] - vpslld ymm2,ymm2,23 ; shift m to exponent field - vbroadcastss ymm15,ExpConstants.MaximumExponent[rax] - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_2 - vbroadcastss ymm13,ExpConstants.poly_3[rax] - vpaddd ymm2,ymm2,ymm15 ; add exponent bias to scale - vbroadcastss ymm14,ExpConstants.poly_4[rax] - vfmadd213ps ymm1,ymm0,ymm13 ; p = p * x + poly_3 - vbroadcastss ymm15,ExpConstants.poly_56[rax] - vfmadd213ps ymm1,ymm0,ymm14 ; p = p * x + poly_4 - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_5 - vfmadd213ps ymm1,ymm0,ymm15 ; p = p * x + poly_6 - vmulps ymm1,ymm1,ymm2 - jb StorePartialVector ; remaining count < 8? - vaddps ymm10,ymm10,ymm1 ; accumulate exp() results - test rdx,rdx ; store exp() results? - jz SkipStoreResultsBy8 - vmovups YMMWORD PTR [rdx],ymm1 - add rdx,8*4 ; advance output by 8 elements - -SkipStoreResultsBy8: - add rcx,8*4 ; advance input by 8 elements - sub r8,8 - jnz ComputeExpBy8Loop - jmp ReduceAccumulator - -StorePartialVector: - vandps ymm1,ymm1,ymm3 ; mask unused exp() results to 0.0 - vaddps ymm10,ymm10,ymm1 ; accumulate exp() results - test rdx,rdx ; store exp() results? - jz ReduceAccumulator - vmaskmovps YMMWORD PTR [rdx],ymm3,ymm1 - -ReduceAccumulator: - vhaddps ymm10,ymm10,ymm10 - vhaddps ymm10,ymm10,ymm10 - vextractf128 xmm0,ymm10,1 - vaddss xmm0,xmm0,xmm10 - - vzeroupper - movaps xmm6,TransKernelFrame.SavedXmm6[rsp] - movaps xmm7,TransKernelFrame.SavedXmm7[rsp] - movaps xmm8,TransKernelFrame.SavedXmm8[rsp] - movaps xmm9,TransKernelFrame.SavedXmm9[rsp] - movaps xmm10,TransKernelFrame.SavedXmm10[rsp] - movaps xmm11,TransKernelFrame.SavedXmm11[rsp] - movaps xmm12,TransKernelFrame.SavedXmm12[rsp] - movaps xmm13,TransKernelFrame.SavedXmm13[rsp] - movaps xmm14,TransKernelFrame.SavedXmm14[rsp] - movaps xmm15,TransKernelFrame.SavedXmm15[rsp] - add rsp,(TransKernelFrame.ReturnAddress) - - BEGIN_EPILOGUE - - ret - - NESTED_END MlasComputeSumExpF32KernelFma3, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm deleted file mode 100644 index 800863c77a230..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm +++ /dev/null @@ -1,153 +0,0 @@ -;++ -; -; Copyright (c) Intel Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; cvtfp16Avx2.asm -; -; Abstract: -; -; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - .const - -SINGLE_SIZE equ 4 -HALF_SIZE equ 2 -LOW_SELECTOR equ 00100000b -HIGH_SELECTOR equ 00110001b - - SUBTTL "Convert buffer of half-precision floats to single-precision floats" -;++ -; -; Routine Description: -; -; This routine converts the source buffer of half-precision floats to the -; destination buffer of single-precision floats. -; -; This implementation uses AVX2 instructions. -; -; Arguments: -; -; Source (rcx) - Supplies the address of the source buffer of half-precision -; floats. -; -; Destination (rdx) - Supplies the address of the destination buffer of -; single-precision floats. -; -; Count (r8) - Supplies the number of elements to convert. -; -; Return Value: -; -; None. -; -;-- - - -LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT - - test r8, r8 ; Check if we have any elements to convert - jz ExitRoutine - cmp r8, 8 - jb ConvertMaskedVectors - cmp r8, 16 - jb Convert128Vectors - - - -Convert256Vectors: - vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes - vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes - vunpcklps ymm2, ymm0, ymm1 ; Interleave low part - vunpckhps ymm1, ymm0, ymm1 ; Interleave high part - vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order - vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order - vmovups ymmword PTR [rdx], ymm0 ; Store the low part - vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part - - add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements - add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements - sub r8, 16 ; Reduce the counter by 16 elements - - jz ExitRoutine ; If we are done, exit - cmp r8, 16 ; If the vector is big enough, we go again - jae Convert256Vectors - cmp r8, 8 ; Check if we have enough elements to convert - jb ConvertMaskedVectors - - - -Convert128Vectors: - vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes - vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order - vmovups xmmword PTR [rdx], xmm0 ; Store the low part - vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part - - add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements - add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements - sub r8, 8 ; Reduce the counter by 8 elements - - jz ExitRoutine ; If we are done, exit - - - -ConvertMaskedVectors: - vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes - vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order - - cmp r8, 4 ; Check if we can store the complete lower vector - jae ConvertLowerVector - - vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones - cmp r8, 2 ; Check how many converts we need - jb ConvertLower1 - ja ConvertLower3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values - jmp ConvertLowerMaskedVector -ConvertLower1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value - jmp ConvertLowerMaskedVector -ConvertLower3: - vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values -ConvertLowerMaskedVector: - vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples - jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing -ConvertLowerVector: - vmovups xmmword PTR [rdx], xmm0 ; Store the low part - sub r8, 4 ; Check if we still need to convert - jz ExitRoutine - - - add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements - vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones - cmp r8, 2 ; Check how many converts we need - jb ConvertUpper1 - ja ConvertUpper3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values - jmp ConvertMaskedUpperVector -ConvertUpper1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value - jmp ConvertMaskedUpperVector -ConvertUpper3: - vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values -ConvertMaskedUpperVector: - vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples - -ExitRoutine: - ret - - LEAF_END MlasCastF16ToF32KernelAvx, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm deleted file mode 100644 index 0ad98d3115208..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm +++ /dev/null @@ -1,124 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; cvtfp16a.asm -; -; Abstract: -; -; This module implements routines to convert between FP16 and FP32 formats. -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - - .const - - ALIGN 16 -MlasFp16MaskSign DD 4 DUP (00007FFFh) -MlasFp16CompareInfinity DD 4 DUP (00007C00h) -MlasFp16CompareSmallest DD 4 DUP (00000400h) -MlasFp16AdjustExponent DD 4 DUP (38000000h) -MlasFp16MagicDenormal DD 4 DUP (38800000h) - - SUBTTL "Convert buffer of half-precision floats to single-precision floats" -;++ -; -; Routine Description: -; -; This routine converts the source buffer of half-precision floats to the -; destination buffer of single-precision floats. -; -; This implementation uses SSE2 instructions. -; -; Arguments: -; -; Source (rcx) - Supplies the address of the source buffer of half-precision -; floats. -; -; Destination (rdx) - Supplies the address of the destination buffer of -; single-precision floats. -; -; Count (r8) - Supplies the number of elements to convert. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT - - test r8,r8 - jz ExitRoutine - cmp r8,4 - jb LoadPartialVector - -LoadFullVector: - movq xmm0,QWORD PTR [rcx] - add rcx,4*2 ; advance S by 4 elements - -ConvertHalfToFloat: - punpcklwd xmm0,xmm0 ; duplicate 4 WORDs to 4 DWORDs - movaps xmm1,xmm0 ; isolate exponent/mantissa - pand xmm1,XMMWORD PTR [MlasFp16MaskSign] - pxor xmm0,xmm1 ; isolate sign bit - movaps xmm2,XMMWORD PTR [MlasFp16CompareInfinity] - pcmpgtd xmm2,xmm1 ; test for infinity/NaNs - movaps xmm3,XMMWORD PTR [MlasFp16CompareSmallest] - pcmpgtd xmm3,xmm1 ; test for denormals - pandn xmm2,XMMWORD PTR [MlasFp16AdjustExponent] - pslld xmm1,13 ; shift exponent/mask into place - movaps xmm4,xmm1 - paddd xmm1,XMMWORD PTR [MlasFp16AdjustExponent] - paddd xmm1,xmm2 ; adjust exponent again for infinity/NaNs - paddd xmm4,XMMWORD PTR [MlasFp16MagicDenormal] - pslld xmm0,16 ; shift sign into place - subps xmm4,XMMWORD PTR [MlasFp16MagicDenormal] - pand xmm4,xmm3 ; select elements that are denormals - pandn xmm3,xmm1 ; select elements that are not denormals - por xmm3,xmm4 ; blend the selected values together - por xmm0,xmm3 ; merge sign into exponent/mantissa - - cmp r8,4 ; storing full vector? - jb StorePartialVector - movups XMMWORD PTR [rdx],xmm0 - add rdx,4*4 ; advance D by 4 elements - sub r8,4 - jz ExitRoutine - cmp r8,4 - jae LoadFullVector - -LoadPartialVector: - pxor xmm0,xmm0 - pinsrw xmm0,WORD PTR [rcx],0 - cmp r8,2 - jb ConvertHalfToFloat - pinsrw xmm0,WORD PTR [rcx+2],1 - je ConvertHalfToFloat - pinsrw xmm0,WORD PTR [rcx+4],2 - jmp ConvertHalfToFloat - -StorePartialVector: - cmp r8,2 - jb StoreLastElement - movsd QWORD PTR [rdx],xmm0 - je ExitRoutine - movhlps xmm0,xmm0 ; shift third element down - add rdx,4*2 ; advance D by 2 elements - -StoreLastElement: - movss DWORD PTR [rdx],xmm0 - -ExitRoutine: - ret - - LEAF_END MlasCastF16ToF32KernelSse, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amd64/mlasi.inc b/onnxruntime/core/mlas/lib/amd64/mlasi.inc deleted file mode 100644 index 2db3147168727..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/mlasi.inc +++ /dev/null @@ -1,115 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; mlasi.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the Microsoft -; Machine Learning algebra subprogram library. -; -;-- - - .xlist -INCLUDE macamd64.inc - .list - -; -; Macro Description: -; -; This macro generates an optimization for "add reg,128" which can instead -; be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit -; value. -; -; Arguments: -; -; Register - Supplies the register to be added to. -; -; Immediate - Supplies the immediate to add to the register. -; - -add_immed MACRO Register, Immediate - -IF (Immediate NE 128) - add Register,Immediate -ELSE - sub Register,-Immediate ; smaller encoding -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro conditionally emits the statement if Count is greater than or -; equal to Value. -; -; Arguments: -; -; Count - Supplies the variable used in the comparison. -; -; Value - Supplies the static used in the comparison. -; -; Statement - Supplies the statement to conditionally emit. -; - -EmitIfCountGE MACRO Count, Value, Statement - -IF (Count GE Value) - Statement -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro conditionally emits the statement if Count1 is greater than or -; equal to Value1 and Count2 is greater than or equal to Value2. -; -; Arguments: -; -; Count1 - Supplies the variable used in the comparison. -; -; Value1 - Supplies the static used in the comparison. -; -; Count2 - Supplies the variable used in the comparison. -; -; Value2 - Supplies the static used in the comparison. -; -; Statement - Supplies the statement to conditionally emit. -; - -EmitIfCount2GE MACRO Count1, Value1, Count2, Value2, Statement - -IF (Count1 GE Value1) AND (Count2 GE Value2) - Statement -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro emits the statement for each register listed in the register -; list. The statement can use RegItem to access the current register. -; -; Arguments: -; -; RegList - Supplies the list of registers. -; -; Statement - Supplies the statement to emit. -; - -EmitForEachRegister MACRO RegList, Statement - -IRP RegItem, - Statement -ENDM - - ENDM diff --git a/onnxruntime/core/mlas/lib/amd64/sgemma.asm b/onnxruntime/core/mlas/lib/amd64/sgemma.asm deleted file mode 100644 index 80805d337e500..0000000000000 --- a/onnxruntime/core/mlas/lib/amd64/sgemma.asm +++ /dev/null @@ -1,181 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; sgemma.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -;-- - - .xlist -INCLUDE mlasi.inc - .list - -;++ -; -; Routine Description: -; -; This routine transposes elements from the source matrix to the destination -; packed buffer. -; -; 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 -; rows in the destination packed buffer. -; -; This implementation uses SSE2 instructions. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; B (rdx) - Supplies the address of the source matrix. -; -; ldb (r8d) - Supplies the number of elements per row of the source matrix. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasSgemmTransposePackB16x4Sse, _TEXT - - shl r8,2 ; convert ldb to bytes - mov r9d,4 ; transpose four 4x4 blocks - -TransposeBlockLoop: - lea rax,[rdx+r8*2] - movups xmm0,XMMWORD PTR [rdx] - movups xmm1,XMMWORD PTR [rdx+r8] - movups xmm2,XMMWORD PTR [rax] - movups xmm3,XMMWORD PTR [rax+r8] - movaps xmm4,xmm0 - unpcklps xmm4,xmm1 - unpckhps xmm0,xmm1 - movaps xmm5,xmm2 - unpcklps xmm5,xmm3 - unpckhps xmm2,xmm3 - movaps xmm1,xmm4 - unpcklpd xmm1,xmm5 - unpckhpd xmm4,xmm5 - movaps xmm3,xmm0 - unpcklpd xmm3,xmm2 - unpckhpd xmm0,xmm2 - movaps XMMWORD PTR [rcx+16*4*0],xmm1 - movaps XMMWORD PTR [rcx+16*4*1],xmm4 - movaps XMMWORD PTR [rcx+16*4*2],xmm3 - movaps XMMWORD PTR [rcx+16*4*3],xmm0 - add rcx,4*4 - lea rdx,[rax+r8*2] - dec r9d - jnz TransposeBlockLoop - ret - - LEAF_END MlasSgemmTransposePackB16x4Sse, _TEXT - -; -; Transpose8x4BlockAvx -; -; 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 -; rows in the destination packed buffer. -; -; This implementation uses AVX instructions. -; -; Arguments: -; -; StoreOffset - Supplies the relative byte offset into the destination packed -; buffer. -; -; Implicit Arguments: -; -; rcx - Supplies the address of the destination packed buffer. -; -; rdx - Supplies the address of the source matrix. -; -; r8 - Supplies the number of elements per row of the source matrix. -; - -TransposePackB8x4BlockAvx MACRO StoreOffset - -; -; Load 4 columns from 8 rows of the source matrix into the lower and upper -; halves of 4 YMM registers. -; - - lea rax,[rdx+r8*2] - vmovups xmm0,XMMWORD PTR [rdx] - vmovups xmm1,XMMWORD PTR [rdx+r8] - lea rdx,[rax+r8*2] - vmovups xmm2,XMMWORD PTR [rax] - vmovups xmm3,XMMWORD PTR [rax+r8] - lea rax,[rdx+r8*2] - vinsertf128 ymm0,ymm0,XMMWORD PTR [rdx],1 - vinsertf128 ymm1,ymm1,XMMWORD PTR [rdx+r8],1 - vinsertf128 ymm2,ymm2,XMMWORD PTR [rax],1 - vinsertf128 ymm3,ymm3,XMMWORD PTR [rax+r8],1 - -; -; Transpose the lower and upper halves of the 4 YMM registers as two 4x4 -; matrices and store the output to the destination packed buffer. -; - - vunpcklps ymm4,ymm0,ymm1 - vunpckhps ymm5,ymm0,ymm1 - vunpcklps ymm0,ymm2,ymm3 - vunpckhps ymm1,ymm2,ymm3 - vunpcklpd ymm2,ymm4,ymm0 - vunpckhpd ymm3,ymm4,ymm0 - vmovaps YMMWORD PTR [rcx+16*4*0+StoreOffset],ymm2 - vmovaps YMMWORD PTR [rcx+16*4*1+StoreOffset],ymm3 - vunpcklpd ymm0,ymm5,ymm1 - vunpckhpd ymm4,ymm5,ymm1 - vmovaps YMMWORD PTR [rcx+16*4*2+StoreOffset],ymm0 - vmovaps YMMWORD PTR [rcx+16*4*3+StoreOffset],ymm4 - - ENDM - -;++ -; -; Routine Description: -; -; This routine transposes elements from the source matrix to the destination -; packed buffer. -; -; 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 -; rows in the destination packed buffer. -; -; This implementation uses AVX instructions. -; -; Arguments: -; -; D (rcx) - Supplies the address of the destination packed buffer. -; -; B (rdx) - Supplies the address of the source matrix. -; -; ldb (r8d) - Supplies the number of elements per row of the source matrix. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasSgemmTransposePackB16x4Avx, _TEXT - - shl r8,2 ; convert ldb to bytes - TransposePackB8x4BlockAvx 0*4 - lea rdx,[rax+r8*2] - TransposePackB8x4BlockAvx 8*4 - vzeroupper - ret - - LEAF_END MlasSgemmTransposePackB16x4Avx, _TEXT - - END diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h deleted file mode 100644 index caf94af02362d..0000000000000 --- a/onnxruntime/core/mlas/lib/amx_common.h +++ /dev/null @@ -1,80 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - amx_common.h - -Abstract: - - Intrinsic and inline functions for amx processing. - ---*/ - -#pragma once - -#include "mlasi.h" - -#ifdef _WIN32 -#define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) - -#define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) - -#define tile_dpbusd(dst, src1, src2) _tile_dpbusd(dst, src1, src2) - -#define tile_dpbuud(dst, src1, src2) _tile_dpbuud(dst, src1, src2) - -#define tile_loadd(dst, base, stride) _tile_loadd(dst, base, stride) - -#define tile_stream_loadd(dst, base, stride) _tile_stream_loadd(dst, base, stride) - -#define tile_stored(dst, base, stride) _tile_stored(dst, base, stride) - -#define tile_loadconfig(config) \ - _tile_loadconfig(config) - -#define tile_storeconfig(config) _tile_storeconfig(config) - -#else - -#define tile_dpbusd_internal(dst,src1,src2) \ -__asm__ volatile (".set Payload1, 0x01\n\t" \ - ".set Payload1, Payload1 + (("#src2" & 15) ^ 15) << 3\n\t" \ - ".set ModRMByte, 0xC0\n\t" \ - ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ - ".set ModRMByte, ModRMByte + ("#src1")\n\t" \ - ".byte 0xC4, 0xE2, Payload1, 0x5E, ModRMByte\n\t") - -#define tile_dpbusd(dst,src1,src2) \ -tile_dpbusd_internal(dst,src1,src2) - -#define tile_loadd_internal1(dst,base,stride) \ - __asm__ volatile (".set ModRMByte, 0x04\n\t" \ - ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ - ".byte 0xC4, 0xE2, 0x7B, 0x4B, ModRMByte, 0x18\n\t" \ - :: "a" ((const void*) (base)), "b" ((long) (stride))) - -#define tile_loadd(dst,base,stride) \ - tile_loadd_internal1(dst, base, stride) - - -#define tile_stored_internal1(dst,base,stride) \ - __asm__ volatile (".set ModRMByte, 0x04\n\t" \ - ".set ModRMByte, ModRMByte + ("#dst" << 3)\n\t" \ - ".byte 0xC4, 0xE2, 0x7A, 0x4B, ModRMByte, 0x18\n\t" \ - :: "a" ((const void*) (base)), "b" ((long) (stride))) - -#define tile_stored(dst,base,stride) \ -tile_stored_internal1(dst, base, stride) - - -#define tile_loadconfig(config) \ -__asm__ volatile (".byte 0xC4, 0xE2, 0x78, 0x49, 0x00" :: "a" (((const void *)config))) \ - -#define tile_storeconfig(config) \ -__asm__ volatile (".byte 0xC4, 0xE2, 0x79, 0x49, 0x00" :: "a" (((const void *)config))) \ - -#endif diff --git a/onnxruntime/core/mlas/lib/arm/sgemmc.cpp b/onnxruntime/core/mlas/lib/arm/sgemmc.cpp deleted file mode 100644 index 200ec331176a6..0000000000000 --- a/onnxruntime/core/mlas/lib/arm/sgemmc.cpp +++ /dev/null @@ -1,531 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sgemmc.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "mlasi.h" - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - float32x4_t Row0Block0; - float32x4_t Row0Block1; - float32x4_t Row0Block2; - float32x4_t Row0Block3; - - float32x4_t Row1Block0; - float32x4_t Row1Block1; - float32x4_t Row1Block2; - float32x4_t Row1Block3; - -#if defined(_WIN32) - - if (!ProcessTwoRows) { - UNREFERENCED_PARAMETER(lda); - UNREFERENCED_PARAMETER(ldc); - } - -#endif - - do { - - float32x4_t BElements0; - float32x4_t BElements1; - float32x4_t BElements2; - float32x4_t BElements3; - - float32x2_t Row0AElements; - float32x2_t Row1AElements; - - // - // Clear the block accumulators. - // - - Row0Block0 = vdupq_n_f32(0.0f); - Row0Block1 = vdupq_n_f32(0.0f); - Row0Block2 = vdupq_n_f32(0.0f); - Row0Block3 = vdupq_n_f32(0.0f); - - if (ProcessTwoRows) { - Row1Block0 = vdupq_n_f32(0.0f); - Row1Block1 = vdupq_n_f32(0.0f); - Row1Block2 = vdupq_n_f32(0.0f); - Row1Block3 = vdupq_n_f32(0.0f); - } - - // - // Compute the 16x1 or 16x2 output block. - // - - const float* a = A; - size_t k = CountK; - - while (k >= 2) { - - Row0AElements = vld1_f32(a); - - if (ProcessTwoRows) { - Row1AElements = vld1_f32(a + lda); - } - - BElements0 = vld1q_f32(B + 0); - BElements1 = vld1q_f32(B + 4); - BElements2 = vld1q_f32(B + 8); - BElements3 = vld1q_f32(B + 12); - - Row0Block0 = vmlaq_lane_f32(Row0Block0, BElements0, Row0AElements, 0); - Row0Block1 = vmlaq_lane_f32(Row0Block1, BElements1, Row0AElements, 0); - Row0Block2 = vmlaq_lane_f32(Row0Block2, BElements2, Row0AElements, 0); - Row0Block3 = vmlaq_lane_f32(Row0Block3, BElements3, Row0AElements, 0); - - if (ProcessTwoRows) { - Row1Block0 = vmlaq_lane_f32(Row1Block0, BElements0, Row1AElements, 0); - Row1Block1 = vmlaq_lane_f32(Row1Block1, BElements1, Row1AElements, 0); - Row1Block2 = vmlaq_lane_f32(Row1Block2, BElements2, Row1AElements, 0); - Row1Block3 = vmlaq_lane_f32(Row1Block3, BElements3, Row1AElements, 0); - } - - BElements0 = vld1q_f32(B + 16); - BElements1 = vld1q_f32(B + 20); - BElements2 = vld1q_f32(B + 24); - BElements3 = vld1q_f32(B + 28); - - Row0Block0 = vmlaq_lane_f32(Row0Block0, BElements0, Row0AElements, 1); - Row0Block1 = vmlaq_lane_f32(Row0Block1, BElements1, Row0AElements, 1); - Row0Block2 = vmlaq_lane_f32(Row0Block2, BElements2, Row0AElements, 1); - Row0Block3 = vmlaq_lane_f32(Row0Block3, BElements3, Row0AElements, 1); - - if (ProcessTwoRows) { - Row1Block0 = vmlaq_lane_f32(Row1Block0, BElements0, Row1AElements, 1); - Row1Block1 = vmlaq_lane_f32(Row1Block1, BElements1, Row1AElements, 1); - Row1Block2 = vmlaq_lane_f32(Row1Block2, BElements2, Row1AElements, 1); - Row1Block3 = vmlaq_lane_f32(Row1Block3, BElements3, Row1AElements, 1); - } - - a += 2; - B += 32; - k -= 2; - } - - if (k > 0) { - - Row0AElements = vld1_dup_f32(a); - - if (ProcessTwoRows) { - Row1AElements = vld1_dup_f32(a + lda); - } - - BElements0 = vld1q_f32(B + 0); - BElements1 = vld1q_f32(B + 4); - BElements2 = vld1q_f32(B + 8); - BElements3 = vld1q_f32(B + 12); - - Row0Block0 = vmlaq_lane_f32(Row0Block0, BElements0, Row0AElements, 0); - Row0Block1 = vmlaq_lane_f32(Row0Block1, BElements1, Row0AElements, 0); - Row0Block2 = vmlaq_lane_f32(Row0Block2, BElements2, Row0AElements, 0); - Row0Block3 = vmlaq_lane_f32(Row0Block3, BElements3, Row0AElements, 0); - - if (ProcessTwoRows) { - Row1Block0 = vmlaq_lane_f32(Row1Block0, BElements0, Row1AElements, 0); - Row1Block1 = vmlaq_lane_f32(Row1Block1, BElements1, Row1AElements, 0); - Row1Block2 = vmlaq_lane_f32(Row1Block2, BElements2, Row1AElements, 0); - Row1Block3 = vmlaq_lane_f32(Row1Block3, BElements3, Row1AElements, 0); - } - - B += 16; - } - - // - // Multiply by the alpha value. - // - - Row0Block0 = vmulq_n_f32(Row0Block0, alpha); - Row0Block1 = vmulq_n_f32(Row0Block1, alpha); - Row0Block2 = vmulq_n_f32(Row0Block2, alpha); - Row0Block3 = vmulq_n_f32(Row0Block3, alpha); - - if (ProcessTwoRows) { - Row1Block0 = vmulq_n_f32(Row1Block0, alpha); - Row1Block1 = vmulq_n_f32(Row1Block1, alpha); - Row1Block2 = vmulq_n_f32(Row1Block2, alpha); - Row1Block3 = vmulq_n_f32(Row1Block3, alpha); - } - - if (CountN >= 16) { - - // - // Store the entire output block. - // - - if (!ZeroMode) { - Row0Block0 = vaddq_f32(Row0Block0, vld1q_f32(C)); - Row0Block1 = vaddq_f32(Row0Block1, vld1q_f32(C + 4)); - Row0Block2 = vaddq_f32(Row0Block2, vld1q_f32(C + 8)); - Row0Block3 = vaddq_f32(Row0Block3, vld1q_f32(C + 12)); - } - - vst1q_f32(C, Row0Block0); - vst1q_f32(C + 4, Row0Block1); - vst1q_f32(C + 8, Row0Block2); - vst1q_f32(C + 12, Row0Block3); - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = vaddq_f32(Row1Block0, vld1q_f32(C + ldc)); - Row1Block1 = vaddq_f32(Row1Block1, vld1q_f32(C + ldc + 4)); - Row1Block2 = vaddq_f32(Row1Block2, vld1q_f32(C + ldc + 8)); - Row1Block3 = vaddq_f32(Row1Block3, vld1q_f32(C + ldc + 12)); - } - - vst1q_f32(C + ldc, Row1Block0); - vst1q_f32(C + ldc + 4, Row1Block1); - vst1q_f32(C + ldc + 8, Row1Block2); - vst1q_f32(C + ldc + 12, Row1Block3); - } - - } else { - - // - // Store the partial output block. - // - - if ((CountN & 8) != 0) { - - if (!ZeroMode) { - Row0Block0 = vaddq_f32(Row0Block0, vld1q_f32(C)); - Row0Block1 = vaddq_f32(Row0Block1, vld1q_f32(C + 4)); - } - - vst1q_f32(C, Row0Block0); - vst1q_f32(C + 4, Row0Block1); - Row0Block0 = Row0Block2; - Row0Block1 = Row0Block3; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = vaddq_f32(Row1Block0, vld1q_f32(C + ldc)); - Row1Block1 = vaddq_f32(Row1Block1, vld1q_f32(C + ldc + 4)); - } - - vst1q_f32(C + ldc, Row1Block0); - vst1q_f32(C + ldc + 4, Row1Block1); - Row1Block0 = Row1Block2; - Row1Block1 = Row1Block3; - } - - C += 8; - } - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Row0Block0 = vaddq_f32(Row0Block0, vld1q_f32(C)); - } - - vst1q_f32(C, Row0Block0); - Row0Block0 = Row0Block1; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = vaddq_f32(Row1Block0, vld1q_f32(C + ldc)); - } - - vst1q_f32(C + ldc, Row1Block0); - Row1Block0 = Row1Block1; - } - - C += 4; - } - - float32x2_t Row0Block0High; - float32x2_t Row0Block0Low; - - float32x2_t Row1Block0High; - float32x2_t Row1Block0Low; - - Row0Block0High = vget_high_f32(Row0Block0); - Row0Block0Low = vget_low_f32(Row0Block0); - - if (ProcessTwoRows) { - Row1Block0High = vget_high_f32(Row1Block0); - Row1Block0Low = vget_low_f32(Row1Block0); - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Row0Block0Low = vadd_f32(Row0Block0Low, vld1_f32(C)); - } - - vst1_f32(C, Row0Block0Low); - Row0Block0Low = Row0Block0High; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0Low = vadd_f32(Row1Block0Low, vld1_f32(C + ldc)); - } - - vst1_f32(C + ldc, Row1Block0Low); - Row1Block0Low = Row1Block0High; - } - - C += 2; - } - - if ((CountN & 1) != 0) { - - if (!ZeroMode) { - Row0Block0Low = vadd_f32(Row0Block0Low, vld1_dup_f32(C)); - } - - vst1_lane_f32(C, Row0Block0Low, 0); - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0Low = vadd_f32(Row1Block0Low, vld1_dup_f32(C + ldc)); - } - - vst1_lane_f32(C + ldc, Row1Block0Low, 0); - } - } - - break; - } - - C += 16; - CountN -= 16; - - } while (CountN > 0); - - return ProcessTwoRows ? 2 : 1; -} - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - - if (CountM >= 2) { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } - - return RowsHandled; -} - -size_t -MLASCALL -MlasSgemmKernelZero( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} - -size_t -MLASCALL -MlasSgemmKernelAdd( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} diff --git a/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h b/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h deleted file mode 100644 index 69f27bbb7fe63..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/AssembleDotProduct.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - AssembleDotProduct.h - -Abstract: - - This module contains macros to build Advanced SIMD dot product instructions - for toolchains that do not natively support this newer instruction set - extension. - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -/*++ - -Macro Description: - - This macro builds a SDOT instruction of the form: - - SDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] - -Arguments: - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - - Index - Specifies the element index of the second source register. - ---*/ - - MACRO - SdotByElement $DestReg, $Src1Reg, $Src2Reg, $Index - - DCD 0x4F80E000:OR:($DestReg):OR:($Src1Reg:SHL:5):OR:($Src2Reg:SHL:16):OR:(($Index:AND:2):SHL:10):OR:(($Index:AND:1):SHL:21) - - MEND - -/*++ - -Macro Description: - - This macro builds a UDOT instruction of the form: - - UDOT DestReg.4s, Src1Reg.16b, Src2Reg.4b[Index] - -Arguments: - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - - Index - Specifies the element index of the second source register. - ---*/ - - MACRO - UdotByElement $DestReg, $Src1Reg, $Src2Reg, $Index - - DCD 0x6F80E000:OR:($DestReg):OR:($Src1Reg:SHL:5):OR:($Src2Reg:SHL:16):OR:(($Index:AND:2):SHL:10):OR:(($Index:AND:1):SHL:21) - - MEND diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm deleted file mode 100644 index d9eafb8203b80..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm +++ /dev/null @@ -1,577 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelDot.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// -#define ConvSymFrame_SavedRegisters (6 * 8) -#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters -#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters - -#define ConvSymPostProcessParams_Bias 0 -#define ConvSymPostProcessParams_Scale 8 -#define ConvSymPostProcessParams_Min 16 -#define ConvSymPostProcessParams_Max 20 -#define ConvSymPostProcessParams_ZeroPoint 24 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the indirection buffer. Every pointer in the indirection - buffer points at a InputChannels length vector (either from the input tensor - or a vector of padding values). These are grouped in batches of length - KernelSize. These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - Must be > 1 - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - NESTED_ENTRY MlasConvSymS8KernelDot - - PROLOG_SAVE_REG_PAIR d8,d9,#-ConvSymFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] - PROLOG_SAVE_REG d10,#16 - PROLOG_NOP cmp x7,2 // OutputCount < 2 ? - PROLOG_SAVE_REG d11,#24 - PROLOG_NOP add x16,x2,x5 // x16 -> C1 - PROLOG_SAVE_REG x19,#32 - lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - add x4,x4,3 // InputChannels align to 4 - add x17,x16,x5 // x17 -> C2 - ldr x11,[x8,#ConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - ldr w10,[sp,#ConvSymFrame_KernelFlags] - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#ConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - - // TODO!! tiptoe around loading biases if we need to support - // output channels none divisible by 16 -OutputChannelLoop - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1 - mov x13,x12 -SkipLoadA2 - mov x14,x13 -SkipLoadA3 - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - ldr q6,[x1],16 - ldr q7,[x1],16 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - ldr d0,[x12],8 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - ldr d1,[x13],8 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - ldr d2,[x14],8 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - ldr d3,[x15],8 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - subs x7,x7,16 // InputChannels -= 16 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr q7,[x1],16 - b.hs InputChannelLoop - -InChLoopEpilogue - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldr d8,[x12],8 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - ldr d9,[x13],8 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldr d10,[x14],8 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - ldr d11,[x15],8 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr q4,[x1],16 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr q5,[x1],16 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr q6,[x1],16 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,0 - SdotByElement 17, 4, 9,0 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr q4,[x1],16 - SdotByElement 20, 5, 8,0 - SdotByElement 21, 5, 9,0 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr q5,[x1],16 - SdotByElement 24, 6, 8,0 - SdotByElement 25, 6, 9,0 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr q6,[x1],16 - SdotByElement 28, 7, 8,0 - SdotByElement 29, 7, 9,0 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr q7,[x1],16 - SdotByElement 16, 4, 8,1 - SdotByElement 17, 4, 9,1 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - tst x7,15 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - b.ne InChannels8 // 4 ~ 12 InputChannels - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize - tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#ConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtn v0.8b,v16.8h - sqxtn v1.8b,v17.8h - sqxtn v2.8b,v18.8h - sqxtn v3.8b,v19.8h - sqxtn2 v0.16b,v24.8h - sqxtn2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtn2 v2.16b,v26.8h - sqxtn2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel - EPILOG_RESTORE_REG x19,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#ConvSymFrame_SavedRegisters! - EPILOG_RETURN - -InChannels8 - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldp q4, q5, [x1], 32 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - tbz x7,2,SkipInCh4 - -InChannels4 - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - -SkipInCh4 - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - NESTED_END MlasConvSymS8KernelDot - - END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDotLd64.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDotLd64.asm deleted file mode 100644 index d513c8ae5f807..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDotLd64.asm +++ /dev/null @@ -1,654 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelDotLd64.S - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// -#define ConvSymFrame_SavedRegisters (10 * 8) -#define ConvSymFrame_PostProcessParams (0 + ConvSymFrame_SavedRegisters) -#define ConvSymFrame_KernelFlags (8 + ConvSymFrame_SavedRegisters) - -#define ConvSymPostProcessParams_Bias 0 -#define ConvSymPostProcessParams_Scale 8 -#define ConvSymPostProcessParams_Min 16 -#define ConvSymPostProcessParams_Max 20 -#define ConvSymPostProcessParams_ZeroPoint 24 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - NESTED_ENTRY MlasConvSymS8KernelDotLd64 - - PROLOG_SAVE_REG_PAIR d8,d9,#-ConvSymFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] - PROLOG_SAVE_REG d10,#16 - PROLOG_NOP cmp x7,2 // OutputCount < 2 ? - PROLOG_SAVE_REG d11,#24 - PROLOG_NOP add x16,x2,x5 // x16 -> C1 - PROLOG_SAVE_REG x19,#32 - PROLOG_NOP lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - PROLOG_SAVE_REG x20,#40 - PROLOG_NOP csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - PROLOG_SAVE_REG x21,#48 - PROLOG_NOP add x4,x4,3 // InputChannels align to 4 - PROLOG_SAVE_REG x22,#56 - PROLOG_NOP add x17,x16,x5 // x17 -> C2 - PROLOG_SAVE_REG x23,#64 - ldr x11,[x8,#ConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - ldr w10,[sp,#ConvSymFrame_KernelFlags] - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#ConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - - // TODO!! tiptoe around loading biases if we need to support - // output channels none divisible by 16 -OutputChannelLoop - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1 - mov x13,x12 -SkipLoadA2 - mov x14,x13 -SkipLoadA3 - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr d5,[x1],#8 - ldr x21,[x1],#8 - ldr d6,[x1],#8 - ldr x22,[x1],#8 - ldr d7,[x1],#8 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop - SdotByElement 16, 4, 0,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,0 - ldr d8,[x12],8 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,0 - ldr d9,[x13],8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,0 - ldr d10,[x14],8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,0 - ldr d11,[x15],8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x21 - SdotByElement 18, 4,10,0 - ldr d0,[x12],8 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x22 - SdotByElement 22, 5,10,0 - ldr d1,[x13],8 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x23 - SdotByElement 26, 6,10,0 - ldr d2,[x14],8 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x20 - SdotByElement 30, 7,10,0 - ldr d3,[x15],8 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x21 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x22 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x23 - SdotByElement 26, 6,10,1 - subs x7,x7,16 // InputChannels -= 16 - SdotByElement 27, 6,11,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,1 - ins v4.d[1],x20 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - ldr d7,[x1],#8 - b.hs InputChannelLoop - -InChLoopEpilogue - SdotByElement 16, 4, 0,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,0 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,0 - ldr d8,[x12],8 - SdotByElement 19, 4, 3,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,0 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,0 - ldr d9,[x13],8 - SdotByElement 23, 5, 3,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,0 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,0 - ldr d10,[x14],8 - SdotByElement 27, 6, 3,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,0 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,0 - ldr d11,[x15],8 - SdotByElement 31, 7, 3,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 0,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 1,1 - ins v5.d[1],x21 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - ldr d4,[x1],#8 - SdotByElement 20, 5, 0,1 - ldr x20,[x1],#8 - SdotByElement 21, 5, 1,1 - ins v6.d[1],x22 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - ldr d5,[x1],#8 - SdotByElement 24, 6, 0,1 - ldr x21,[x1],#8 - SdotByElement 25, 6, 1,1 - ins v7.d[1],x23 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - ldr d6,[x1],#8 - SdotByElement 28, 7, 0,1 - ldr x22,[x1],#8 - SdotByElement 29, 7, 1,1 - ins v4.d[1],x20 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,0 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,0 - ins v5.d[1],x21 - SdotByElement 18, 4,10,0 - SdotByElement 19, 4,11,0 - ldr d4,[x1],#8 - SdotByElement 20, 5, 8,0 - ldr x20,[x1],#8 - SdotByElement 21, 5, 9,0 - ins v6.d[1],x22 - SdotByElement 22, 5,10,0 - SdotByElement 23, 5,11,0 - ldr d5,[x1],#8 - SdotByElement 24, 6, 8,0 - ldr x21,[x1],#8 - SdotByElement 25, 6, 9,0 - ins v7.d[1],x23 - SdotByElement 26, 6,10,0 - SdotByElement 27, 6,11,0 - ldr d6,[x1],#8 - SdotByElement 28, 7, 8,0 - ldr x22,[x1],#8 - SdotByElement 29, 7, 9,0 - ins v4.d[1],x20 - SdotByElement 30, 7,10,0 - SdotByElement 31, 7,11,0 - ldr d7,[x1],#8 - SdotByElement 16, 4, 8,1 - ldr x23,[x1],#8 - SdotByElement 17, 4, 9,1 - ins v5.d[1],x21 - SdotByElement 18, 4,10,1 - SdotByElement 19, 4,11,1 - SdotByElement 20, 5, 8,1 - SdotByElement 21, 5, 9,1 - ins v6.d[1],x22 - SdotByElement 22, 5,10,1 - SdotByElement 23, 5,11,1 - SdotByElement 24, 6, 8,1 - SdotByElement 25, 6, 9,1 - ins v7.d[1],x23 - SdotByElement 26, 6,10,1 - SdotByElement 27, 6,11,1 - SdotByElement 28, 7, 8,1 - SdotByElement 29, 7, 9,1 - SdotByElement 30, 7,10,1 - SdotByElement 31, 7,11,1 - - tst x7,15 - b.ne InChannels8 // 4 ~ 12 InputChannels - - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize - tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#ConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtn v0.8b,v16.8h - sqxtn v1.8b,v17.8h - sqxtn v2.8b,v18.8h - sqxtn v3.8b,v19.8h - sqxtn2 v0.16b,v24.8h - sqxtn2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtn2 v2.16b,v26.8h - sqxtn2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel - EPILOG_RESTORE_REG x23,#64 - EPILOG_RESTORE_REG_PAIR x21,x22,#48 - EPILOG_RESTORE_REG_PAIR x19,x20,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#ConvSymFrame_SavedRegisters! - EPILOG_RETURN - -InChannels8 - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - ldp q4, q5, [x1], 32 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - SdotByElement 16, 4, 0,1 - SdotByElement 17, 4, 1,1 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,1 - SdotByElement 19, 4, 3,1 - SdotByElement 20, 5, 0,1 - SdotByElement 21, 5, 1,1 - SdotByElement 22, 5, 2,1 - SdotByElement 23, 5, 3,1 - SdotByElement 24, 6, 0,1 - SdotByElement 25, 6, 1,1 - SdotByElement 26, 6, 2,1 - SdotByElement 27, 6, 3,1 - SdotByElement 28, 7, 0,1 - SdotByElement 29, 7, 1,1 - SdotByElement 30, 7, 2,1 - SdotByElement 31, 7, 3,1 - tbz x7,2,SkipInCh4 - -InChannels4 - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - ldr q5, [x1], 16 - SdotByElement 16, 4, 0,0 - SdotByElement 17, 4, 1,0 - ldp q6, q7, [x1], 32 - SdotByElement 18, 4, 2,0 - SdotByElement 19, 4, 3,0 - SdotByElement 20, 5, 0,0 - SdotByElement 21, 5, 1,0 - SdotByElement 22, 5, 2,0 - SdotByElement 23, 5, 3,0 - SdotByElement 24, 6, 0,0 - SdotByElement 25, 6, 1,0 - SdotByElement 26, 6, 2,0 - SdotByElement 27, 6, 3,0 - SdotByElement 28, 7, 0,0 - SdotByElement 29, 7, 1,0 - SdotByElement 30, 7, 2,0 - SdotByElement 31, 7, 3,0 - -SkipInCh4 - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - NESTED_END MlasConvSymS8KernelDotLd64 - - END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm deleted file mode 100644 index c22730231310b..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm +++ /dev/null @@ -1,405 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymS8KernelNeon.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "kxarm64.h" - -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// -#define ConvSymFrame_SavedNeonRegisters (8 * 8) -#define ConvSymFrame_SavedRegisters ConvSymFrame_SavedNeonRegisters -#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters -#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters - -#define ConvSymPostProcessParams_Bias 0 -#define ConvSymPostProcessParams_Scale 8 -#define ConvSymPostProcessParams_Min 16 -#define ConvSymPostProcessParams_Max 20 -#define ConvSymPostProcessParams_ZeroPoint 24 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirect buffer. Every pointer in - the indirection buffer points at a InputChannels length vector (either - from the input tensor or a vector of padding values). These are grouped - in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. Must be > 1 - - InputChannels (x4) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 8. - - OutputChannels (x5) - Supplies the number of output channels. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 8. - - OutputCount (x7) - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be 1 or 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - NESTED_ENTRY MlasConvSymS8KernelNeon - - PROLOG_SAVE_REG_PAIR d8,d9,#-ConvSymFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] - PROLOG_NOP ldrb w10,[sp,#ConvSymFrame_KernelFlags] - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - mov x9,x3 // save kernel size - ldr x11,[x8,#ConvSymPostProcessParams_Bias] - mov x16,x4 // save input channels - ldr x12,[x8,#ConvSymPostProcessParams_Scale] - cmp x7,2 // if OutputCount < 2 - add x5,x2,x5 // c1 = c0 + ldc - add x4,x4,7 // kc = (kc + 7) & ~7 - csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 - bic x4,x4,7 - ldp s16,s18,[x11],8 // init accumulators with bias - ldp s20,s22,[x11],8 - ldp s24,s26,[x11],8 - ldp s28,s30,[x11],8 - mov v17.16b,v16.16b - mov v19.16b,v18.16b - mov v21.16b,v20.16b - mov v23.16b,v22.16b - mov v25.16b,v24.16b - mov v27.16b,v26.16b - mov v29.16b,v28.16b - mov v31.16b,v30.16b - -// Nested loops, inner loop: input channel; outter loop: kernel size -// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. -// -// B 8x8 -// ------------------------------------------------------------------ -// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | -// | ... ... ... ... ... ... ... ... | -// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | -// A 2x8 ------------------------------------------------------------------ -// ------------------ ------------------------------------------------------------------ -// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | -// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | -// ------------------ ------------------------------------------------------------------ -// When Input Channels greater than 16, unroll: -// A registers v6 v7, -// B registers v8 v9 -// - -KernelSizeLoop - - // Load next 2 A pointers - cmp x7,2 // test if OutputCount < 2 - ldr x13,[x0] // x13 -> A0 - bhs LoadA1 - ldr x15,[x0],#8 // x15 -> A0 - b BlockLoopPrologue -LoadA1 - ldr x15,[x0,x3,lsl#3] // x15 -> A1 - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop -BlockLoopPrologue - ldr d4,[x1] - subs x14,x4,16 // input channel - 16 - ldr d5,[x1,8] - blo InputChannel8 // less than 16 deep, no unroll - - ldr d0,[x13],8 - ldr d1,[x15],8 - ldr d8,[x1,64] - ldr d9,[x1,72] - ldr d6,[x13],8 - subs x14,x14,16 // input channel - 16 - ldr d7,[x15],8 - blo BlockLoopEpilogue // need 32 input channel for full unrolled loop - -Blockloop - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d5,[x1,40] - sadalp v19.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,56] - sadalp v23.4s, v15.8h - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - add x1,x1,128 - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1] // Read B - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - ldr d0,[x13],8 // Read A0 - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d1,[x15],8 // Read A1 - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - ldr d5,[x1,8] // Read B - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,64] // Read B - smlal v14.8h,v9.8b,v6.8b - ldr d6,[x13],8 // Read A0 - smlal v15.8h,v9.8b,v7.8b - ldr d7,[x15],8 // Read A1 - sadalp v28.4s,v12.8h - ldr d9,[x1,72] // Read B - sadalp v29.4s,v13.8h - subs x14,x14,16 - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b.hs Blockloop - -BlockLoopEpilogue // remaining 16 input channels - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v19.4s,v11.8h - ldr d5,[x1,40] - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - ldr d5,[x1,56] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - add x1,x1,128 - - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - tbnz x14,3,InputChannel8 - - subs x9,x9,1 - b.hi KernelSizeLoop - -Requantize - ldr w11,[x8,#ConvSymPostProcessParams_ZeroPoint] - tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - beq BroadcastScaleValue - ld1 {v4.4s,v5.4s},[x12] // load scale vector - b AccumulatorsToFloat - -BroadcastScaleValue - ld1r {v4.4s},[x12] // load scale Value - mov v5.16b, v4.16b - -AccumulatorsToFloat - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - addp v17.4s,v17.4s,v19.4s - addp v21.4s,v21.4s,v23.4s - addp v25.4s,v25.4s,v27.4s - addp v29.4s,v29.4s,v31.4s - addp v0.4s,v16.4s,v20.4s - addp v1.4s,v24.4s,v28.4s - addp v2.4s,v17.4s,v21.4s - addp v3.4s,v25.4s,v29.4s - scvtf v0.4s,v0.4s // convert to float - scvtf v1.4s,v1.4s - scvtf v2.4s,v2.4s - scvtf v3.4s,v3.4s - fmul v0.4s,v0.4s,v4.4s // multiply by scale - fmul v1.4s,v1.4s,v5.4s - fmul v2.4s,v2.4s,v4.4s - fmul v3.4s,v3.4s,v5.4s - fcvtns v0.4s,v0.4s // convert to int - fcvtns v1.4s,v1.4s - dup v9.8h,w11 - fcvtns v2.4s,v2.4s - fcvtns v3.4s,v3.4s - sqxtn v0.4h,v0.4s - sqxtn2 v0.8h,v1.4s - sqxtn v2.4h,v2.4s - sqxtn2 v2.8h,v3.4s - sqadd v0.8h,v0.8h,v9.8h - sqadd v2.8h,v2.8h,v9.8h - sqxtn v0.8b,v0.8h // shorten to int8 - sqxtn2 v0.16b,v2.8h - st1 {v0.d}[1],[x5] // full 2x8 store to c - st1 {v0.8b},[x2] - -ExitKernel - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -InputChannel8 - ldr d0,[x13] - ldr d1,[x15] - ldr d4,[x1] - ldr d5,[x1,8] - ldr d6,[x1,16] - ldr d7,[x1,24] - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,32] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,40] - smull v12.8h,v6.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - ldr d6,[x1,48] - sadalp v17.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - ldr d7,[x1,56] - sadalp v19.4s,v11.8h - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - smull v12.8h,v6.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - sadalp v27.4s,v11.8h - add x1,x1,64 - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - - // ks loop - subs x9,x9,1 - b.hi KernelSizeLoop - b Requantize - - NESTED_END MlasConvSymS8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm deleted file mode 100644 index 7b0917b95551c..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm +++ /dev/null @@ -1,631 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymKernelNeonDot.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "kxarm64.h" - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// -#define ConvSymFrame_SavedNeonRegisters (8 * 8) -#define ConvSymFrame_SavedRegisters ConvSymFrame_SavedNeonRegisters -#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters -#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters - -#define ConvSymPostProcessParams_Bias 0 -#define ConvSymPostProcessParams_Scale 8 -#define ConvSymPostProcessParams_Min 16 -#define ConvSymPostProcessParams_Max 20 -#define ConvSymPostProcessParams_ZeroPoint 24 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Points to the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Points to the filter buffer. - - Output (x2) - Points the output buffer. - - KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4/x7) - Number of input channels. - - OutputChannels (x5) - Number of output channels. - - ChannelCount (x6) - Number of output channels this iteration produces. - - OutputCount (x7) - Number of output elements this iteration produces. - - This implementation requires the count to be no larger than 4. - - PostProcessParams (x8) - Points to the post process parameter block. - - KernelFlags - (w10) Additional flags controlling the operation. - -Return Value: - - None. - ---*/ - NESTED_ENTRY MlasConvSymU8KernelDot - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] - PROLOG_NOP ldr w10,[sp,#ConvSymFrame_KernelFlags] - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR x19,x20,#48 - - // compute C pointers: x2, x16, x17, x5 - cmp x7,2 // OutputCount < 2 ? - add x16,x2,x5 // x16 -> C1 - lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) - csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 - mov x20,x4 - add x4,x4,3 // InputChannels align to 4 - add x17,x16,x5 // x17 -> C2 - ldr x11,[x8,#ConvSymPostProcessParams_Bias] - csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 - bic x4,x4,3 - cmp x7,4 // OutputCount < 4 ? - add x5,x17,x5 // x5 -> C3 - ldr x19,[x8,#ConvSymPostProcessParams_Scale] - csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 - movi v12.16b,128 // for top bit flipping - -OutputChannelLoop - ldp q16,q20,[x11],32 // Init accumulators with biases - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldp q24,q28,[x11],32 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - mov v30.16b,v28.16b - mov v31.16b,v28.16b - mov x9,x3 // restore KernelSize * sizeof(int8_t*) - -KernelSizeLoop - tst w10,#MLAS_CONV_SYM_FLAG_INPUT_DIRECT - beq InputIndirection - -InputDirect - cmp x16,x2 - mov x12,x0 // x12 -> A0 - add x13,x0,x20 // x13 -> A1 = A0 + input channels - csel x13,x0,x13,eq - cmp x17,x16 - add x14,x0,x20,lsl#1 // x14 -> A2 - csel x14,x13,x14,eq - cmp x5,x17 - add x15,x13,x20,lsl#1 // x15 -> A3 - csel x15,x14,x15,eq - b FinishLoadAPtr - -InputIndirection - ldr x12,[x0] // x12 -> A0 - cmp x16,x2 - b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 - cmp x17,x16 - lsl x14,x3,#1 - ldr x13,[x0,x3] // x13 -> A1 - b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 - cmp x5,x17 - add x15,x3,x3,lsl#1 - ldr x14,[x0,x14] // x14 -> A2 - b.eq SkipLoadA3 // C3==C2 -> A2=A3 - ldr x15,[x0,x15] // x15 -> A3 - b FinishLoadAPtr -SkipLoadA1 - mov x13,x12 -SkipLoadA2 - mov x14,x13 -SkipLoadA3 - mov x15,x14 - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 -// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 -// ------------------ ---------------------------------------------------------------------------- - -FinishLoadAPtr - subs x7,x4,16 // Need 16 input channels for loop - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - b.lo InChannels8 - - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - subs x7,x7,16 - ldr d2,[x14],8 - ldr d3,[x15],8 - ldr q5,[x1],16 - ldr q6,[x1],16 - ldr q7,[x1],16 - b.lo InChLoopEpilogue // Need 32 input channels for main loop - -InputChannelLoop - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - sdot v16.4s,v4.16b,v0.4b[0] - eor v2.8b,v2.8b,v12.8b - sdot v17.4s,v4.16b,v1.4b[0] - eor v3.8b,v3.8b,v12.8b - ldr d8,[x12],8 - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - ldr d9,[x13],8 - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - ldr d10,[x14],8 - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - ldr d11,[x15],8 - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - ldr q7,[x1],16 - sdot v16.4s,v4.16b,v0.4b[1] - sdot v17.4s,v4.16b,v1.4b[1] - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v0.4b[1] - sdot v21.4s,v5.16b,v1.4b[1] - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v0.4b[1] - sdot v25.4s,v6.16b,v1.4b[1] - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v0.4b[1] - sdot v29.4s,v7.16b,v1.4b[1] - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - eor v8.8b,v8.8b,v12.8b - ldr q7,[x1],16 - eor v9.8b,v9.8b,v12.8b - sdot v16.4s,v4.16b,v8.4b[0] - eor v10.8b,v10.8b,v12.8b - sdot v17.4s,v4.16b,v9.4b[0] - ldr d0,[x12],8 - eor v11.8b,v11.8b,v12.8b - sdot v18.4s,v4.16b,v10.4b[0] - sdot v19.4s,v4.16b,v11.4b[0] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v8.4b[0] - sdot v21.4s,v5.16b,v9.4b[0] - ldr d1,[x13],8 - sdot v22.4s,v5.16b,v10.4b[0] - sdot v23.4s,v5.16b,v11.4b[0] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v8.4b[0] - sdot v25.4s,v6.16b,v9.4b[0] - ldr d2,[x14],8 - sdot v26.4s,v6.16b,v10.4b[0] - sdot v27.4s,v6.16b,v11.4b[0] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v8.4b[0] - sdot v29.4s,v7.16b,v9.4b[0] - ldr d3,[x15],8 - sdot v30.4s,v7.16b,v10.4b[0] - sdot v31.4s,v7.16b,v11.4b[0] - ldr q7,[x1],16 - sdot v16.4s,v4.16b,v8.4b[1] - sdot v17.4s,v4.16b,v9.4b[1] - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v8.4b[1] - sdot v21.4s,v5.16b,v9.4b[1] - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v8.4b[1] - sdot v25.4s,v6.16b,v9.4b[1] - sdot v26.4s,v6.16b,v10.4b[1] - sdot v27.4s,v6.16b,v11.4b[1] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v8.4b[1] - sdot v29.4s,v7.16b,v9.4b[1] - subs x7,x7,16 // InputChannels -= 16 - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - ldr q7,[x1],16 - b.hs InputChannelLoop - -InChLoopEpilogue - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - sdot v16.4s,v4.16b,v0.4b[0] - eor v2.8b,v2.8b,v12.8b - sdot v17.4s,v4.16b,v1.4b[0] - eor v3.8b,v3.8b,v12.8b - ldr d8,[x12],8 - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - ldr d9,[x13],8 - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - ldr d10,[x14],8 - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - ldr d11,[x15],8 - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - ldr q7,[x1],16 - sdot v16.4s,v4.16b,v0.4b[1] - sdot v17.4s,v4.16b,v1.4b[1] - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v0.4b[1] - sdot v21.4s,v5.16b,v1.4b[1] - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v0.4b[1] - sdot v25.4s,v6.16b,v1.4b[1] - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v0.4b[1] - sdot v29.4s,v7.16b,v1.4b[1] - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - eor v8.8b,v8.8b,v12.8b - ldr q7,[x1],16 - eor v9.8b,v9.8b,v12.8b - sdot v16.4s,v4.16b,v8.4b[0] - eor v10.8b,v10.8b,v12.8b - sdot v17.4s,v4.16b,v9.4b[0] - eor v11.8b,v11.8b,v12.8b - sdot v18.4s,v4.16b,v10.4b[0] - sdot v19.4s,v4.16b,v11.4b[0] - ldr q4,[x1],16 - sdot v20.4s,v5.16b,v8.4b[0] - sdot v21.4s,v5.16b,v9.4b[0] - sdot v22.4s,v5.16b,v10.4b[0] - sdot v23.4s,v5.16b,v11.4b[0] - ldr q5,[x1],16 - sdot v24.4s,v6.16b,v8.4b[0] - sdot v25.4s,v6.16b,v9.4b[0] - sdot v26.4s,v6.16b,v10.4b[0] - sdot v27.4s,v6.16b,v11.4b[0] - ldr q6,[x1],16 - sdot v28.4s,v7.16b,v8.4b[0] - sdot v29.4s,v7.16b,v9.4b[0] - sdot v30.4s,v7.16b,v10.4b[0] - sdot v31.4s,v7.16b,v11.4b[0] - ldr q7,[x1],16 - sdot v16.4s,v4.16b,v8.4b[1] - sdot v17.4s,v4.16b,v9.4b[1] - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - sdot v20.4s,v5.16b,v8.4b[1] - sdot v21.4s,v5.16b,v9.4b[1] - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - sdot v24.4s,v6.16b,v8.4b[1] - sdot v25.4s,v6.16b,v9.4b[1] - sdot v26.4s,v6.16b,v10.4b[1] - sdot v27.4s,v6.16b,v11.4b[1] - sdot v28.4s,v7.16b,v8.4b[1] - sdot v29.4s,v7.16b,v9.4b[1] - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - - TST x7,15 - B.NE InChannels8 // 4 ~ 12 InputChannels - - subs x9,x9,8 // KernelSize-=1 - b.hi KernelSizeLoop - -Requantize - tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ldr w13,[x8,#ConvSymPostProcessParams_ZeroPoint] - beq BroadcastScaleValue - ldp q0,q1,[x19],32 // load scale vector - ldp q2,q3,[x19],32 - b AccumulatorsToFloat - -BroadcastScaleValue - ld1r {v0.4s},[x19] // load scale Value - mov v1.16b, v0.16b - mov v2.16b, v0.16b - mov v3.16b, v0.16b - -AccumulatorsToFloat - scvtf v16.4s,v16.4s // convert to float - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - scvtf v24.4s,v24.4s - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - fmul v16.4s,v16.4s,v0.4s // multiply by scale - fmul v17.4s,v17.4s,v0.4s - fmul v18.4s,v18.4s,v0.4s - fmul v19.4s,v19.4s,v0.4s - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v1.4s - fmul v22.4s,v22.4s,v1.4s - fmul v23.4s,v23.4s,v1.4s - fmul v24.4s,v24.4s,v2.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v2.4s - fmul v27.4s,v27.4s,v2.4s - fmul v28.4s,v28.4s,v3.4s - fmul v29.4s,v29.4s,v3.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v3.4s - fcvtns v16.4s,v16.4s // convert to int - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - fcvtns v24.4s,v24.4s - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - - sqxtn v16.4h,v16.4s - sqxtn v17.4h,v17.4s - sqxtn v18.4h,v18.4s - sqxtn v19.4h,v19.4s - sqxtn v24.4h,v24.4s - sqxtn v25.4h,v25.4s - sqxtn v26.4h,v26.4s - sqxtn v27.4h,v27.4s - dup v4.8h,w13 // zero point - sqxtn2 v16.8h,v20.4s - sqxtn2 v17.8h,v21.4s - sqxtn2 v18.8h,v22.4s - sqxtn2 v19.8h,v23.4s - sqxtn2 v24.8h,v28.4s - sqxtn2 v25.8h,v29.4s - sqxtn2 v26.8h,v30.4s - sqxtn2 v27.8h,v31.4s - sqadd v16.8h,v16.8h,v4.8h - sqadd v17.8h,v17.8h,v4.8h - sqadd v18.8h,v18.8h,v4.8h - sqadd v19.8h,v19.8h,v4.8h - sqadd v24.8h,v24.8h,v4.8h - sqadd v25.8h,v25.8h,v4.8h - sqadd v26.8h,v26.8h,v4.8h - sqadd v27.8h,v27.8h,v4.8h - sqxtun v0.8b,v16.8h - sqxtun v1.8b,v17.8h - sqxtun v2.8b,v18.8h - sqxtun v3.8b,v19.8h - sqxtun2 v0.16b,v24.8h - sqxtun2 v1.16b,v25.8h - subs x6,x6,16 // processed 16 output channels - sqxtun2 v2.16b,v26.8h - sqxtun2 v3.16b,v27.8h - b.lo PartialStore - - st1 {v3.16b},[x5],16 // Store full 4 x 16 - st1 {v2.16b},[x17],16 - sub x0,x0,x3 // Restore pointer to A: a -= ks - st1 {v1.16b},[x16],16 - st1 {v0.16b},[x2],16 - b.hi OutputChannelLoop - -ExitKernel - EPILOG_RESTORE_REG_PAIR x19,x20,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -InChannels8 - tbz x7,3,InChannels4 - ldr d0,[x12],8 - ldr q4,[x1],16 - ldr d1,[x13],8 - ldr d2,[x14],8 - ldr d3,[x15],8 - eor v0.8b,v0.8b,v12.8b - ldr q5,[x1],16 - eor v1.8b,v1.8b,v12.8b - sdot v16.4s,v4.16b,v0.4b[0] - sdot v17.4s,v4.16b,v1.4b[0] - eor v2.8b,v2.8b,v12.8b - ldp q6,q7,[x1],32 - eor v3.8b,v3.8b,v12.8b - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - ldp q4,q5,[x1],32 - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - sdot v16.4s,v4.16b,v0.4b[1] - sdot v17.4s,v4.16b,v1.4b[1] - ldp q6,q7,[x1],32 - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - sdot v20.4s,v5.16b,v0.4b[1] - sdot v21.4s,v5.16b,v1.4b[1] - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - sdot v24.4s,v6.16b,v0.4b[1] - sdot v25.4s,v6.16b,v1.4b[1] - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - sdot v28.4s,v7.16b,v0.4b[1] - sdot v29.4s,v7.16b,v1.4b[1] - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - tbz x7,2,SkipInCh4 - -InChannels4 - ldr s0,[x12],4 - ldr q4,[x1],16 - ldr s1,[x13],4 - ldr s2,[x14],4 - ldr s3,[x15],4 - eor v0.8b,v0.8b,v12.8b - ldr q5,[x1],16 - eor v1.8b,v1.8b,v12.8b - sdot v16.4s,v4.16b,v0.4b[0] - sdot v17.4s,v4.16b,v1.4b[0] - eor v2.8b,v2.8b,v12.8b - ldp q6,q7,[x1],32 - eor v3.8b,v3.8b,v12.8b - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - -SkipInCh4 - subs x9,x9,8 // ks -= 1 - b.hi KernelSizeLoop - b Requantize - -PartialStore - tbz x6,3,LT8Store - str d3,[x5],8 // no less than 8 channels - str d2,[x17],8 - dup d3,v3.d[1] - dup d2,v2.d[1] - str d1,[x16],8 - str d0,[x2],8 - dup d1,v1.d[1] - dup d0,v0.d[1] -LT8Store - tbz x6,2,LT4Store - str s3,[x5],4 - str s2,[x17],4 - dup s3,v3.s[1] - dup s2,v2.s[1] - str s1,[x16],4 - str s0,[x2],4 - dup s1,v1.s[1] - dup s0,v0.s[1] -LT4Store - tbz x6,1, LT2Store - str h3,[x5],2 - str h2,[x17],2 - dup h3,v3.h[1] - dup h2,v2.h[1] - str h1,[x16],2 - str h0,[x2],2 - dup h1,v1.h[1] - dup h0,v0.h[1] -LT2Store - tbz x6,0,ExitKernel - str b3,[x5] - str b2,[x17] - str b1,[x16] - str b0,[x2] - b ExitKernel - - NESTED_END MlasConvSymU8KernelDot - - END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm deleted file mode 100644 index a8a11fe6209d1..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm +++ /dev/null @@ -1,436 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymU8KernelNeon.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - ---*/ - -#include "kxarm64.h" - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - -// -// Stack frame layout for the symmetric convolution kernel. -// d8-d15, x19-x30 need to be preserved if used -// -#define ConvSymFrame_SavedNeonRegisters (8 * 8) -#define ConvSymFrame_SavedRegisters ConvSymFrame_SavedNeonRegisters -#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters -#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters - -#define ConvSymPostProcessParams_Bias 0 -#define ConvSymPostProcessParams_Scale 8 -#define ConvSymPostProcessParams_Min 16 -#define ConvSymPostProcessParams_Max 20 -#define ConvSymPostProcessParams_ZeroPoint 24 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (x4) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 8. - - OutputChannels (x5) - Supplies the number of output channels. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 8. - - OutputCount (x7) - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be 1 or 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - NESTED_ENTRY MlasConvSymU8KernelNeon - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] - PROLOG_NOP ldrb w10,[sp,#ConvSymFrame_KernelFlags] - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - mov x9,x3 // save kernel size - ldr x11,[x8,#ConvSymPostProcessParams_Bias] - mov x16,x4 // save input channels - ldr x12,[x8,#ConvSymPostProcessParams_Scale] - cmp x7,2 // if OutputCount < 2 - add x5,x2,x5 // c1 = c0 + ldc - add x4,x4,7 // kc = (kc + 7) & ~7 - csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 - bic x4,x4,7 - ldp s16,s18,[x11],8 // init accumulators with bias - ldp s20,s22,[x11],8 - ldp s24,s26,[x11],8 - ldp s28,s30,[x11],8 - mov v17.16b,v16.16b - mov v19.16b,v18.16b - mov v21.16b,v20.16b - mov v23.16b,v22.16b - mov v25.16b,v24.16b - mov v27.16b,v26.16b - mov v29.16b,v28.16b - mov v31.16b,v30.16b - -// Nested loops, inner loop: input channel; outter loop: kernel size -// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. -// -// B 8x8 -// ------------------------------------------------------------------ -// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | -// | ... ... ... ... ... ... ... ... | -// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | -// A 2x8 ------------------------------------------------------------------ -// ------------------ ------------------------------------------------------------------ -// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | -// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | -// ------------------ ------------------------------------------------------------------ -// When Input Channels greater than 16, unroll: -// A registers v6 v7, -// B registers v8 v9 -// - -KernelSizeLoop - - // Load next 2 A pointers - tst w10,#MLAS_CONV_SYM_FLAG_INPUT_DIRECT - ldr d4,[x1] - ldr d5,[x1,8] - beq InputIndirection - -InputDirect - mov x13,x0 // x13 -> A0 - add x15,x0,x16 // x15 -> A1 = A0 + input channels - b BlockLoopPrologue - -InputIndirection - cmp x7,2 // test if OutputCount < 2 - ldr x13,[x0] // x13 -> A0 - blo SkipLoadA1 - ldr x15,[x0,x3,lsl#3] // x15 -> A1 -SkipLoadA1 - -BlockLoopPrologue - cmp x7,2 // test if OutputCount < 2 - add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop - csel x15,x13,x15,lo // if OutputCount < 2 x15 -> A0 - subs x14,x4,16 // input channel - 16 - movi v12.8b,128 - blo InputChannel8 // less than 16 deep, no unroll - - ldr d0,[x13],8 - ldr d1,[x15],8 - ldr d8,[x1,64] - ldr d9,[x1,72] - ldr d6,[x13],8 - subs x14,x14,16 // input channel - 16 - ldr d7,[x15],8 - blo BlockLoopEpilogue // need 32 input channel for full unrolled loop - -Blockloop - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - eor v6.8b,v6.8b,v12.8b - eor v7.8b,v7.8b,v12.8b - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d5,[x1,40] - sadalp v19.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,56] - sadalp v23.4s, v15.8h - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - add x1,x1,128 - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1] // Read B - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - ldr d0,[x13],8 // Read A0 - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - ldr d1,[x15],8 // Read A1 - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - ldr d5,[x1,8] // Read B - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,64] // Read B - smlal v14.8h,v9.8b,v6.8b - ldr d6,[x13],8 // Read A0 - smlal v15.8h,v9.8b,v7.8b - ldr d7,[x15],8 // Read A1 - sadalp v28.4s,v12.8h - ldr d9,[x1,72] // Read B - sadalp v29.4s,v13.8h - subs x14,x14,16 - sadalp v30.4s,v14.8h - movi v12.8b,128 - sadalp v31.4s,v15.8h - b.hs Blockloop - -BlockLoopEpilogue // remaining 16 input channels - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,16] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,24] - eor v6.8b,v6.8b,v12.8b - eor v7.8b,v7.8b,v12.8b - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,80] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,88] - smull v12.8h,v4.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - ldr d4,[x1,32] - sadalp v17.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v19.4s,v11.8h - ldr d5,[x1,40] - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - ldr d8,[x1,96] - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - ldr d9,[x1,104] - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,48] - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - ldr d5,[x1,56] - smlal v2.8h,v8.8b,v6.8b - smlal v3.8h,v8.8b,v7.8b - ldr d8,[x1,112] - smlal v10.8h,v9.8b,v6.8b - smlal v11.8h,v9.8b,v7.8b - ldr d9,[x1,120] - smull v12.8h,v4.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v4.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v5.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v5.8b,v1.8b - sadalp v27.4s,v11.8h - smlal v12.8h,v8.8b,v6.8b - smlal v13.8h,v8.8b,v7.8b - smlal v14.8h,v9.8b,v6.8b - smlal v15.8h,v9.8b,v7.8b - add x1,x1,128 - - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - movi v12.8b,128 - tbnz x14,3,InputChannel8 - - subs x9,x9,1 - b.hi KernelSizeLoop - -Requantize - ldr w11,[x8,#ConvSymPostProcessParams_ZeroPoint] - tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - beq BroadcastScaleValue - ld1 {v4.4s,v5.4s},[x12] // load scale vector - b AccumulatorsToFloat - -BroadcastScaleValue - ld1r {v4.4s},[x12] // load scale Value - mov v5.16b, v4.16b - -AccumulatorsToFloat - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - addp v17.4s,v17.4s,v19.4s - addp v21.4s,v21.4s,v23.4s - addp v25.4s,v25.4s,v27.4s - addp v29.4s,v29.4s,v31.4s - addp v0.4s,v16.4s,v20.4s - addp v1.4s,v24.4s,v28.4s - addp v2.4s,v17.4s,v21.4s - addp v3.4s,v25.4s,v29.4s - scvtf v0.4s,v0.4s // convert to float - scvtf v1.4s,v1.4s - scvtf v2.4s,v2.4s - scvtf v3.4s,v3.4s - fmul v0.4s,v0.4s,v4.4s // multiply by scale - fmul v1.4s,v1.4s,v5.4s - fmul v2.4s,v2.4s,v4.4s - fmul v3.4s,v3.4s,v5.4s - fcvtns v0.4s,v0.4s // convert to int - fcvtns v1.4s,v1.4s - dup v9.8h,w11 - fcvtns v2.4s,v2.4s - fcvtns v3.4s,v3.4s - sqxtn v0.4h,v0.4s - sqxtn2 v0.8h,v1.4s - sqxtn v2.4h,v2.4s - sqxtn2 v2.8h,v3.4s - sqadd v0.8h,v0.8h,v9.8h - sqadd v2.8h,v2.8h,v9.8h - sqxtun v0.8b,v0.8h // shorten to int8 - sqxtun2 v0.16b,v2.8h - st1 {v0.d}[1],[x5] // full 2x8 store to c - st1 {v0.8b},[x2] - -ExitKernel - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -InputChannel8 - ldr d0,[x13] - ldr d1,[x15] - ldr d4,[x1] - ldr d5,[x1,8] - ldr d6,[x1,16] - ldr d7,[x1,24] - eor v0.8b,v0.8b,v12.8b - eor v1.8b,v1.8b,v12.8b - smull v2.8h,v4.8b,v0.8b - smull v3.8h,v4.8b,v1.8b - ldr d4,[x1,32] - smull v10.8h,v5.8b,v0.8b - smull v11.8h,v5.8b,v1.8b - ldr d5,[x1,40] - smull v12.8h,v6.8b,v0.8b - sadalp v16.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - ldr d6,[x1,48] - sadalp v17.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v18.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - ldr d7,[x1,56] - sadalp v19.4s,v11.8h - smull v2.8h,v4.8b,v0.8b - sadalp v20.4s,v12.8h - smull v3.8h,v4.8b,v1.8b - sadalp v21.4s,v13.8h - smull v10.8h,v5.8b,v0.8b - sadalp v22.4s,v14.8h - smull v11.8h,v5.8b,v1.8b - sadalp v23.4s,v15.8h - smull v12.8h,v6.8b,v0.8b - sadalp v24.4s,v2.8h - smull v13.8h,v6.8b,v1.8b - sadalp v25.4s,v3.8h - smull v14.8h,v7.8b,v0.8b - sadalp v26.4s,v10.8h - smull v15.8h,v7.8b,v1.8b - sadalp v27.4s,v11.8h - add x1,x1,64 - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - - // ks loop - subs x9,x9,1 - b.hi KernelSizeLoop - b Requantize - - NESTED_END MlasConvSymU8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvKernelSize9Neon.asm b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvKernelSize9Neon.asm deleted file mode 100644 index 335e11be1e1d9..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvKernelSize9Neon.asm +++ /dev/null @@ -1,662 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvKernelSize9Neon.asm - -Abstract: - - This module implements the routine for the depthwise convolution - operation with symmetrically quantized integer values for kernel - size 9. ie, 3x3, 1x9, 9x1 - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the depthwise conv kernel. -// d8-d15, x19-x30 need to be preserved if used -// - -#define ConvSymDepthwisePostProcessParams_Bias 0 -#define ConvSymDepthwisePostProcessParams_Scale 8 -#define ConvSymDepthwisePostProcessParams_ZeroPoint 24 - -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX 1 - -#define MlasConvSymDepthwiseKernelSize9_backup_x19_x20 0 -#define MlasConvSymDepthwiseKernelSize9_backup_x21_x22 16 -#define MlasConvSymDepthwiseKernelSize9_backup_x23_x24 32 -#define MlasConvSymDepthwiseKernelSize9_backup_x25_x26 48 -#define MlasConvSymDepthwiseKernelSize9_backup_x27_x28 64 -#define MlasConvSymDepthwiseKernelSize9_backup_d8_d9 80 -#define MlasConvSymDepthwiseKernelSize9_backup_d10_d11 96 -#define MlasConvSymDepthwiseKernelSize9_backup_d12_d13 112 -#define MlasConvSymDepthwiseKernelSize9_backup_d14_d15 128 -#define MlasConvSymDepthwiseKernelSize9_SavedRegisters 144 -#define MlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg -144 - - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise quantized convolution - on kernel size 9 for U8S8. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Channels (x2) - Supplies the number of input and output channels. - - Output (x3) - Supplies the address of the output buffer. - - OutputCount (x4)- Supplies the number of image pixels. - - PostProcessParams (x5) - Supplies the address of the post process parameter block. - - KernelFlags (x6) - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - LEAF_ENTRY MlasConvSymDepthwiseKernelSize9Arm64U8S8 - - PROLOG_SAVE_REG_PAIR x19, x20, #MlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg ! - PROLOG_SAVE_REG_PAIR x21, x22, #MlasConvSymDepthwiseKernelSize9_backup_x21_x22 - PROLOG_SAVE_REG_PAIR x23, x24, #MlasConvSymDepthwiseKernelSize9_backup_x23_x24 - PROLOG_SAVE_REG_PAIR x25, x26, #MlasConvSymDepthwiseKernelSize9_backup_x25_x26 - PROLOG_SAVE_REG_PAIR x27, x28, #MlasConvSymDepthwiseKernelSize9_backup_x27_x28 - PROLOG_SAVE_REG_PAIR d8, d9, #MlasConvSymDepthwiseKernelSize9_backup_d8_d9 - PROLOG_SAVE_REG_PAIR d10, d11, #MlasConvSymDepthwiseKernelSize9_backup_d10_d11 - PROLOG_SAVE_REG_PAIR d12, d13, #MlasConvSymDepthwiseKernelSize9_backup_d12_d13 - PROLOG_SAVE_REG_PAIR d14, d15, #MlasConvSymDepthwiseKernelSize9_backup_d14_d15 - - ldr x9, [x5, #ConvSymDepthwisePostProcessParams_Bias] - ldr x8, [x5, #ConvSymDepthwisePostProcessParams_Scale] - add x5, x5, #ConvSymDepthwisePostProcessParams_ZeroPoint - ins v12.d[0], x1 // Filter - ins v13.d[0], x9 // Bias - ins v13.d[1], x8 // Scale - ld1r {v0.8h}, [x5] // v0.8h <--- vector for output zero point - movi v5.16b, #0x80 - - tbnz x6, #MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, MlasConvSymDepthwiseKernelSize9_SkipPerTensorScaleInit - ld1r {v1.4s}, [x8] // load and dup scale value - mov v2.16b, v1.16b - mov v3.16b, v1.16b - mov v4.16b, v1.16b - -MlasConvSymDepthwiseKernelSize9_SkipPerTensorScaleInit - - add x9, x3, x2 // x9 <---- Ouput1, x3 is Ouput0 - cbz x4, MlasConvSymDepthwiseKernelSize9_Exit - -MlasConvSymDepthwiseKernelSize9_OutputLoop - ldp x20, x21, [x0], #72 // input ptrs for Output0 - ldp x22, x23, [x0, #-56] - sub x4, x4, #1 - ldp x24, x25, [x0, #-40] - ldp x26, x27, [x0, #-24] - ldur x28, [x0, #-8] - - cbz x4, MlasConvSymDepthwiseKernelSize9_Dup_Inputs - ldp x10, x11, [x0], #72 // input ptrs for Output0 - ldp x12, x13, [x0, #-56] - sub x4, x4, #1 - ldp x14, x15, [x0, #-40] - ldp x16, x17, [x0, #-24] - ldur x19, [x0, #-8] - b MlasConvSymDepthwiseKernelSize9_Loaded_Input - -MlasConvSymDepthwiseKernelSize9_Dup_Inputs - mov x9, x3 // Output1 <-- Output0 - mov x10, x20 - mov x11, x21 - mov x12, x22 - mov x13, x23 - mov x14, x24 - mov x15, x25 - mov x16, x26 - mov x17, x27 - mov x19, x28 - -MlasConvSymDepthwiseKernelSize9_Loaded_Input - - eor x8, x8, x8 // Processed channels - umov x1, v12.d[0] // filter - umov x5, v13.d[0] // bias - umov x7, v13.d[1] // scale - - cmp x8, x2 // Save one register by not using count down to zero here - bhs MlasConvSymDepthwiseKernelSize9_Finish_Channels16_Loop - -MlasConvSymDepthwiseKernelSize9_Channels16_Loop - ld1 {v10.16b}, [x1], x2 // vk0 - ldr q16, [x20, x8] // out0 vi0 - ldr q17, [x10, x8] // out1 vi0 - ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [x5], #64 // bias vacc 0-15 for outs - ld1 {v11.16b}, [x1], x2 // vk1 - ldr q18, [x21, x8] // out0 vi1 - ldr q19, [x11, x8] // out1 vi1 - - eor v16.16b, v16.16b, v5.16b // -128 to signed int8 - eor v17.16b, v17.16b, v5.16b - ld1 {v14.16b}, [x1], x2 // vk2 - eor v18.16b, v18.16b, v5.16b - eor v19.16b, v19.16b, v5.16b - - ldr q20, [x22, x8] // out0 vi2 - smull v24.8h, v10.8b, v16.8b - smull2 v25.8h, v10.16b, v16.16b - ldr q21, [x12, x8] // out1 vi2 - smull v26.8h, v10.8b, v17.8b - ld1 {v15.16b}, [x1], x2 // vk3 - smull2 v27.8h, v10.16b, v17.16b - ldr q22, [x23, x8] // out0 vi3 - smull v28.8h, v11.8b, v18.8b - smull2 v29.8h, v11.16b, v18.16b - ldr q23, [x13, x8] // out1 vi3 - smull v30.8h, v11.8b, v19.8b - smull2 v31.8h, v11.16b, v19.16b - - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ld1 {v10.16b}, [x1], x2 // vk4 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - ld1 {v11.16b}, [x1], x2 // vk5 - - saddw v16.4s, v6.4s, v24.4h // dup acc for out1 - saddw2 v17.4s, v7.4s, v24.8h - saddw v18.4s, v8.4s, v25.4h - saddw2 v19.4s, v9.4s, v25.8h - - ldr q20, [x24, x8] // out0 vi4 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x14, x8] // out1 vi4 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - ldr q22, [x25, x8] // out0 vi5 - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - ldr q23, [x15, x8] // out1 vi5 - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - ld1 {v14.16b}, [x1], x2 // vk6 - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ld1 {v15.16b}, [x1], x2 // vk7 - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - smull v28.8h, v11.8b, v22.8b - smull2 v29.8h, v11.16b, v22.16b - smull v30.8h, v11.8b, v23.8b - smull2 v31.8h, v11.16b, v23.16b - - ldr q20, [x26, x8] // out0 vi6 - ldr q21, [x16, x8] // out1 vi6 - ldr q22, [x27, x8] // out0 vi7 - ldr q23, [x17, x8] // out1 vi7 - tbz x6, #MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, DonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9 - ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x7], #64 // scales 0-15 for outs - -DonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9 - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - eor v22.16b, v22.16b, v5.16b - eor v23.16b, v23.16b, v5.16b - ldr q10, [x1] // vk8 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - ldr q20, [x28, x8] // out0 vi8 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x19, x8] // out1 vi8 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - eor v20.16b, v20.16b, v5.16b - eor v21.16b, v21.16b, v5.16b - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - scvtf v16.4s, v16.4s // Requantize - scvtf v17.4s, v17.4s - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v16.4s, v16.4s, v1.4s - fmul v17.4s, v17.4s, v2.4s - fmul v18.4s, v18.4s, v3.4s - fmul v19.4s, v19.4s, v4.4s - fmul v6.4s, v6.4s, v1.4s - fmul v7.4s, v7.4s, v2.4s - fmul v8.4s, v8.4s, v3.4s - fmul v9.4s, v9.4s, v4.4s - - fcvtns v16.4s, v16.4s - fcvtns v17.4s, v17.4s - fcvtns v18.4s, v18.4s - fcvtns v19.4s, v19.4s - fcvtns v6.4s, v6.4s - fcvtns v7.4s, v7.4s - fcvtns v8.4s, v8.4s - fcvtns v9.4s, v9.4s - - sqxtn v16.4h, v16.4s // +zp, narrow and combine - sqxtn v18.4h, v18.4s - sqxtn v6.4h, v6.4s - sqxtn v8.4h, v8.4s - sqxtn2 v16.8h, v17.4s - sqxtn2 v18.8h, v19.4s - sqxtn2 v6.8h, v7.4s - sqxtn2 v8.8h, v9.4s - sqadd v16.8h, v16.8h, v0.8h - sqadd v18.8h, v18.8h, v0.8h - sqadd v6.8h, v6.8h, v0.8h - sqadd v8.8h, v8.8h, v0.8h - sqxtun v16.8b, v16.8h - sqxtun2 v16.16b, v18.8h - sqxtun v6.8b, v6.8h - sqxtun2 v6.16b, v8.8h - - str q16, [x3, x8] - str q6, [x9, x8] - add x8, x8, #16 - umov x1, v12.d[0] // filter's beginning - cmp x8, x2 - add x1, x1, x8 - blo MlasConvSymDepthwiseKernelSize9_Channels16_Loop - -MlasConvSymDepthwiseKernelSize9_Finish_Channels16_Loop - add x3, x3, x2, LSL #1 - add x9, x9, x2, LSL #1 - cbnz x4, MlasConvSymDepthwiseKernelSize9_OutputLoop - -MlasConvSymDepthwiseKernelSize9_Exit - EPILOG_RESTORE_REG_PAIR d14, d15, #MlasConvSymDepthwiseKernelSize9_backup_d14_d15 - EPILOG_RESTORE_REG_PAIR d12, d13, #MlasConvSymDepthwiseKernelSize9_backup_d12_d13 - EPILOG_RESTORE_REG_PAIR d10, d11, #MlasConvSymDepthwiseKernelSize9_backup_d10_d11 - EPILOG_RESTORE_REG_PAIR d8, d9, #MlasConvSymDepthwiseKernelSize9_backup_d8_d9 - EPILOG_RESTORE_REG_PAIR x27, x28, #MlasConvSymDepthwiseKernelSize9_backup_x27_x28 - EPILOG_RESTORE_REG_PAIR x25, x26, #MlasConvSymDepthwiseKernelSize9_backup_x25_x26 - EPILOG_RESTORE_REG_PAIR x23, x24, #MlasConvSymDepthwiseKernelSize9_backup_x23_x24 - EPILOG_RESTORE_REG_PAIR x21, x22, #MlasConvSymDepthwiseKernelSize9_backup_x21_x22 - EPILOG_RESTORE_REG_PAIR x19, x20, #MlasConvSymDepthwiseKernelSize9_SavedRegisters ! - EPILOG_RETURN - - LEAF_END MlasConvSymDepthwiseKernelSize9Arm64U8S8 - - - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise quantized convolution - on kernel size 9 for S8S8. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Channels (x2) - Supplies the number of input and output channels. - - Output (x3) - Supplies the address of the output buffer. - - OutputCount (x4)- Supplies the number of image pixels. - - PostProcessParams (x5) - Supplies the address of the post process parameter block. - - KernelFlags (x6) - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - LEAF_ENTRY MlasConvSymDepthwiseKernelSize9Arm64S8S8 - - PROLOG_SAVE_REG_PAIR x19, x20, #MlasConvSymDepthwiseKernelSize9_SavedRegisters_Neg ! - PROLOG_SAVE_REG_PAIR x21, x22, #MlasConvSymDepthwiseKernelSize9_backup_x21_x22 - PROLOG_SAVE_REG_PAIR x23, x24, #MlasConvSymDepthwiseKernelSize9_backup_x23_x24 - PROLOG_SAVE_REG_PAIR x25, x26, #MlasConvSymDepthwiseKernelSize9_backup_x25_x26 - PROLOG_SAVE_REG_PAIR x27, x28, #MlasConvSymDepthwiseKernelSize9_backup_x27_x28 - PROLOG_SAVE_REG_PAIR d8, d9, #MlasConvSymDepthwiseKernelSize9_backup_d8_d9 - PROLOG_SAVE_REG_PAIR d10, d11, #MlasConvSymDepthwiseKernelSize9_backup_d10_d11 - PROLOG_SAVE_REG_PAIR d12, d13, #MlasConvSymDepthwiseKernelSize9_backup_d12_d13 - PROLOG_SAVE_REG_PAIR d14, d15, #MlasConvSymDepthwiseKernelSize9_backup_d14_d15 - - ldr x9, [x5, #ConvSymDepthwisePostProcessParams_Bias] - ldr x8, [x5, #ConvSymDepthwisePostProcessParams_Scale] - add x5, x5, #ConvSymDepthwisePostProcessParams_ZeroPoint - ins v12.d[0], x1 // Filter - ins v13.d[0], x9 // Bias - ins v13.d[1], x8 // Scale - ld1r {v0.8h}, [x5] // v0.8h <--- vector for output zero point - - tbnz x6, #MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, MlasConvSymDepthwiseKernelSize9S8S8_SkipPerTensorScaleInit - ld1r {v1.4s}, [x8] // load and dup scale value - mov v2.16b, v1.16b - mov v3.16b, v1.16b - mov v4.16b, v1.16b - -MlasConvSymDepthwiseKernelSize9S8S8_SkipPerTensorScaleInit - - add x9, x3, x2 // x9 <---- Ouput1, x3 is Ouput0 - cbz x4, MlasConvSymDepthwiseKernelSize9S8S8_Exit - -MlasConvSymDepthwiseKernelSize9S8S8_OutputLoop - ldp x20, x21, [x0], #72 // input ptrs for Output0 - ldp x22, x23, [x0, #-56] - sub x4, x4, #1 - ldp x24, x25, [x0, #-40] - ldp x26, x27, [x0, #-24] - ldur x28, [x0, #-8] - - cbz x4, MlasConvSymDepthwiseKernelSize9S8S8_Dup_Inputs - ldp x10, x11, [x0], #72 // input ptrs for Output0 - ldp x12, x13, [x0, #-56] - sub x4, x4, #1 - ldp x14, x15, [x0, #-40] - ldp x16, x17, [x0, #-24] - ldur x19, [x0, #-8] - b MlasConvSymDepthwiseKernelSize9S8S8_Loaded_Input - -MlasConvSymDepthwiseKernelSize9S8S8_Dup_Inputs - mov x9, x3 // Output1 <-- Output0 - mov x10, x20 - mov x11, x21 - mov x12, x22 - mov x13, x23 - mov x14, x24 - mov x15, x25 - mov x16, x26 - mov x17, x27 - mov x19, x28 - -MlasConvSymDepthwiseKernelSize9S8S8_Loaded_Input - - eor x8, x8, x8 // Processed channels - umov x1, v12.d[0] // filter - umov x5, v13.d[0] // bias - umov x7, v13.d[1] // scale - - cmp x8, x2 // Save one register by not using count down to zero here - bhs MlasConvSymDepthwiseKernelSize9S8S8_Finish_Channels16_Loop - -MlasConvSymDepthwiseKernelSize9S8S8_Channels16_Loop - ld1 {v10.16b}, [x1], x2 // vk0 - ldr q16, [x20, x8] // out0 vi0 - ldr q17, [x10, x8] // out1 vi0 - ld1 {v6.4s, v7.4s, v8.4s, v9.4s}, [x5], #64 // bias vacc 0-15 for outs - ld1 {v11.16b}, [x1], x2 // vk1 - ldr q18, [x21, x8] // out0 vi1 - ldr q19, [x11, x8] // out1 vi1 - - ld1 {v14.16b}, [x1], x2 // vk2 - - ldr q20, [x22, x8] // out0 vi2 - smull v24.8h, v10.8b, v16.8b - smull2 v25.8h, v10.16b, v16.16b - ldr q21, [x12, x8] // out1 vi2 - smull v26.8h, v10.8b, v17.8b - ld1 {v15.16b}, [x1], x2 // vk3 - smull2 v27.8h, v10.16b, v17.16b - ldr q22, [x23, x8] // out0 vi3 - smull v28.8h, v11.8b, v18.8b - smull2 v29.8h, v11.16b, v18.16b - ldr q23, [x13, x8] // out1 vi3 - smull v30.8h, v11.8b, v19.8b - smull2 v31.8h, v11.16b, v19.16b - - ld1 {v10.16b}, [x1], x2 // vk4 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - ld1 {v11.16b}, [x1], x2 // vk5 - - saddw v16.4s, v6.4s, v24.4h // dup acc for out1 - saddw2 v17.4s, v7.4s, v24.8h - saddw v18.4s, v8.4s, v25.4h - saddw2 v19.4s, v9.4s, v25.8h - - ldr q20, [x24, x8] // out0 vi4 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x14, x8] // out1 vi4 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - ldr q22, [x25, x8] // out0 vi5 - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - ldr q23, [x15, x8] // out1 vi5 - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - ld1 {v14.16b}, [x1], x2 // vk6 - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - ld1 {v15.16b}, [x1], x2 // vk7 - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - smull v28.8h, v11.8b, v22.8b - smull2 v29.8h, v11.16b, v22.16b - smull v30.8h, v11.8b, v23.8b - smull2 v31.8h, v11.16b, v23.16b - - ldr q20, [x26, x8] // out0 vi6 - ldr q21, [x16, x8] // out1 vi6 - ldr q22, [x27, x8] // out0 vi7 - ldr q23, [x17, x8] // out1 vi7 - tbz x6, #MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE_BIT_INDEX, DonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9S8S8 - ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x7], #64 // scales 0-15 for outs - -DonePerChannelScaleLoad_MlasConvSymDepthwiseKernelSize9S8S8 - ldr q10, [x1] // vk8 - - smlal v24.8h, v14.8b, v20.8b - smlal2 v25.8h, v14.16b, v20.16b - smlal v26.8h, v14.8b, v21.8b - smlal2 v27.8h, v14.16b, v21.16b - smlal v28.8h, v15.8b, v22.8b - smlal2 v29.8h, v15.16b, v22.16b - smlal v30.8h, v15.8b, v23.8b - smlal2 v31.8h, v15.16b, v23.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - ldr q20, [x28, x8] // out0 vi8 - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - ldr q21, [x19, x8] // out1 vi8 - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - - saddw v16.4s, v16.4s, v28.4h - saddw2 v17.4s, v17.4s, v28.8h - saddw v18.4s, v18.4s, v29.4h - saddw2 v19.4s, v19.4s, v29.8h - - saddw v6.4s, v6.4s, v30.4h - saddw2 v7.4s, v7.4s, v30.8h - saddw v8.4s, v8.4s, v31.4h - saddw2 v9.4s, v9.4s, v31.8h - - smull v24.8h, v10.8b, v20.8b - smull2 v25.8h, v10.16b, v20.16b - smull v26.8h, v10.8b, v21.8b - smull2 v27.8h, v10.16b, v21.16b - - saddw v16.4s, v16.4s, v24.4h - saddw2 v17.4s, v17.4s, v24.8h - saddw v18.4s, v18.4s, v25.4h - saddw2 v19.4s, v19.4s, v25.8h - - saddw v6.4s, v6.4s, v26.4h - saddw2 v7.4s, v7.4s, v26.8h - saddw v8.4s, v8.4s, v27.4h - saddw2 v9.4s, v9.4s, v27.8h - - scvtf v16.4s, v16.4s // Requantize - scvtf v17.4s, v17.4s - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v16.4s, v16.4s, v1.4s - fmul v17.4s, v17.4s, v2.4s - fmul v18.4s, v18.4s, v3.4s - fmul v19.4s, v19.4s, v4.4s - fmul v6.4s, v6.4s, v1.4s - fmul v7.4s, v7.4s, v2.4s - fmul v8.4s, v8.4s, v3.4s - fmul v9.4s, v9.4s, v4.4s - - fcvtns v16.4s, v16.4s - fcvtns v17.4s, v17.4s - fcvtns v18.4s, v18.4s - fcvtns v19.4s, v19.4s - fcvtns v6.4s, v6.4s - fcvtns v7.4s, v7.4s - fcvtns v8.4s, v8.4s - fcvtns v9.4s, v9.4s - - sqxtn v16.4h, v16.4s // +zp, narrow and combine - sqxtn v18.4h, v18.4s - sqxtn v6.4h, v6.4s - sqxtn v8.4h, v8.4s - sqxtn2 v16.8h, v17.4s - sqxtn2 v18.8h, v19.4s - sqxtn2 v6.8h, v7.4s - sqxtn2 v8.8h, v9.4s - sqadd v16.8h, v16.8h, v0.8h - sqadd v18.8h, v18.8h, v0.8h - sqadd v6.8h, v6.8h, v0.8h - sqadd v8.8h, v8.8h, v0.8h - sqxtn v16.8b, v16.8h - sqxtn2 v16.16b, v18.8h - sqxtn v6.8b, v6.8h - sqxtn2 v6.16b, v8.8h - - str q16, [x3, x8] - str q6, [x9, x8] - add x8, x8, #16 - umov x1, v12.d[0] // filter's beginning - cmp x8, x2 - add x1, x1, x8 - blo MlasConvSymDepthwiseKernelSize9S8S8_Channels16_Loop - -MlasConvSymDepthwiseKernelSize9S8S8_Finish_Channels16_Loop - add x3, x3, x2, LSL #1 - add x9, x9, x2, LSL #1 - cbnz x4, MlasConvSymDepthwiseKernelSize9S8S8_OutputLoop - -MlasConvSymDepthwiseKernelSize9S8S8_Exit - EPILOG_RESTORE_REG_PAIR d14, d15, #MlasConvSymDepthwiseKernelSize9_backup_d14_d15 - EPILOG_RESTORE_REG_PAIR d12, d13, #MlasConvSymDepthwiseKernelSize9_backup_d12_d13 - EPILOG_RESTORE_REG_PAIR d10, d11, #MlasConvSymDepthwiseKernelSize9_backup_d10_d11 - EPILOG_RESTORE_REG_PAIR d8, d9, #MlasConvSymDepthwiseKernelSize9_backup_d8_d9 - EPILOG_RESTORE_REG_PAIR x27, x28, #MlasConvSymDepthwiseKernelSize9_backup_x27_x28 - EPILOG_RESTORE_REG_PAIR x25, x26, #MlasConvSymDepthwiseKernelSize9_backup_x25_x26 - EPILOG_RESTORE_REG_PAIR x23, x24, #MlasConvSymDepthwiseKernelSize9_backup_x23_x24 - EPILOG_RESTORE_REG_PAIR x21, x22, #MlasConvSymDepthwiseKernelSize9_backup_x21_x22 - EPILOG_RESTORE_REG_PAIR x19, x20, #MlasConvSymDepthwiseKernelSize9_SavedRegisters ! - EPILOG_RETURN - - LEAF_END MlasConvSymDepthwiseKernelSize9Arm64S8S8 - - END diff --git a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm deleted file mode 100644 index ba565ca587e0c..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm +++ /dev/null @@ -1,693 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvSymS8KernelNeon.asm - -Abstract: - - This module implements the kernels for the depthwise convolution - operation with symmetrically quantized integer values - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the depthwise conv kernel. -// d8-d15, x19-x30 need to be preserved if used -// - -#define ConvSymDepthwiseKernelFrame_SavedRegisters (4 * 8) -#define ConvSymDepthwiseKernelFrame_PostProcessParams 0 + ConvSymDepthwiseKernelFrame_SavedRegisters -#define ConvSymDepthwiseKernelFrame_KernelFlags 8 + ConvSymDepthwiseKernelFrame_SavedRegisters - -#define ConvSymDepthwisePostProcessParams_Bias 0 -#define ConvSymDepthwisePostProcessParams_Scale 8 -#define ConvSymDepthwisePostProcessParams_Min 16 -#define ConvSymDepthwisePostProcessParams_Max 20 -#define ConvSymDepthwisePostProcessParams_ZeroPoint 24 - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - Channels (x4) - Supplies the number of input and output channels. - - ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 16 or 8 - - OutputCount (x7)- Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - NESTED_ENTRY MlasConvSymDepthwiseS8KernelNeon - - PROLOG_SAVE_REG_PAIR d12,d13,#-ConvSymDepthwiseKernelFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#ConvSymDepthwiseKernelFrame_PostProcessParams] - PROLOG_SAVE_REG_PAIR d14,d15,#16 - cmp x7,2 - add x9,x0,x3,lsl#3 // x9 -> &A1 - add x14,x0,x3,lsl#4 // x14 -> &A2 - add x15,x9,x3,lsl#4 // x15 -> &A3 - ldr x16,[x8,#ConvSymDepthwisePostProcessParams_Bias] - csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 - csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 - ldr x11,[x9],#8 // x11 -> A1 iter 0 - cmp x7,4 - ldp q24,q25,[x16],#32 // init accumulators with bias - csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 - cmp x6,16 - ldr x10,[x0],#8 // x10 -> A0 iter 0 - b.lo Process8Channels - -// -// Process an input block of length Channels for each element of the kernel. -// -// Filter: v0, -// v1 // unroll -// Input: -// x0 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x9 -> x11 -> v6 -// -> x13 -> v3 // unroll -// x14 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x15 -> x11 -> v6 -// -> x13 -> v3 // unroll -// - -Process16Channels - cmp x3,1 - ldp q26,q27,[x16] - b.eq ProcC16P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1],x4 // filter iter 0 - ld1 {v1.16b},[x1],x4 // filter iter 1 - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] // A0 iter 0 - mov v20.16b,v24.16b - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v21.16b,v25.16b - ldr q6,[x11,x5] // A1 iter 0 - mov v30.16b,v26.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v31.16b,v27.16b - ldr q2,[x12,x5] // A0 iter 1 - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - mov v19.16b,v27.16b - ldr q3,[x13,x5] // A1 iter 1 - mov v22.16b,v26.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v23.16b,v27.16b - -BlockLoopC16 - - // - // Process 2 pixels, and load next two pixels - // - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A2 iter 0 - b.eq EpilogueC16P2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x0],#8 // x10 -> A0 iter 2 - smull2 v15.8h,v0.16b,v6.16b - cmp x3,1 - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - b.eq EpilogueC16P3 // 3 pixel remains - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x0],#8 // x12 -> A0 iter 3 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A1 iter 2 - ld1 {v0.16b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A1 iter 3 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.16b},[x1],x4 // filter iter 3 - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b BlockLoopC16 - -EpilogueC16P2 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad2 - ldp q4,q5,[x12],#32 // load scale vector if per channel - ldp q6,q3,[x12] -SkipScaleVecLoad2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b Dequantization - -ProcC16P1 - // - // Channel 16 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - ldr x13,[x15],#8 // x13 -> A3 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1] - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] - mov v20.16b,v24.16b - mov v21.16b,v25.16b - ldr q6,[x11,x5] - mov v30.16b,v26.16b - mov v31.16b,v27.16b - ldr q2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - mov v19.16b,v27.16b - ldr q3,[x13,x5] - mov v22.16b,v26.16b - mov v23.16b,v27.16b - b EpilogueC16P1 - -EpilogueC16P3 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 2 - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ld1 {v0.16b},[x1] // filter iter 2 - ldr q6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - ldr q2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v3.8b - smlal2 v15.8h,v1.16b,v3.16b - ldr q3,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -EpilogueC16P1 - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - smull v12.8h,v0.8b,v2.8b - smull2 v13.8h,v0.16b,v2.16b - smull v14.8h,v0.8b,v3.8b - smull2 v15.8h,v0.16b,v3.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad - ldp q4,q5,[x12],#32 // load scale vector if per channel - ldp q6,q3,[x12] -SkipScaleVecLoad - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -Dequantization - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - b.ne SkipScaleBroadcast - mov v5.16b,v4.16b // broadcast scale val if not per channel - mov v6.16b,v4.16b - mov v3.16b,v4.16b -SkipScaleBroadcast - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v5.4s - fmul v26.4s,v26.4s,v6.4s - fmul v27.4s,v27.4s,v3.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v5.4s - fmul v30.4s,v30.4s,v6.4s - fmul v31.4s,v31.4s,v3.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v5.4s - fmul v18.4s,v18.4s,v6.4s - fmul v19.4s,v19.4s,v3.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v5.4s - fmul v22.4s,v22.4s,v6.4s - fmul v23.4s,v23.4s,v3.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn v26.4h,v26.4s - sqxtn2 v24.8h,v25.4s - sqxtn2 v26.8h,v27.4s - sqxtn v28.4h,v28.4s - sqxtn v30.4h,v30.4s - sqxtn2 v28.8h,v29.4s - sqxtn2 v30.8h,v31.4s - dup v0.8h,w15 - sqxtn v16.4h,v16.4s - sqxtn v18.4h,v18.4s - sqxtn2 v16.8h,v17.4s - sqxtn2 v18.8h,v19.4s - sqxtn v20.4h,v20.4s - sqxtn v22.4h,v22.4s - sqxtn2 v20.8h,v21.4s - sqxtn2 v22.8h,v23.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v26.8h,v26.8h,v0.8h - sqadd v28.8h,v28.8h,v0.8h - sqadd v30.8h,v30.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v18.8h,v18.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqadd v22.8h,v22.8h,v0.8h - sqxtn v24.8b,v24.8h // shorten to int8 - sqxtn2 v24.16b,v26.8h - sqxtn v28.8b,v28.8h - sqxtn2 v28.16b,v30.8h - sqxtn v16.8b,v16.8h - sqxtn2 v16.16b,v18.8h - sqxtn v20.8b,v20.8h - sqxtn2 v20.16b,v22.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.16b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 2 - st1 {v28.16b},[x2],x4 - b.ls ExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.16b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 4 - str q20,[x2] - -ExitKernel - EPILOG_RESTORE_REG_PAIR d14,d15,#16 - EPILOG_RESTORE_REG_PAIR d12,d13,#ConvSymDepthwiseKernelFrame_SavedRegisters! - EPILOG_RETURN - -Process8Channels - cmp x3,1 - b.eq ProcC8P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - ld1 {v0.8b},[x1],x4 // filter iter 0 - ld1 {v1.8b},[x1],x4 // filter iter 1 - ldr d4,[x10,x5] // A0 iter 0 - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v28.16b,v24.16b - ldr d6,[x11,x5] // A1 iter 0 - mov v29.16b,v25.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v16.16b,v24.16b - ldr d2,[x12,x5] // A0 iter 1 - mov v17.16b,v25.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - subs x3,x3,2 // decrement input blocks remaining - ldr d3,[x13,x5] // A1 iter 1 - mov v20.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v21.16b,v25.16b - -BlockLoopC8 - // - // Process 2 pixels, and load next two pixels - // - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A2 iter 0 - smull v14.8h,v0.8b,v6.8b - b.eq EpilogueC8P2 - ldr x10,[x0],#8 // x10 -> A0 iter 2 - ldr d6,[x11,x5] // A3 iter 0 - cmp x3,1 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal v14.8h,v1.8b,v3.8b - ldr d2,[x12,x5] // A2 iter 1 - b.eq EpilogueC8P3 // 3 pixel remains - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - ldr x12,[x0],#8 // x12 -> A0 iter 3 - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - ldr d6,[x11,x5] // A1 iter 2 - ld1 {v0.8b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - ldr d2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v3.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - saddw v16.4s,v16.4s,v12.4h - ldr d3,[x13,x5] // A1 iter 3 - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.8b},[x1],x4 // filter iter 3 - saddw v20.4s,v20.4s,v14.4h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw2 v21.4s,v21.4s,v14.8h - b BlockLoopC8 - -EpilogueC8P2 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr d6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v3.8b - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - smull v12.8h,v0.8b,v4.8b - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v14.8h,v0.8b,v6.8b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smlal v12.8h,v1.8b,v2.8b - smlal v14.8h,v1.8b,v3.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad2C8 - ldp q4,q5,[x12],#32 // load scale vector if per channel -SkipScaleVecLoad2C8 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - b DequantC8 - -ProcC8P1 - // - // Channel 8 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - mov v28.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 - mov v29.16b,v25.16b - ld1 {v0.8b},[x1] - mov v16.16b,v24.16b - ldr d4,[x10,x5] - mov v17.16b,v25.16b - ldr d6,[x11,x5] - mov v20.16b,v24.16b - ldr d2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - ldr d3,[x13,x5] - mov v21.16b,v25.16b - b EpilogueC8P1 - -EpilogueC8P3 - // - // Loop epilogue (process 2 of last 3 pixels) - // - ldr x12,[x14],#8 // x12 -> A2 iter 2 - ldr d3,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ld1 {v0.8b},[x1] // filter iter 2 - ldr d6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v3.8b - ldr d3,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -EpilogueC8P1 - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - smull v12.8h,v0.8b,v2.8b - smull v14.8h,v0.8b,v3.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoadC8 - ldp q4,q5,[x12] // load scale vector if per channel -SkipScaleVecLoadC8 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -DequantC8 - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - b.ne SkipScaleBroadcastC8 - mov v5.16b,v4.16b // broadcast scale val if not per channel -SkipScaleBroadcastC8 - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v5.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v5.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v5.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v5.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - dup v0.8h,w15 - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn2 v24.8h,v25.4s - sqxtn v28.4h,v28.4s - sqxtn2 v28.8h,v29.4s - sqxtn v16.4h,v16.4s - sqxtn2 v16.8h,v17.4s - sqxtn v20.4h,v20.4s - sqxtn2 v20.8h,v21.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v28.8h,v28.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqxtn v24.8b,v24.8h // shorten to int8 - sqxtn v28.8b,v28.8h - sqxtn v16.8b,v16.8h - sqxtn v20.8b,v20.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.8b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 2 - st1 {v28.8b},[x2],x4 - b.ls ExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.8b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 4 - str d20,[x2] - b ExitKernel - NESTED_END MlasConvSymDepthwiseS8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm deleted file mode 100644 index e9f75f1be5bfd..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm +++ /dev/null @@ -1,745 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DepthwiseQConvSymU8KernelNeon.asm - -Abstract: - - This module implements the kernels for the depthwise convolution - operation with symmetrically quantized integer values - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the depthwise conv kernel. -// d8-d15, x19-x30 need to be preserved if used -// - -#define ConvSymDepthwiseKernelFrame_SavedNeonRegisters (8 * 8) -#define ConvSymDepthwiseKernelFrame_SavedRegisters ConvSymDepthwiseKernelFrame_SavedNeonRegisters -#define ConvSymDepthwiseKernelFrame_PostProcessParams 0 + ConvSymDepthwiseKernelFrame_SavedRegisters -#define ConvSymDepthwiseKernelFrame_KernelFlags 8 + ConvSymDepthwiseKernelFrame_SavedRegisters - -#define ConvSymDepthwisePostProcessParams_Bias 0 -#define ConvSymDepthwisePostProcessParams_Scale 8 -#define ConvSymDepthwisePostProcessParams_Min 16 -#define ConvSymDepthwisePostProcessParams_Max 20 -#define ConvSymDepthwisePostProcessParams_ZeroPoint 24 - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 - - TEXTAREA - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (x0) - Supplies the address of the indirection buffer. - - Filter (x1) - Supplies the address of the filter buffer. - - Output (x2) - Supplies the address of the output buffer. - - KernelSize (x3) - Supplies the size of the kernel. - - Channels (x4) - Supplies the number of input and output channels. - - ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount (x6) - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 16 or 8 - - OutputCount (x7)- Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 2. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - NESTED_ENTRY MlasConvSymDepthwiseU8KernelNeon - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_NOP ldr x8,[sp,#ConvSymDepthwiseKernelFrame_PostProcessParams] - PROLOG_NOP mov w10,#0x80808080 - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - dup v8.4s,w10 // bit flip vector - ldr x16,[x8,#ConvSymDepthwisePostProcessParams_Bias] - cmp x7,2 - add x9,x0,x3,lsl#3 // x9 -> &A1 - add x14,x0,x3,lsl#4 // x14 -> &A2 - add x15,x9,x3,lsl#4 // x15 -> &A3 - csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 - csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 - ldr x11,[x9],#8 // x11 -> A1 iter 0 - cmp x7,4 - ldp q24,q25,[x16],#32 // init accumulators with bias - csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 - cmp x6,16 - ldr x10,[x0],#8 // x10 -> A0 iter 0 - b.lo Process8Channels - -// -// Process an input block of length Channels for each element of the kernel. -// -// Filter: v0, -// v1 // unroll -// Input: -// x0 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x9 -> x11 -> v6 -// -> x13 -> v10 // unroll -// x14 -> x10 -> v4 -// -> x12 -> v2 // unroll -// x15 -> x11 -> v6 -// -> x13 -> v10 // unroll -// - -Process16Channels - cmp x3,1 - ldp q26,q27,[x16] - b.eq ProcC16P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1],x4 // filter iter 0 - ld1 {v1.16b},[x1],x4 // filter iter 1 - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] // A0 iter 0 - mov v20.16b,v24.16b - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v21.16b,v25.16b - ldr q6,[x11,x5] // A1 iter 0 - mov v30.16b,v26.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v31.16b,v27.16b - ldr q2,[x12,x5] // A0 iter 1 - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - mov v19.16b,v27.16b - ldr q10,[x13,x5] // A1 iter 1 - mov v22.16b,v26.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v23.16b,v27.16b - -BlockLoopC16 - - // - // Process 2 pixels, and load next two pixels - // - eor v4.16b,v4.16b,v8.16b // fix sign bits - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A2 iter 0 - b.eq EpilogueC16P2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x0],#8 // x10 -> A0 iter 2 - smull2 v15.8h,v0.16b,v6.16b - eor v2.16b,v2.16b,v8.16b - cmp x3,1 - ldr q6,[x11,x5] // A3 iter 0 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - b.eq EpilogueC16P3 // 3 pixel remains - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x0],#8 // x12 -> A0 iter 3 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v4.16b,v4.16b,v8.16b - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A1 iter 2 - eor v2.16b,v2.16b,v8.16b - ld1 {v0.16b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A1 iter 3 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.16b},[x1],x4 // filter iter 3 - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b BlockLoopC16 - -EpilogueC16P2 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ldr q6,[x11,x5] // A3 iter 0 - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.16b,v4.16b,v8.16b - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad2 - ldp q4,q11,[x12],#32 // load scale vector if per channel - ldp q6,q9,[x12] -SkipScaleVecLoad2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - b Dequantization - -ProcC16P1 - // - // Channel 16 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - ldr x13,[x15],#8 // x13 -> A3 - mov v28.16b,v24.16b - mov v29.16b,v25.16b - ld1 {v0.16b},[x1] - mov v16.16b,v24.16b - mov v17.16b,v25.16b - ldr q4,[x10,x5] - mov v20.16b,v24.16b - mov v21.16b,v25.16b - ldr q6,[x11,x5] - mov v30.16b,v26.16b - mov v31.16b,v27.16b - ldr q2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - mov v18.16b,v26.16b - mov v19.16b,v27.16b - ldr q10,[x13,x5] - mov v22.16b,v26.16b - mov v23.16b,v27.16b - b EpilogueC16P1 - -EpilogueC16P3 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 2 - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v4.16b,v4.16b,v8.16b - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr q4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - ld1 {v0.16b},[x1] // filter iter 2 - ldr q6,[x11,x5] // A1 iter 2 - eor v2.16b,v2.16b,v8.16b - smlal v12.8h,v1.8b,v2.8b - smlal2 v13.8h,v1.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - ldr q2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v10.8b - smlal2 v15.8h,v1.16b,v10.16b - ldr q10,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -EpilogueC16P1 - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.16b,v4.16b,v8.16b - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - smull2 v13.8h,v0.16b,v4.16b - eor v6.16b,v6.16b,v8.16b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - smull2 v15.8h,v0.16b,v6.16b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v26.4s,v26.4s,v13.4h - saddw2 v27.4s,v27.4s,v13.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - saddw v30.4s,v30.4s,v15.4h - saddw2 v31.4s,v31.4s,v15.8h - eor v2.16b,v2.16b,v8.16b - smull v12.8h,v0.8b,v2.8b - smull2 v13.8h,v0.16b,v2.16b - eor v10.16b,v10.16b,v8.16b - smull v14.8h,v0.8b,v10.8b - smull2 v15.8h,v0.16b,v10.16b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad - ldp q4,q11,[x12],#32 // load scale vector if per channel - ldp q6,q9,[x12] -SkipScaleVecLoad - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v18.4s,v18.4s,v13.4h - saddw2 v19.4s,v19.4s,v13.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - saddw v22.4s,v22.4s,v15.4h - saddw2 v23.4s,v23.4s,v15.8h - -Dequantization - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v26.4s,v26.4s - scvtf v27.4s,v27.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v30.4s,v30.4s - scvtf v31.4s,v31.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v18.4s,v18.4s - scvtf v19.4s,v19.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - scvtf v22.4s,v22.4s - scvtf v23.4s,v23.4s - b.ne SkipScaleBroadcast - mov v11.16b,v4.16b // broadcast scale val if not per channel - mov v6.16b,v4.16b - mov v9.16b,v4.16b -SkipScaleBroadcast - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v11.4s - fmul v26.4s,v26.4s,v6.4s - fmul v27.4s,v27.4s,v9.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v11.4s - fmul v30.4s,v30.4s,v6.4s - fmul v31.4s,v31.4s,v9.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v11.4s - fmul v18.4s,v18.4s,v6.4s - fmul v19.4s,v19.4s,v9.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v11.4s - fmul v22.4s,v22.4s,v6.4s - fmul v23.4s,v23.4s,v9.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v26.4s,v26.4s - fcvtns v27.4s,v27.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v30.4s,v30.4s - fcvtns v31.4s,v31.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v18.4s,v18.4s - fcvtns v19.4s,v19.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - fcvtns v22.4s,v22.4s - fcvtns v23.4s,v23.4s - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn v26.4h,v26.4s - sqxtn2 v24.8h,v25.4s - sqxtn2 v26.8h,v27.4s - sqxtn v28.4h,v28.4s - sqxtn v30.4h,v30.4s - sqxtn2 v28.8h,v29.4s - sqxtn2 v30.8h,v31.4s - dup v0.8h,w15 - sqxtn v16.4h,v16.4s - sqxtn v18.4h,v18.4s - sqxtn2 v16.8h,v17.4s - sqxtn2 v18.8h,v19.4s - sqxtn v20.4h,v20.4s - sqxtn v22.4h,v22.4s - sqxtn2 v20.8h,v21.4s - sqxtn2 v22.8h,v23.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v26.8h,v26.8h,v0.8h - sqadd v28.8h,v28.8h,v0.8h - sqadd v30.8h,v30.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v18.8h,v18.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqadd v22.8h,v22.8h,v0.8h - sqxtun v24.8b,v24.8h // shorten to int8 - sqxtun2 v24.16b,v26.8h - sqxtun v28.8b,v28.8h - sqxtun2 v28.16b,v30.8h - sqxtun v16.8b,v16.8h - sqxtun2 v16.16b,v18.8h - sqxtun v20.8b,v20.8h - sqxtun2 v20.16b,v22.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.16b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 2 - st1 {v28.16b},[x2],x4 - b.ls ExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.16b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 4 - str q20,[x2] - -ExitKernel - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -Process8Channels - cmp x3,1 - b.eq ProcC8P1 - - ldr x12,[x0],#8 // x12 -> A0 iter 1 - ldr x13,[x9],#8 // x13 -> A1 iter 1 - ld1 {v0.8b},[x1],x4 // filter iter 0 - ld1 {v1.8b},[x1],x4 // filter iter 1 - ldr d4,[x10,x5] // A0 iter 0 - ldr x10,[x14],#8 // x10 -> A2 iter 0 - mov v28.16b,v24.16b - ldr d6,[x11,x5] // A1 iter 0 - mov v29.16b,v25.16b - ldr x11,[x15],#8 // x11 -> A3 iter 0 - mov v16.16b,v24.16b - ldr d2,[x12,x5] // A0 iter 1 - mov v17.16b,v25.16b - ldr x12,[x14],#8 // x12 -> A2 iter 1 - subs x3,x3,2 // decrement input blocks remaining - ldr d10,[x13,x5] // A1 iter 1 - mov v20.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 iter 1 - mov v21.16b,v25.16b - -BlockLoopC8 - // - // Process 2 pixels, and load next two pixels - // - eor v4.8b,v4.8b,v8.8b // fix sign bits - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A2 iter 0 - smull v14.8h,v0.8b,v6.8b - b.eq EpilogueC8P2 - ldr x10,[x0],#8 // x10 -> A0 iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ldr d6,[x11,x5] // A3 iter 0 - cmp x3,1 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x9],#8 // x11 -> A1 iter 2 - smlal v14.8h,v1.8b,v10.8b - ldr d2,[x12,x5] // A2 iter 1 - b.eq EpilogueC8P3 // 3 pixel remains - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - ldr x12,[x0],#8 // x12 -> A0 iter 3 - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x9],#8 // x13 -> A1 iter 3 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - subs x3,x3,2 // decrement input blocks remaining - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ldr x10,[x14],#8 // x10 -> A2 iter 2 - ldr d6,[x11,x5] // A1 iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ld1 {v0.8b},[x1],x4 // filter iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr x11,[x15],#8 // x11 -> A3 iter 2 - ldr d2,[x12,x5] // A0 iter 3 - smlal v14.8h,v1.8b,v10.8b - ldr x12,[x14],#8 // x12 -> A2 iter 3 - saddw v16.4s,v16.4s,v12.4h - ldr d10,[x13,x5] // A1 iter 3 - saddw2 v17.4s,v17.4s,v12.8h - ld1 {v1.8b},[x1],x4 // filter iter 3 - saddw v20.4s,v20.4s,v14.4h - ldr x13,[x15],#8 // x13 -> A3 iter 3 - saddw2 v21.4s,v21.4s,v14.8h - b BlockLoopC8 - -EpilogueC8P2 - // - // Loop epilogue (process last 2 pixels) mixed - // with loading of dequantization params - // - ldr d6,[x11,x5] // A3 iter 0 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 1 - smlal v14.8h,v1.8b,v10.8b - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v14.8h,v0.8b,v6.8b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smlal v12.8h,v1.8b,v2.8b - smlal v14.8h,v1.8b,v10.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoad2C8 - ldp q4,q11,[x12],#32 // load scale vector if per channel -SkipScaleVecLoad2C8 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - b DequantC8 - -ProcC8P1 - // - // Channel 8 kernel size 1 - // TODO!! is this reachable at all? - // - ldr x12,[x14],#8 // x12 -> A2 - mov v28.16b,v24.16b - ldr x13,[x15],#8 // x13 -> A3 - mov v29.16b,v25.16b - ld1 {v0.8b},[x1] - mov v16.16b,v24.16b - ldr d4,[x10,x5] - mov v17.16b,v25.16b - ldr d6,[x11,x5] - mov v20.16b,v24.16b - ldr d2,[x12,x5] - subs x3,x3,2 // decrement input blocks remaining - ldr d10,[x13,x5] - mov v21.16b,v25.16b - b EpilogueC8P1 - -EpilogueC8P3 - // - // Loop epilogue (process 2 of last 3 pixels) - // - ldr x12,[x14],#8 // x12 -> A2 iter 2 - ldr d10,[x13,x5] // A3 iter 1 - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - ldr x13,[x15],#8 // x13 -> A3 iter 2 - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - smull v12.8h,v0.8b,v4.8b - ldr d4,[x10,x5] // A0 iter 2 - smull v14.8h,v0.8b,v6.8b - ld1 {v0.8b},[x1] // filter iter 2 - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - ldr d6,[x11,x5] // A1 iter 2 - smlal v12.8h,v1.8b,v2.8b - ldr d2,[x12,x5] // A2 iter 2 - smlal v14.8h,v1.8b,v10.8b - ldr d10,[x13,x5] // A3 iter 2 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -EpilogueC8P1 - // - // Loop epilogue (process last single pixel) mixed with loading of dequantization params - // - ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] - eor v4.8b,v4.8b,v8.8b - eor v6.8b,v6.8b,v8.8b - ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] - smull v12.8h,v0.8b,v4.8b - ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] - smull v14.8h,v0.8b,v6.8b - saddw v24.4s,v24.4s,v12.4h - saddw2 v25.4s,v25.4s,v12.8h - saddw v28.4s,v28.4s,v14.4h - saddw2 v29.4s,v29.4s,v14.8h - eor v2.8b,v2.8b,v8.8b - eor v10.8b,v10.8b,v8.8b - smull v12.8h,v0.8b,v2.8b - smull v14.8h,v0.8b,v10.8b - tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - ld1r {v4.4s},[x12] // load scale val - b.eq SkipScaleVecLoadC8 - ldp q4,q11,[x12] // load scale vector if per channel -SkipScaleVecLoadC8 - saddw v16.4s,v16.4s,v12.4h - saddw2 v17.4s,v17.4s,v12.8h - saddw v20.4s,v20.4s,v14.4h - saddw2 v21.4s,v21.4s,v14.8h - -DequantC8 - scvtf v24.4s,v24.4s // convert to float - scvtf v25.4s,v25.4s - scvtf v28.4s,v28.4s - scvtf v29.4s,v29.4s - scvtf v16.4s,v16.4s - scvtf v17.4s,v17.4s - scvtf v20.4s,v20.4s - scvtf v21.4s,v21.4s - b.ne SkipScaleBroadcastC8 - mov v11.16b,v4.16b // broadcast scale val if not per channel -SkipScaleBroadcastC8 - fmul v24.4s,v24.4s,v4.4s // multiply by scale - fmul v25.4s,v25.4s,v11.4s - fmul v28.4s,v28.4s,v4.4s - fmul v29.4s,v29.4s,v11.4s - fmul v16.4s,v16.4s,v4.4s - fmul v17.4s,v17.4s,v11.4s - fmul v20.4s,v20.4s,v4.4s - fmul v21.4s,v21.4s,v11.4s - fcvtns v24.4s,v24.4s // convert to int - fcvtns v25.4s,v25.4s - fcvtns v28.4s,v28.4s - fcvtns v29.4s,v29.4s - fcvtns v16.4s,v16.4s - fcvtns v17.4s,v17.4s - fcvtns v20.4s,v20.4s - fcvtns v21.4s,v21.4s - dup v0.8h,w15 - sqxtn v24.4h,v24.4s // shorten to int16 - sqxtn2 v24.8h,v25.4s - sqxtn v28.4h,v28.4s - sqxtn2 v28.8h,v29.4s - sqxtn v16.4h,v16.4s - sqxtn2 v16.8h,v17.4s - sqxtn v20.4h,v20.4s - sqxtn2 v20.8h,v21.4s - sqadd v24.8h,v24.8h,v0.8h // add zero point - sqadd v28.8h,v28.8h,v0.8h - sqadd v16.8h,v16.8h,v0.8h - sqadd v20.8h,v20.8h,v0.8h - sqxtun v24.8b,v24.8h // shorten to int8 - sqxtun v28.8b,v28.8h - sqxtun v16.8b,v16.8h - sqxtun v20.8b,v20.8h - cmp x7,2 // OutputCount < 2 ? - st1 {v24.8b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 2 - st1 {v28.8b},[x2],x4 - b.ls ExitKernel // exit if OutputCount <=2 - cmp x7,4 // OutputCount < 4 ? - st1 {v16.8b},[x2],x4 - b.lo ExitKernel // exit if OutputCount < 4 - str d20,[x2] - b ExitKernel - NESTED_END MlasConvSymDepthwiseU8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm deleted file mode 100644 index d7b626327780c..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ /dev/null @@ -1,552 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - HalfGemmKernelNeon.asm - -Abstract: - - This module implements the kernels for the half precision matrix/matrix - multiply operation (HALF GEMM). - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the half gemm kernel. -// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. -// - -#define HGemmKernelFrame_SavedRegs (2 * 8) -#define HGemmKernelFrame_B 0 + HGemmKernelFrame_SavedRegs -#define HGemmKernelFrame_ldb 8 + HGemmKernelFrame_SavedRegs -#define HGemmKernelFrame_ZeroMode 16 + HGemmKernelFrame_SavedRegs - - -/*++ - -Routine Description: - - This routine is an inner kernel to compute 6 rows of GEMM - -Arguments: - - CountM - (x0) the number of rows for matrix A and matrix C. - only process 6 rows - - CountN - (x1) the number of columns from matrix B and matrix C - - CountK - (x2/x0) the number of columns from matrix A and the - number of rows from matrix B. - - C - (x3) the address of matrix C. - - ldc - (x4) - the first dimension of matrix C. - - Bias - (x5) - the address of the Bias vector (optional) - - A - (x6) - the address of matrix A - - lda - (x7) - the first dimension of matrix A - - B - the address of matrix B - - ldb - the first dimension of matrix B - - ZeroMode - true if the output matrix must be zero initialized, else - if the output matrix is accumulated into - ---*/ - - LEAF_ENTRY MlasHalfGemmKernelNeon - - PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! - ldr x9,[sp,#HGemmKernelFrame_ldb] - lsl x2,x2,#1 // k *= sizeof(fp16) - cmp x0,2 - add x14,x6,x7,lsl #1 // a1 = a0 + lda - add x10,x3,x4,lsl #1 // c1 = c0 + ldc - ldr x8,[sp,#HGemmKernelFrame_B] - csel x14,x6,x14,LO // M < 2 ? a1 = a0 - csel x10,x3,x10,LO // c1 = c0 - add x15,x14,x7,lsl #1 // a2 = a1 + lda - add x11,x10,x4,lsl #1 // c2 = c1 + ldc - csel x15,x14,x15,LS // M <= 2 ? a2 = a1 - csel x11,x10,x11,LS // c2 = c1 - cmp x0,4 - add x16,x15,x7,lsl #1 // a3 = a2 + lda - add x12,x11,x4,lsl #1 // c3 = c2 + ldc - csel x16,x15,x16,LO // M < 4 ? a3 = a2 - csel x12,x11,x12,LO // c3 = c2 - add x17,x16,x7,lsl #1 // a4 = a3 + lda - add x13,x12,x4,lsl #1 // c4 = c3 + ldc - csel x17,x16,x17,LS // M <= 4 ? a4 = a3 - csel x13,x12,x13,LS // c4 = c3 - cmp x0,6 - add x7,x17,x7,lsl #1 // a5 = a4 + lda - add x4,x13,x4,lsl #1 // c5 = c4 + ldc - csel x7,x17,x7,LO // M < 6 ? a5 = a4 - csel x4,x13,x4,LO // c5 = c4 - lsl x9,x9,#1 // ldb *= sizeof(fp16) - ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] - sub x9,x9,16 // ldb -= 16 - -/**** -Main loop processes 6x16 tile, depth 4. - B 4x16 - --------------------------------------- - |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 - |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 - |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 - |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 - A 6x4 --------------------------------------- - ------------------ --------------------------------------- -x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 -x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 -x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 -x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 -x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 -x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 - ------------------ --------------------------------------- -****/ - -M6N16OutterLoopN - cbz x5,M6N16SkipBias - ldp q20,q21,[x5],32 // Load 16 Bias values - b M6N16PopulateAccumulators - -M6N16SkipBias - eor q20.16b,q20.16b,q20.16b // No bias, reset regs - eor q21.16b,q21.16b,q21.16b - -M6N16PopulateAccumulators - mov v22.16b,v20.16b - mov v23.16b,v21.16b - mov v24.16b,v20.16b - mov v25.16b,v21.16b - mov v26.16b,v20.16b - mov v27.16b,v21.16b - mov v28.16b,v20.16b - subs x0,x2,8 // k -= 4 (8 bytes) - mov v29.16b,v21.16b - mov v30.16b,v20.16b - mov v31.16b,v21.16b - b.LO M6N16RemainderK123 // remaining k 1~3 - - ldr d0,[x6],8 // A0 - ldr q16,[x8],16 // B0.l - ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row - subs x0,x0,8 // over decement k -= 4 (8 bytes) - ldr d1,[x14],8 // A1 - ldr d2,[x15],8 // A2 - ldr d3,[x16],8 // A3 - b.LO M6N16LoopK_Epilogue // need k>=8 for main loop - -M6N16InnerLoopK - fmla v20.8h,v16.8h,v0.h[0] - fmla v21.8h,v17.8h,v0.h[0] - ldr d4,[x17],8 // A4 - fmla v22.8h,v16.8h,v1.h[0] - fmla v23.8h,v17.8h,v1.h[0] - ldr d5,[x7],8 // A5 - fmla v24.8h,v16.8h,v2.h[0] - fmla v25.8h,v17.8h,v2.h[0] - ldr q18,[x8],16 // B1.low - fmla v26.8h,v16.8h,v3.h[0] - fmla v27.8h,v17.8h,v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - fmla v28.8h,v16.8h,v4.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v31.8h,v17.8h,v5.h[0] - subs x0,x0,8 // k -= 4 - - fmla v20.8h,v18.8h,v0.h[1] - fmla v21.8h,v19.8h,v0.h[1] - ldr q16,[x8],16 // B2.low - fmla v22.8h,v18.8h,v1.h[1] - fmla v23.8h,v19.8h,v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - fmla v24.8h,v18.8h,v2.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v31.8h,v19.8h,v5.h[1] - - fmla v20.8h,v16.8h,v0.h[2] - fmla v21.8h,v17.8h,v0.h[2] - ldr q18,[x8],16 // B3.low - fmla v22.8h,v16.8h,v1.h[2] - fmla v23.8h,v17.8h,v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - fmla v24.8h,v16.8h,v2.h[2] - fmla v25.8h,v17.8h,v2.h[2] - fmla v26.8h,v16.8h,v3.h[2] - fmla v27.8h,v17.8h,v3.h[2] - fmla v28.8h,v16.8h,v4.h[2] - fmla v29.8h,v17.8h,v4.h[2] - fmla v30.8h,v16.8h,v5.h[2] - fmla v31.8h,v17.8h,v5.h[2] - - ldr q16,[x8],16 // Load B0.low for next iter - fmla v20.8h,v18.8h,v0.h[3] - fmla v21.8h,v19.8h,v0.h[3] - ld1 {v17.16b},[x8],x9 // Load B0.high for next iter - fmla v22.8h,v18.8h,v1.h[3] - fmla v23.8h,v19.8h,v1.h[3] - ldr d0,[x6],8 // Load A0 for next iter - fmla v24.8h,v18.8h,v2.h[3] - fmla v25.8h,v19.8h,v2.h[3] - ldr d1,[x14],8 // Load A1 for next iter - fmla v26.8h,v18.8h,v3.h[3] - fmla v27.8h,v19.8h,v3.h[3] - ldr d2,[x15],8 // Load A2 for next iter - fmla v28.8h,v18.8h,v4.h[3] - fmla v29.8h,v19.8h,v4.h[3] - ldr d3,[x16],8 // Load A3 for next iter - fmla v30.8h,v18.8h,v5.h[3] - fmla v31.8h,v19.8h,v5.h[3] - b.hs M6N16InnerLoopK // k >= 8 for main loop - -M6N16LoopK_Epilogue - // last block of k >= 4, no pre-load for next iter - fmla v20.8h,v16.8h,v0.h[0] - fmla v21.8h,v17.8h,v0.h[0] - ldr d4,[x17],8 // A4 - fmla v22.8h,v16.8h,v1.h[0] - fmla v23.8h,v17.8h,v1.h[0] - ldr d5,[x7],8 // A5 - fmla v24.8h,v16.8h,v2.h[0] - fmla v25.8h,v17.8h,v2.h[0] - ldr q18,[x8],16 // B1.low - fmla v26.8h,v16.8h,v3.h[0] - fmla v27.8h,v17.8h,v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - fmla v28.8h,v16.8h,v4.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v31.8h,v17.8h,v5.h[0] - adds x0,x0,8 // revert k over-decrement - - fmla v20.8h,v18.8h,v0.h[1] - fmla v21.8h,v19.8h,v0.h[1] - ldr q16,[x8],16 // B2.low - fmla v22.8h,v18.8h,v1.h[1] - fmla v23.8h,v19.8h,v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - fmla v24.8h,v18.8h,v2.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v31.8h,v19.8h,v5.h[1] - - fmla v20.8h,v16.8h,v0.h[2] - fmla v21.8h,v17.8h,v0.h[2] - ldr q18,[x8],16 // B3.low - fmla v22.8h,v16.8h,v1.h[2] - fmla v23.8h,v17.8h,v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - fmla v24.8h,v16.8h,v2.h[2] - fmla v25.8h,v17.8h,v2.h[2] - fmla v26.8h,v16.8h,v3.h[2] - fmla v27.8h,v17.8h,v3.h[2] - fmla v28.8h,v16.8h,v4.h[2] - fmla v29.8h,v17.8h,v4.h[2] - fmla v30.8h,v16.8h,v5.h[2] - fmla v31.8h,v17.8h,v5.h[2] - - fmla v20.8h,v18.8h,v0.h[3] - fmla v21.8h,v19.8h,v0.h[3] - fmla v22.8h,v18.8h,v1.h[3] - fmla v23.8h,v19.8h,v1.h[3] - fmla v24.8h,v18.8h,v2.h[3] - fmla v25.8h,v19.8h,v2.h[3] - fmla v26.8h,v18.8h,v3.h[3] - fmla v27.8h,v19.8h,v3.h[3] - fmla v28.8h,v18.8h,v4.h[3] - fmla v29.8h,v19.8h,v4.h[3] - fmla v30.8h,v18.8h,v5.h[3] - fmla v31.8h,v19.8h,v5.h[3] - b.NE M6N16RemainderK123 // remaining k 1~3 - -M6N16OutterLoopNTail - subs x1,x1,16 // N -= 16 - ldr x8,[sp,#HGemmKernelFrame_B] - b.LO M6StoreRemainderN // remaining N < 16 - - cbnz x19,M6N16SkipAccumulateOutput - ldp q0,q1,[x3] - ldp q2,q3,[x10] - ldp q4,q5,[x11] - ldp q6,q7,[x12] - ldp q16,q17,[x13] - ldp q18,q19,[x4] - fadd v20.8h,v20.8h,v0.8h // !ZeroMode - fadd v21.8h,v21.8h,v1.8h // accumulate into C - fadd v22.8h,v22.8h,v2.8h - fadd v23.8h,v23.8h,v3.8h - fadd v24.8h,v24.8h,v4.8h - fadd v25.8h,v25.8h,v5.8h - fadd v26.8h,v26.8h,v6.8h - fadd v27.8h,v27.8h,v7.8h - fadd v28.8h,v28.8h,v16.8h - fadd v29.8h,v29.8h,v17.8h - fadd v30.8h,v30.8h,v18.8h - fadd v31.8h,v31.8h,v19.8h - -M6N16SkipAccumulateOutput - st1 {v20.16b,v21.16b},[x3],32 - sub x6,x6,x2 // restore a0 - st1 {v22.16b,v23.16b},[x10],32 - sub x14,x14,x2 // restore a1 - st1 {v24.16b,v25.16b},[x11],32 - sub x15,x15,x2 // restore a2 - st1 {v26.16b,v27.16b},[x12],32 - sub x16,x16,x2 // restore a3 - st1 {v28.16b,v29.16b},[x13],32 - sub x17,x17,x2 // restore a4 - add x8,x8,32 // B <- next 16 columns - st1 {v30.16b,v31.16b},[x4],32 - sub x7,x7,x2 // restore a5 - str x8,[sp,#HGemmKernelFrame_B] - b.HI M6N16OutterLoopN - -ExitKernel - EPILOG_RESTORE_REG x19,#HGemmKernelFrame_SavedRegs! - EPILOG_RETURN - -M6N16RemainderK123 - tbz x0,2,M6N16RemainderK1 - ldr s0,[x6],4 // A0 - ldr q16,[x8],16 // B0.low - ld1 {v17.16b},[x8],x9 // B0.high - ldr s1,[x14],4 // A1 - ldr s2,[x15],4 // A2 - ldr s3,[x16],4 // A3 - ldr s4,[x17],4 // A4 - ldr s5,[x7],4 // A5 - ldr q18,[x8],16 // B1.low - ld1 {v19.16b},[x8],x9 // B2.high - fmla v20.8h,v16.8h,v0.h[0] - fmla v22.8h,v16.8h,v1.h[0] - fmla v24.8h,v16.8h,v2.h[0] - fmla v26.8h,v16.8h,v3.h[0] - fmla v28.8h,v16.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v21.8h,v17.8h,v0.h[0] - fmla v23.8h,v17.8h,v1.h[0] - fmla v25.8h,v17.8h,v2.h[0] - fmla v27.8h,v17.8h,v3.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v31.8h,v17.8h,v5.h[0] - - fmla v20.8h,v18.8h,v0.h[1] - fmla v22.8h,v18.8h,v1.h[1] - fmla v24.8h,v18.8h,v2.h[1] - fmla v26.8h,v18.8h,v3.h[1] - fmla v28.8h,v18.8h,v4.h[1] - fmla v30.8h,v18.8h,v5.h[1] - fmla v21.8h,v19.8h,v0.h[1] - fmla v23.8h,v19.8h,v1.h[1] - fmla v25.8h,v19.8h,v2.h[1] - fmla v27.8h,v19.8h,v3.h[1] - fmla v29.8h,v19.8h,v4.h[1] - fmla v31.8h,v19.8h,v5.h[1] - tbz x0,1,M6N16OutterLoopNTail - -M6N16RemainderK1 - ldr h0,[x6],2 // A0 - ldr q16,[x8],16 // B0.low - ld1 {v17.16b},[x8],x9 // B0.high - ldr h1,[x14],2 // A1 - ldr h2,[x15],2 // A2 - ldr h3,[x16],2 // A3 - ldr h4,[x17],2 // A4 - ldr h5,[x7],2 // A5 - fmla v20.8h,v16.8h,v0.h[0] - fmla v22.8h,v16.8h,v1.h[0] - fmla v24.8h,v16.8h,v2.h[0] - fmla v26.8h,v16.8h,v3.h[0] - fmla v28.8h,v16.8h,v4.h[0] - fmla v30.8h,v16.8h,v5.h[0] - fmla v21.8h,v17.8h,v0.h[0] - fmla v23.8h,v17.8h,v1.h[0] - fmla v25.8h,v17.8h,v2.h[0] - fmla v27.8h,v17.8h,v3.h[0] - fmla v29.8h,v17.8h,v4.h[0] - fmla v31.8h,v17.8h,v5.h[0] - b M6N16OutterLoopNTail - -M6StoreRemainderN - cbnz x19,M6StoreRemainderNZeroMode - tbz x1,3,M6StoreRemainderN4 - ldr q0,[x3] - ldr q1,[x10] - ldr q2,[x11] - ldr q3,[x12] - ldr q4,[x13] - ldr q5,[x4] - fadd v20.8h,v20.8h,v0.8h - fadd v22.8h,v22.8h,v1.8h - fadd v24.8h,v24.8h,v2.8h - str q20,[x3],16 - mov v20.16b,v21.16b - str q22,[x10],16 - mov v22.16b,v23.16b - str q24,[x11],16 - mov v24.16b,v25.16b - fadd v26.8h,v26.8h,v3.8h - fadd v28.8h,v28.8h,v4.8h - fadd v30.8h,v30.8h,v5.8h - str q26,[x12],16 - mov v26.16b,v27.16b - str q28,[x13],16 - mov v28.16b,v29.16b - str q30,[x4],16 - mov v30.16b,v31.16b - -M6StoreRemainderN4 - tbz x1,2,M6StoreRemainderN2 - ldr d0,[x3] - ldr d1,[x10] - ldr d2,[x11] - ldr d3,[x12] - ldr d4,[x13] - ldr d5,[x4] - fadd v21.4h,v20.4h,v0.4h - dup d20,v20.d[1] - fadd v23.4h,v22.4h,v1.4h - dup d22,v22.d[1] - fadd v25.4h,v24.4h,v2.4h - dup d24,v24.d[1] - fadd v27.4h,v26.4h,v3.4h - dup d26,v26.d[1] - fadd v29.4h,v28.4h,v4.4h - dup d28,v28.d[1] - fadd v31.4h,v30.4h,v5.4h - dup d30,v30.d[1] - str d21,[x3],8 - str d23,[x10],8 - str d25,[x11],8 - str d27,[x12],8 - str d29,[x13],8 - str d31,[x4],8 - -M6StoreRemainderN2 - tbz x1,1,M6StoreRemainderN1 - ldr s0,[x3] - ldr s1,[x10] - ldr s2,[x11] - ldr s3,[x12] - ldr s4,[x13] - ldr s5,[x4] - fadd v21.4h,v20.4h,v0.4h - fadd v23.4h,v22.4h,v1.4h - fadd v25.4h,v24.4h,v2.4h - fadd v27.4h,v26.4h,v3.4h - fadd v29.4h,v28.4h,v4.4h - fadd v31.4h,v30.4h,v5.4h - str s21,[x3],4 - str s23,[x10],4 - dup s20,v20.s[1] - dup s22,v22.s[1] - str s25,[x11],4 - str s27,[x12],4 - dup s24,v24.s[1] - dup s26,v26.s[1] - str s29,[x13],4 - str s31,[x4],4 - dup s28,v28.s[1] - dup s30,v30.s[1] - -M6StoreRemainderN1 - tbz x1,0,ExitKernel - ldr h0,[x3] - ldr h1,[x10] - ldr h2,[x11] - ldr h3,[x12] - ldr h4,[x13] - ldr h5,[x4] - fadd v20.4h,v20.4h,v0.4h - fadd v22.4h,v22.4h,v1.4h - fadd v24.4h,v24.4h,v2.4h - fadd v26.4h,v26.4h,v3.4h - fadd v28.4h,v28.4h,v4.4h - fadd v30.4h,v30.4h,v5.4h - str h20,[x3] - str h22,[x10] - str h24,[x11] - str h26,[x12] - str h28,[x13] - str h30,[x4] - b ExitKernel - -M6StoreRemainderNZeroMode - tbz x1,3,M6StoreRemainderN4ZeroMode - str q20,[x3],16 - mov v20.16b,v21.16b - str q22,[x10],16 - mov v22.16b,v23.16b - str q24,[x11],16 - mov v24.16b,v25.16b - str q26,[x12],16 - mov v26.16b,v27.16b - str q28,[x13],16 - mov v28.16b,v29.16b - str q30,[x4],16 - mov v30.16b,v31.16b - -M6StoreRemainderN4ZeroMode - tbz x1,2,M6StoreRemainderN2ZeroMode - str d20,[x3],8 - str d22,[x10],8 - dup d20,v20.d[1] - dup d22,v22.d[1] - str d24,[x11],8 - str d26,[x12],8 - dup d24,v24.d[1] - dup d26,v26.d[1] - str d28,[x13],8 - str d30,[x4],8 - dup d28,v28.d[1] - dup d30,v30.d[1] - -M6StoreRemainderN2ZeroMode - tbz x1,1,M6StoreRemainderN1ZeroMode - str s20,[x3],4 - str s22,[x10],4 - dup s20,v20.s[1] - dup s22,v22.s[1] - str s24,[x11],4 - str s26,[x12],4 - dup s24,v24.s[1] - dup s26,v26.s[1] - str s28,[x13],4 - str s30,[x4],4 - dup s28,v28.s[1] - dup s30,v30.s[1] - -M6StoreRemainderN1ZeroMode - tbz x1,0,ExitKernel - str h20,[x3] - str h22,[x10] - str h24,[x11] - str h26,[x12] - str h28,[x13] - str h30,[x4] - b ExitKernel - - LEAF_END MlasHalfGemmKernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelNeon.asm deleted file mode 100644 index f470fc1853e33..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelNeon.asm +++ /dev/null @@ -1,696 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmS8S8KernelNeon.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the S8S8 kernel. -// - -#define GemmS8S8KernelFrame_SavedNeonRegisters (8 * 8) -#define GemmS8S8KernelFrame_SavedRegisters GemmS8S8KernelFrame_SavedNeonRegisters -#define GemmS8S8KernelFrame_ColumnSumBuffer 0 + GemmS8S8KernelFrame_SavedRegisters -#define GemmS8S8KernelFrame_ZeroPointB 8 + GemmS8S8KernelFrame_SavedRegisters -#define GemmS8S8KernelFrame_ZeroMode 16 + GemmS8S8KernelFrame_SavedRegisters - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasGemmS8S8KernelNeon - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - ldr x8,[sp,#GemmS8S8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#GemmS8S8KernelFrame_ZeroPointB] - ldrb w13,[sp,#GemmS8S8KernelFrame_ZeroMode] - mov x14,x0 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq GemmS8S8_M1_ProcessLoop - cmp x4,#4 // CountM < 4? - blo GemmS8S8_M2_ProcessLoop - -// -// Process 4 rows of the matrices. -// -// B 16x4 -// /--------------------------------------\ -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v8.b[0] v9.b[0] v10.b[0] v11.b[0]| -// | ... ... ... ... | -// |v8.b[7] v9.b[7] v10.b[7] v11.b[7]| -// A 4x16 \--------------------------------------/ -// /---------------------------------\ /--------------------------------------\ -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v24.4s v25.4s v26.4s v27.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v28.4s v29.4s v30.4s v31.4s | -// \---------------------------------/ \--------------------------------------/ -// -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) - -GemmS8S8_M4_ProcessNextColumnLoop - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#64 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d4,d8,[x1],#64 // B - movi v18.4s,#0 - movi v19.4s,#0 - ldp d5,d9,[x1,#-48] - movi v20.4s,#0 - movi v21.4s,#0 - ldp d6,d10,[x1,#-32] - movi v22.4s,#0 - movi v23.4s,#0 - ldp d7,d11,[x1,#-16] - movi v24.4s,#0 - movi v25.4s,#0 - ldp d1,d3,[x0,#-48] - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -GemmS8S8_M4_ComputeBlockLoop - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0,#-32] - sadalp v16.4s,v12.8h - sadalp v17.4s,v13.8h - sadalp v18.4s,v14.8h - sadalp v19.4s,v15.8h - sub x3,x3,#1 - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - ldp d1,d3,[x0,#-16] - sadalp v20.4s,v12.8h - sadalp v21.4s,v13.8h - sadalp v22.4s,v14.8h - sadalp v23.4s,v15.8h - cbz x3,GemmS8S8_M4_ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0],#64 - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - ldp d4,d8,[x1],#64 // B - smlal v13.8h,v3.8b,v9.8b - ldp d5,d9,[x1,#-48] - smlal v14.8h,v3.8b,v10.8b - ldp d6,d10,[x1,#-32] - smlal v15.8h,v3.8b,v11.8b - ldp d7,d11,[x1,#-16] - sadalp v28.4s,v12.8h - ldp d1,d3,[x0,#-48] - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b GemmS8S8_M4_ComputeBlockLoop - -GemmS8S8_M4_ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - ld1 {v0.4s},[x7] - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v24.4s,v24.4s,v25.4s - addp v26.4s,v26.4s,v27.4s - addp v28.4s,v28.4s,v29.4s - addp v30.4s,v30.4s,v31.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - dup v8.4s,v0.s[0] // broadcast row fixups - dup v9.4s,v0.s[1] - dup v10.4s,v0.s[2] - dup v11.4s,v0.s[3] - cbz x9,GemmS8S8_M4_SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v30.4s},[x9],#16 // load ZeroPointB - mul v17.4s,v30.4s,v8.4s - mul v21.4s,v30.4s,v9.4s - mul v25.4s,v30.4s,v10.4s - mul v29.4s,v30.4s,v11.4s - add v16.4s,v16.4s,v17.4s - add v20.4s,v20.4s,v21.4s - add v24.4s,v24.4s,v25.4s - add v28.4s,v28.4s,v29.4s - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - b GemmS8S8_M4_StoreOutput - -GemmS8S8_M4_SkipScaleByZeroPointB - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v8.4s - add v20.4s,v20.4s,v9.4s - add v24.4s,v24.4s,v10.4s - add v28.4s,v28.4s,v11.4s - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - -GemmS8S8_M4_StoreOutput - add x10,x2,x6,lsl #2 - add x11,x10,x6,lsl #2 - add x12,x11,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo GemmS8S8_M4_StoreOutputPartial - cbnz x13,GemmS8S8_M4_SkipAccumulateOutput - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v3.4s - -GemmS8S8_M4_SkipAccumulateOutput - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - st1 {v24.4s},[x11] - st1 {v28.4s},[x12] - cbnz x5,GemmS8S8_M4_ProcessNextColumnLoop - -GemmS8S8_M4_ExitKernel - mov x0,#4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - - -GemmS8S8_M4_StoreOutputPartial - cbz x13,GemmS8S8_M4_StoreOutputPartial_AddMode - -GemmS8S8_M4_StoreOutputPartial_ZeroMode - tbz x5,#1,GemmS8S8_M4_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -GemmS8S8_M4_StoreOutputPartial1_ZeroMode - tbz x5,#0,GemmS8S8_M4_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b GemmS8S8_M4_ExitKernel - -GemmS8S8_M4_StoreOutputPartial_AddMode - tbz x5,#1,GemmS8S8_M4_StoreOutputPartial1_AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v3.4s - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -GemmS8S8_M4_StoreOutputPartial1_AddMode - tbz x5,#0,GemmS8S8_M4_ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v20.4s,v20.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v24.4s,v24.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - add v28.4s,v28.4s,v3.4s - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b GemmS8S8_M4_ExitKernel - -// -// Process 2 rows of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// Each row sum replicated to all 4 elements of a vector register -// v30 v31 -// B 16x4 -// /--------------------------------------\ -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 2x16 \--------------------------------------/ -// /---------------------------------\ /--------------------------------------\ -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// \---------------------------------/ \--------------------------------------/ -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -GemmS8S8_M2_ProcessLoop - -GemmS8S8_M2_ProcessNextColumnLoop - ldp d4,d24,[x1],#16 // B - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d5,d25,[x1],#16 - movi v18.4s,#0 - movi v19.4s,#0 - ldp d6,d26,[x1],#16 - movi v20.4s,#0 - movi v21.4s,#0 - ldp d7,d27,[x1],#16 - movi v22.4s,#0 - movi v23.4s,#0 - ldp d1,d3,[x0],#16 // A1 - -GemmS8S8_M2_ComputeBlockLoop - - sub x3,x3,#1 - smull v28.8h,v0.8b,v4.8b - smull v29.8h,v0.8b,v5.8b - smull v30.8h,v0.8b,v6.8b - smull v31.8h,v0.8b,v7.8b - cbz x3,GemmS8S8_M2_ComputeBlockLoopFinish - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v29.8h,v3.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v30.8h,v3.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v31.8h,v3.8b,v27.8b - ldp d7,d27,[x1],#16 - sadalp v20.4s,v28.8h - ldp d1,d3,[x0],#16 // A1 - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - b GemmS8S8_M2_ComputeBlockLoop - -GemmS8S8_M2_ComputeBlockLoopFinish - ld1 {v0.4s},[x8],#16 // load ColumnSumBuffer[0] - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldr d2,[x7] // load row sums - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - smlal v29.8h,v3.8b,v25.8b - smlal v30.8h,v3.8b,v26.8b - smlal v31.8h,v3.8b,v27.8b - sadalp v20.4s,v28.8h - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - dup v30.4s,v2.s[0] // broadcast row fixups - dup v31.4s,v2.s[1] // broadcast row fixups - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - cbz x9,GemmS8S8_M2_SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v18.4s},[x9],#16 // load ZeroPointB[0] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - mul v17.4s,v18.4s,v30.4s - mul v21.4s,v18.4s,v31.4s - add v16.4s,v16.4s,v17.4s - add v20.4s,v20.4s,v21.4s - b GemmS8S8_M2_StoreOutput - -GemmS8S8_M2_SkipScaleByZeroPointB - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - add v16.4s,v16.4s,v30.4s - add v20.4s,v20.4s,v31.4s - -GemmS8S8_M2_StoreOutput - add x10,x2,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo GemmS8S8_M2_StoreOutputPartial - cbnz x13,GemmS8S8_M2_SkipAccumulateOutput - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - -GemmS8S8_M2_SkipAccumulateOutput - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - cbnz x5,GemmS8S8_M2_ProcessNextColumnLoop - -GemmS8S8_M2_ExitKernel - mov x0,#2 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -GemmS8S8_M2_StoreOutputPartial - cbz x13,GemmS8S8_M2_StoreOutputPartial_AddMode - -GemmS8S8_M2_StoreOutputPartial_ZeroMode - tbz x5,#1,GemmS8S8_M2_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -GemmS8S8_M2_StoreOutputPartial1_ZeroMode - tbz x5,#0,GemmS8S8_M2_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b GemmS8S8_M2_ExitKernel - -GemmS8S8_M2_StoreOutputPartial_AddMode - tbz x5,#1,GemmS8S8_M2_StoreOutputPartial1_AddMode - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -GemmS8S8_M2_StoreOutputPartial1_AddMode - tbz x5,#0,GemmS8S8_M2_ExitKernel - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b GemmS8S8_M2_ExitKernel - -// -// Process 1 row of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// row sum replicated to all 4 elements of a vector register -// v31 -// B 16x4 -// /--------------------------------------\ -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 1x16 \--------------------------------------/ -// /---------------------------------\ /--------------------------------------\ -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// \---------------------------------/ \--------------------------------------/ -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -GemmS8S8_M1_ProcessLoop - ldr d31,[x7] - dup v31.4s,v31.s[0] // broadcast row fixups - -GemmS8S8_M1_ProcessNextColumnLoop - ldp d4,d24,[x1],#16 // B - ldp d5,d25,[x1],#16 - ldp d6,d26,[x1],#16 - ldp d7,d27,[x1],#16 - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - -GemmS8S8_M1_ComputeBlockLoop - sub x3,x3,#1 - smull v20.8h,v0.8b,v4.8b - smull v21.8h,v0.8b,v5.8b - cbz x3,GemmS8S8_M1_ComputeBlockLoopFinish - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v21.8h,v2.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v22.8h,v2.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v23.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - ldp d7,d27,[x1],#16 - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - b GemmS8S8_M1_ComputeBlockLoop - -GemmS8S8_M1_ComputeBlockLoopFinish - ld1 {v4.4s},[x8],#16 // load ColumnSumBuffer[0] - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - smlal v21.8h,v2.8b,v25.8b - smlal v22.8h,v2.8b,v26.8b - smlal v23.8h,v2.8b,v27.8b - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v16.4s,v16.4s,v18.4s - cbz x9,GemmS8S8_M1_SkipScaleByZeroPointB - - // accumulator = zero point B * row sum A + column sum B - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v17.4s,v30.4s,v31.4s - add v16.4s,v16.4s,v17.4s - add v16.4s,v16.4s,v4.4s - b GemmS8S8_M1_StoreOutput -GemmS8S8_M1_SkipScaleByZeroPointB - // accumulator = row sum A + column sum B - add v16.4s,v16.4s,v31.4s - add v16.4s,v16.4s,v4.4s - -GemmS8S8_M1_StoreOutput - subs x5,x5,#4 // adjust CountN remaining - blo GemmS8S8_M1_StoreOutputPartial - cbnz x13,GemmS8S8_M1_SkipAccumulateOutput - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - -GemmS8S8_M1_SkipAccumulateOutput - st1 {v16.4s},[x2],#16 - cbnz x5,GemmS8S8_M1_ProcessNextColumnLoop - -GemmS8S8_M1_ExitKernel - mov x0,#1 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -GemmS8S8_M1_StoreOutputPartial - cbz x13,GemmS8S8_M1_StoreOutputPartial_AddMode - -GemmS8S8_M1_StoreOutputPartial_ZeroMode: - tbz x5,#1,GemmS8S8_M1_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -GemmS8S8_M1_StoreOutputPartial1_ZeroMode - tbz x5,#0,GemmS8S8_M1_ExitKernel - st1 {v16.s}[0],[x2] - b GemmS8S8_M1_ExitKernel - -GemmS8S8_M1_StoreOutputPartial_AddMode - tbz x5,#1,GemmS8S8_M1_StoreOutputPartial1_AddMode - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -GemmS8S8_M1_StoreOutputPartial1_AddMode - tbz x5,#0,GemmS8S8_M1_ExitKernel - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b GemmS8S8_M1_ExitKernel - - NESTED_END MlasGemmS8S8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelSdot.asm b/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelSdot.asm deleted file mode 100644 index 32d0e37f9654e..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/QgemmS8S8KernelSdot.asm +++ /dev/null @@ -1,1054 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmS8S8KernelUdot.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the S8S8 kernel. -// Defining spaces for saving 2 vector registers, and pointers to parameters -// on the stack -// - -#define GemmS8S8KernelFrame_SavedNeonRegisters (2 * 8) -#define GemmS8S8KernelFrame_SavedRegisters GemmS8S8KernelFrame_SavedNeonRegisters -#define GemmS8S8KernelFrame_ColumnSumBuffer (0 + GemmS8S8KernelFrame_SavedRegisters) -#define GemmS8S8KernelFrame_ZeroPointB (8 + GemmS8S8KernelFrame_SavedRegisters) -#define GemmS8S8KernelFrame_ZeroMode (16 + GemmS8S8KernelFrame_SavedRegisters) - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasGemmS8S8KernelSDot - - PROLOG_SAVE_REG_PAIR d8,d9,#-16! - ldr x8,[sp,#GemmS8S8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#GemmS8S8KernelFrame_ZeroPointB] - ldrb w13,[sp,#GemmS8S8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v8.4s},[x7],#16 // load row sum 0 ~ 4 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq ProcessLoopM1 - cmp x4,#4 // CountM < 4? - blo ProcessLoopM2 - cmp x4,#8 // CountM < 8? - blo ProcessNextColumnLoopM4 - ld1 {v9.4s},[x7] // load row sum 5 ~ 8 - -// -// Process 8 rows of the matrices. -// Row Sums: v8 ~ v9 -// A 4x8 block -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// B 8x4 block -// /---------------------\ /-----------------------------------------| -// |v4.b[0] ... v4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v4.b[4] ... v4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v4.b[8] ... v4.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v4.b[12] ... v4.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v5.b[0] ... v5.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v5.b[4] ... v5.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v5.b[8] ... v5.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v5.b[12] ... v5.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ -// -// unroll for the next 4 in k dimension -// /-----------------------------------------| -// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| -// | ... ... | -// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| -// \-----------------------------------------/ -// /---------------------\ /-----------------------------------------\ -// |v6.b[0] ... v6.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v6.b[4] ... v6.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v6.b[8] ... v6.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v6.b[12] ... v6.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v7.b[0] ... v7.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v7.b[4] ... v7.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v7.b[8] ... v7.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v7.b[12] ... v7.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ - -// Starting the loop: initialize accumulators with scaled combination -// of row and column sums - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - dup v25.4s,v9.s[0] - dup v27.4s,v9.s[1] - dup v29.4s,v9.s[2] - dup v31.4s,v9.s[3] - -ProcessNextColumnLoopM8 - mov x0,x14 // reload matrix A - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v7.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,SkipScaleByZeroPointBM8 - - // accumulator = zero point B * row sum A + column sum B - ld1 {v0.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v0.4s,v17.4s - mul v18.4s,v0.4s,v19.4s - ld1 {v1.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v0.4s,v21.4s - mul v22.4s,v0.4s,v23.4s - mul v24.4s,v0.4s,v25.4s - mul v26.4s,v0.4s,v27.4s - mul v28.4s,v0.4s,v29.4s - mul v30.4s,v0.4s,v31.4s - mul v17.4s,v1.4s,v17.4s - mul v19.4s,v1.4s,v19.4s - mul v21.4s,v1.4s,v21.4s - mul v23.4s,v1.4s,v23.4s - mul v25.4s,v1.4s,v25.4s - mul v27.4s,v1.4s,v27.4s - mul v29.4s,v1.4s,v29.4s - mul v31.4s,v1.4s,v31.4s - - // preloading mixed with accumulator inits - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v16.4s - add v18.4s,v3.4s,v18.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v20.4s - add v22.4s,v3.4s,v22.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v24.4s - add v26.4s,v3.4s,v26.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v28.4s - add v30.4s,v3.4s,v30.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - b ComputeBlockLoopM8 - -SkipScaleByZeroPointBM8 - // accumulator = row sum A + column sum B - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v17.4s - add v18.4s,v3.4s,v19.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v21.4s - add v22.4s,v3.4s,v23.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v25.4s - add v26.4s,v3.4s,v27.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v29.4s - add v30.4s,v3.4s,v31.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - -ComputeBlockLoopM8 - sub x3,x3,#1 - ld1 {v3.16b},[x1],#16 // load packed B1_next4k - SdotByElement 16, 0, 4, 0 - SdotByElement 18, 0, 4, 1 - ldr q7,[x0],#16 // load packed A3 - SdotByElement 20, 0, 4, 2 - SdotByElement 22, 0, 4, 3 - cbz x3,ComputeBlockLoopFinishM8 - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - SdotByElement 21, 1, 4, 2 - SdotByElement 23, 1, 4, 3 - ldr q4,[x0],#16 // load packed A0 for next iteration - SdotByElement 24, 0, 5, 0 - SdotByElement 26, 0, 5, 1 - SdotByElement 28, 0, 5, 2 - SdotByElement 30, 0, 5, 3 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - SdotByElement 25, 1, 5, 0 - SdotByElement 27, 1, 5, 1 - SdotByElement 29, 1, 5, 2 - SdotByElement 31, 1, 5, 3 - ld1 {v1.16b},[x1],#16 // load packed B1 for next iteration - - SdotByElement 16, 2, 6, 0 - SdotByElement 18, 2, 6, 1 - ldr q5,[x0],#16 // load packed A1 for next iteration - SdotByElement 20, 2, 6, 2 - SdotByElement 22, 2, 6, 3 - SdotByElement 17, 3, 6, 0 - SdotByElement 19, 3, 6, 1 - SdotByElement 21, 3, 6, 2 - SdotByElement 23, 3, 6, 3 - ldr q6,[x0],#16 // load packed A2 for next iteration - SdotByElement 24, 2, 7, 0 - SdotByElement 26, 2, 7, 1 - SdotByElement 28, 2, 7, 2 - SdotByElement 30, 2, 7, 3 - ld1 {v2.16b},[x1],#16 // load packed B0_next4k for next iteration - SdotByElement 25, 3, 7, 0 - SdotByElement 27, 3, 7, 1 - SdotByElement 29, 3, 7, 2 - SdotByElement 31, 3, 7, 3 - b ComputeBlockLoopM8 - -ComputeBlockLoopFinishM8 - // postfix, compute tail values and prepare to write results - // We are either about to go to ProcessNextColumnLoopM8 - // where x0 and x3 are about to be restored, or exit - // when x0 and x3 will not be used. - // x4 x7 has finished their task - // so we can use x0 x3 x4 x7 as output row pointers - - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add x11,x10,x6,lsl #2 // compute output row 3 - SdotByElement 21, 1, 4, 2 - SdotByElement 23, 1, 4, 3 - add x12,x11,x6,lsl #2 // compute output row 4 - add x0,x12,x6,lsl #2 // compute output row 5 - SdotByElement 24, 0, 5, 0 - SdotByElement 26, 0, 5, 1 - add x3,x0,x6,lsl #2 // compute output row 6 - add x4,x3,x6,lsl #2 // compute output row 7 - SdotByElement 28, 0, 5, 2 - SdotByElement 30, 0, 5, 3 - add x7,x4,x6,lsl #2 // compute output row 8 - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 25, 1, 5, 0 - SdotByElement 27, 1, 5, 1 - SdotByElement 29, 1, 5, 2 - SdotByElement 31, 1, 5, 3 - SdotByElement 16, 2, 6, 0 - SdotByElement 18, 2, 6, 1 - SdotByElement 20, 2, 6, 2 - SdotByElement 22, 2, 6, 3 - SdotByElement 17, 3, 6, 0 - SdotByElement 19, 3, 6, 1 - SdotByElement 21, 3, 6, 2 - SdotByElement 23, 3, 6, 3 - SdotByElement 24, 2, 7, 0 - SdotByElement 26, 2, 7, 1 - SdotByElement 28, 2, 7, 2 - SdotByElement 30, 2, 7, 3 - SdotByElement 25, 3, 7, 0 - SdotByElement 27, 3, 7, 1 - SdotByElement 29, 3, 7, 2 - SdotByElement 31, 3, 7, 3 - blo StoreOutputPartialM8 - cbnz x13,SkipAccumulateOutputM8 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - ldp q0, q1, [x0] - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - ldp q2, q3, [x3] - add v24.4s,v24.4s,v0.4s - add v25.4s,v25.4s,v1.4s - ldp q4, q5, [x4] - add v26.4s,v26.4s,v2.4s - add v27.4s,v27.4s,v3.4s - ldp q6, q7, [x7] - add v28.4s,v28.4s,v4.4s - add v29.4s,v29.4s,v5.4s - add v30.4s,v30.4s,v6.4s - add v31.4s,v31.4s,v7.4s - -SkipAccumulateOutputM8 - stp q16,q17,[x2],#32 - dup v17.4s,v8.s[0] // broadcast row sums - stp q18,q19,[x10] - dup v19.4s,v8.s[1] - stp q20,q21,[x11] - dup v21.4s,v8.s[2] - stp q22,q23,[x12] - dup v23.4s,v8.s[3] - stp q24,q25,[x0] - dup v25.4s,v9.s[0] - stp q26,q27,[x3] - dup v27.4s,v9.s[1] - stp q28,q29,[x4] - dup v29.4s,v9.s[2] - stp q30,q31,[x7] - dup v31.4s,v9.s[3] - cbnz x5,ProcessNextColumnLoopM8 - -ExitKernelM8 - mov x0,#8 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM8 - cbz x13,StoreOutputPartialAddModeM8 - -StoreOutputPartialZeroModeM8 - tbz x5,#2,StoreOutputPartial2ZeroModeM8 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -StoreOutputPartial2ZeroModeM8 - tbz x5,#1,StoreOutputPartial1ZeroModeM8 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -StoreOutputPartial1ZeroModeM8 - tbz x5,#0,ExitKernelM8 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - st1 {v24.s}[0],[x0] - st1 {v26.s}[0],[x3] - st1 {v28.s}[0],[x4] - st1 {v30.s}[0],[x7] - b ExitKernelM8 - -StoreOutputPartialAddModeM8 - tbz x5,#2,StoreOutputPartial2AddModeM8 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - ld1 {v4.4s},[x0] - ld1 {v5.4s},[x3] - ld1 {v6.4s},[x4] - ld1 {v7.4s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -StoreOutputPartial2AddModeM8 - tbz x5,#1,StoreOutputPartial1AddModeM8 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - ld1 {v4.2s},[x0] - ld1 {v5.2s},[x3] - ld1 {v6.2s},[x4] - ld1 {v7.2s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -StoreOutputPartial1AddModeM8 - tbz x5,#0,ExitKernelM8 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - ld1 {v4.s}[0],[x0] - ld1 {v5.s}[0],[x3] - ld1 {v6.s}[0],[x4] - ld1 {v7.s}[0],[x7] - add v24.4s,v24.4s,v4.4s - st1 {v24.s}[0],[x0] - add v26.4s,v26.4s,v5.4s - st1 {v26.s}[0],[x3] - add v28.4s,v28.4s,v6.4s - st1 {v28.s}[0],[x4] - add v30.4s,v30.4s,v7.4s - st1 {v30.s}[0],[x7] - b ExitKernelM8 - - -// -// Process 4 rows of the matrices. -// -// -// The packing layout is setup to have a pair of four quad vectors from -// packed matrix A and a pair of eight quad vectors from packed matrix B. -// With this scheme, alternating loads from the packed matrices can be -// interleaved with the dot product instructions. -// -// One negative consequence of using four rows here is that the accumulator -// register tile is too small for processors with high out of order execution -// windows (such as the Apple M1). The dot product instructions for a given -// cell are too close to each other to avoid dependencies. To workaround this, -// the below loop uses a pair of accumulator registers that are then added -// together when the loop finishes. -// -// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for -// packed matrix A. At the time of this implementation, using a wider 128-bit -// load didn't affect performance for higher end cores. -// -// B 4x8 block -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// A 4x4 block -// /---------------------\ /-----------------------------------------| -// |d4.b[0] ... d4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |d4.b[4] ... d4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |d5.b[0] ... d5.b[3] | |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |d5.b[4] ... d5.b[7] | |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// \---------------------/ \-----------------------------------------/ -// unroll for the next 4 in k dimension -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// /---------------------\ /-----------------------------------------\ -// |d6.b[0] ... d6.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |d6.b[4] ... d6.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |d7.b[0] ... d7.b[3] | |v28.s[0] .. v24.s[3] v29.s[0] .. v29.s[3]| -// |d7.b[4] ... d7.b[7] | |v30.s[0] .. v24.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ - -ProcessNextColumnLoopM4 - ld1 {v0.16b},[x1],#16 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - cbz x9,SkipScaleByZeroPointBM4 - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v30.4s,v17.4s - mul v18.4s,v30.4s,v19.4s - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v30.4s,v21.4s - mul v22.4s,v30.4s,v23.4s - mul v17.4s,v31.4s,v17.4s - mul v19.4s,v31.4s,v19.4s - mul v21.4s,v31.4s,v21.4s - mul v23.4s,v31.4s,v23.4s - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - add v20.4s,v2.4s,v20.4s - add v22.4s,v2.4s,v22.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - b ComputeBlockLoopStartM4 - -SkipScaleByZeroPointBM4 - add v16.4s,v2.4s,v17.4s - add v18.4s,v2.4s,v19.4s - add v20.4s,v2.4s,v21.4s - add v22.4s,v2.4s,v23.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - -ComputeBlockLoopStartM4 - ldr d4,[x0],#32 // load packed A0.l - movi v24.4s,#0 - movi v25.4s,#0 - ldur d5,[x0,#-24] // load packed A0.h - movi v26.4s,#0 - movi v27.4s,#0 - ldur d6,[x0,#-16] // load packed A1.l - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -ComputeBlockLoopM4 - ld1 {v1.16b},[x1],#16 // load packed B1 - SdotByElement 16, 0, 4, 0 - SdotByElement 18, 0, 4, 1 - ldur d7,[x0,#-8] // load packed A1.h - SdotByElement 20, 0, 5, 0 - SdotByElement 22, 0, 5, 1 - ld1 {v0.16b},[x1],#16 // load packed B0_next4k - SdotByElement 17, 1, 4, 0 - SdotByElement 19, 1, 4, 1 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM4 - ldr d4,[x0],#32 // load packed A0.l for next iteration - SdotByElement 21, 1, 5, 0 - SdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - SdotByElement 24, 0, 6, 0 - SdotByElement 26, 0, 6, 1 - ldur d5,[x0,#-24] // load packed A0.h for next iteration - SdotByElement 28, 0, 7, 0 - SdotByElement 30, 0, 7, 1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - SdotByElement 25, 1, 6, 0 - SdotByElement 27, 1, 6, 1 - ldur d6,[x0,#-16] // load packed A1.l for next iteration - SdotByElement 29, 1, 7, 0 - SdotByElement 31, 1, 7, 1 - b ComputeBlockLoopM4 - -ComputeBlockLoopFinishM4 - SdotByElement 21, 1, 5, 0 - SdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - SdotByElement 24, 0, 6, 0 - SdotByElement 26, 0, 6, 1 - SdotByElement 28, 0, 7, 0 - SdotByElement 30, 0, 7, 1 - SdotByElement 25, 1, 6, 0 - SdotByElement 27, 1, 6, 1 - SdotByElement 29, 1, 7, 0 - SdotByElement 31, 1, 7, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add v16.4s,v16.4s,v24.4s // fold high results into low results - add v18.4s,v18.4s,v26.4s - add v20.4s,v20.4s,v28.4s - add v22.4s,v22.4s,v30.4s - add x11,x10,x6,lsl #2 // compute output row 3 - add v17.4s,v17.4s,v25.4s - add v19.4s,v19.4s,v27.4s - add v21.4s,v21.4s,v29.4s - add v23.4s,v23.4s,v31.4s - add x12,x11,x6,lsl #2 // compute output row 4 - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM4 - cbnz x13,SkipAccumulateOutputM4 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -SkipAccumulateOutputM4 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,ProcessNextColumnLoopM4 - -ExitKernelM4 - mov x0,#4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM4 - cbz x13,StoreOutputPartialAddModeM4 - -StoreOutputPartialZeroModeM4 - tbz x5,#2,StoreOutputPartial2ZeroModeM4 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2ZeroModeM4 - tbz x5,#1,StoreOutputPartial1ZeroModeM4 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1ZeroModeM4 - tbz x5,#0,ExitKernelM4 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -StoreOutputPartialAddModeM4 - tbz x5,#2,StoreOutputPartial2AddModeM4 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2AddModeM4 - tbz x5,#1,StoreOutputPartial1AddModeM4 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1AddModeM4 - tbz x5,#0,ExitKernelM4 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -// -// Process 2 rows of the matrices. -// -ProcessLoopM2 - dup v9.4s, v8.s[1] - dup v8.4s, v8.s[0] - -ProcessNextColumnLoopM2 - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,SkipScaleByZeroPointBM2 - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v16.4s,v30.4s,v8.4s - mul v18.4s,v30.4s,v9.4s - mul v17.4s,v31.4s,v8.4s - mul v19.4s,v31.4s,v9.4s - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - b ComputeBlockLoopM2 - -SkipScaleByZeroPointBM2 - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v8.4s - add v18.4s,v2.4s,v9.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v8.4s - add v19.4s,v3.4s,v9.4s - -ComputeBlockLoopM2 - sub x3,x3,#1 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - SdotByElement 16, 0, 4, 0 - SdotByElement 17, 1, 4, 0 - SdotByElement 18, 0, 4, 1 - SdotByElement 19, 1, 4, 1 - cbz x3,ComputeBlockLoopFinishM2 - ldr d4,[x0],#8 // load packed A0.l for next iter - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - SdotByElement 16, 6, 5, 0 - SdotByElement 17, 7, 5, 0 - SdotByElement 18, 6, 5, 1 - SdotByElement 19, 7, 5, 1 - ldr d5,[x0],#8 // load packed A0.h for next iter - b ComputeBlockLoopM2 - -ComputeBlockLoopFinishM2 - add x10,x2,x6,lsl #2 // compute output row 2 - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 16, 6, 5, 0 - SdotByElement 17, 7, 5, 0 - SdotByElement 18, 6, 5, 1 - SdotByElement 19, 7, 5, 1 - blo StoreOutputPartialM2 - cbnz x13,SkipAccumulateOutputM2 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -SkipAccumulateOutputM2 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,ProcessNextColumnLoopM2 - -ExitKernelM2 - mov x0,#2 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM2 - cbz x13,StoreOutputPartialAddModeM2 - -StoreOutputPartialZeroModeM2 - tbz x5,#2,StoreOutputPartial2ZeroModeM2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2ZeroModeM2 - tbz x5,#1,StoreOutputPartial1ZeroModeM2 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1ZeroModeM2 - tbz x5,#0,ExitKernelM2 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -StoreOutputPartialAddModeM2 - tbz x5,#2,StoreOutputPartial2AddModeM2 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2AddModeM2 - tbz x5,#1,StoreOutputPartial1AddModeM2 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1AddModeM2 - tbz x5,#0,ExitKernelM2 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -// -// Process 1 row of the matrices. -// - -ProcessLoopM1 - dup v8.4s,v8.s[0] -ProcessNextColumnLoopM1 - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - cbz x9,SkipScaleByZeroPointBM1 - ld1 {v30.4s},[x9],#16 // load ZeroPointB0 - ld1 {v31.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v30.4s,v8.4s - mul v17.4s,v31.4s,v8.4s - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b ComputeBlockLoopM1 - -SkipScaleByZeroPointBM1 - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v8.4s - add v17.4s,v3.4s,v8.4s - -ComputeBlockLoopM1 - sub x3,x3,#1 - SdotByElement 16, 0, 4, 0 - SdotByElement 17, 1, 4, 0 - cbz x3,ComputeBlockLoopFinishM1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - SdotByElement 16, 6, 4, 1 - SdotByElement 17, 7, 4, 1 - ldr d4,[x0],#8 // load packed A0 for next iter - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k for next iter - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k for next iter - b ComputeBlockLoopM1 - -ComputeBlockLoopFinishM1 - subs x5,x5,#8 // adjust CountN remaining - SdotByElement 16, 6, 4, 1 - SdotByElement 17, 7, 4, 1 - blo StoreOutputPartialM1 - cbnz x13,SkipAccumulateOutputM1 - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -SkipAccumulateOutputM1 - stp q16,q17,[x2],#32 - cbnz x5,ProcessNextColumnLoopM1 - -ExitKernelM1 - mov x0,#1 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM1 - cbz x13,StoreOutputPartialAddModeM1 - -StoreOutputPartialZeroModeM1 - tbz x5,#2,StoreOutputPartial2ZeroModeM1 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2ZeroModeM1 - tbz x5,#1,StoreOutputPartial1ZeroModeM1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1ZeroModeM1 - tbz x5,#0,ExitKernelM1 - st1 {v16.s}[0],[x2] - b ExitKernelM1 - -StoreOutputPartialAddModeM1 - tbz x5,#2,StoreOutputPartial2AddModeM1 - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2AddModeM1 - tbz x5,#1,StoreOutputPartial1AddModeM1 - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1AddModeM1 - tbz x5,#0,ExitKernelM1 - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b ExitKernelM1 - - NESTED_END MlasGemmS8S8KernelSDot - - END diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm deleted file mode 100644 index 8a335517c6295..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelNeon.asm +++ /dev/null @@ -1,608 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelNeon.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the U8X8 kernel. -// - -#define GemmU8X8KernelFrame_ColumnSumBuffer 0 -#define GemmU8X8KernelFrame_ZeroPointB 8 -#define GemmU8X8KernelFrame_ZeroMode 16 - -// -// Define instruction aliases not implemented by ARMASM64. -// - - MACRO - uxtl $DestReg, $SrcReg - - ushll $DestReg.,$SrcReg.,#0 - - MEND - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackANeon. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBNeon. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A multiplied by - the zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - LEAF_ENTRY MlasGemmU8X8KernelNeon - - ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB] - ldrb w13,[sp,#GemmU8X8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v27.4s},[x7] - mov x15,x3 - dup v24.4s,v27.s[0] // broadcast row fixups - cmp x4,#1 // CountM == 1? - beq ProcessNextColumnLoopM1 - dup v25.4s,v27.s[1] - cmp x4,#4 // CountM < 4? - blo ProcessNextColumnLoopM2 - dup v26.4s,v27.s[2] - dup v27.4s,v27.s[3] - -// -// Process 4 rows of the matrices. -// - -ProcessNextColumnLoopM4 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,SkipScaleByZeroPointBM4 - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - mul v18.4s,v25.4s,v28.4s - mul v19.4s,v25.4s,v29.4s - mul v20.4s,v26.4s,v28.4s - mul v21.4s,v26.4s,v29.4s - mul v22.4s,v27.4s,v28.4s - mul v23.4s,v27.4s,v29.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - add v18.4s,v2.4s,v18.4s - add v19.4s,v3.4s,v19.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v20.4s,v2.4s,v20.4s - add v21.4s,v3.4s,v21.4s - add v22.4s,v2.4s,v22.4s - add v23.4s,v3.4s,v23.4s - b ComputeBlockLoopM4 - -SkipScaleByZeroPointBM4 - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - add v18.4s,v2.4s,v25.4s - add v19.4s,v3.4s,v25.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v20.4s,v2.4s,v26.4s - add v21.4s,v3.4s,v26.4s - add v22.4s,v2.4s,v27.4s - add v23.4s,v3.4s,v27.4s - -ComputeBlockLoopM4 - uxtl v2.8h,v4.8b - uxtl v3.8h,v5.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - umlal v18.4s,v0.4h,v2.h[4] - umlal2 v19.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - umlal v20.4s,v0.4h,v3.h[0] - umlal2 v21.4s,v0.8h,v3.h[0] - umlal v22.4s,v0.4h,v3.h[4] - umlal2 v23.4s,v0.8h,v3.h[4] - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - umlal v18.4s,v1.4h,v2.h[5] - umlal2 v19.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - umlal v20.4s,v1.4h,v3.h[1] - umlal2 v21.4s,v1.8h,v3.h[1] - umlal v22.4s,v1.4h,v3.h[5] - umlal2 v23.4s,v1.8h,v3.h[5] - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM4 - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - umlal v20.4s,v0.4h,v3.h[2] - umlal2 v21.4s,v0.8h,v3.h[2] - umlal v22.4s,v0.4h,v3.h[6] - umlal2 v23.4s,v0.8h,v3.h[6] - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - ld1 {v5.8b},[x0],#8 // load next packed A1 - umlal v20.4s,v1.4h,v3.h[3] - umlal2 v21.4s,v1.8h,v3.h[3] - umlal v22.4s,v1.4h,v3.h[7] - umlal2 v23.4s,v1.8h,v3.h[7] - b ComputeBlockLoopM4 - -ComputeBlockLoopFinishM4 - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v20.4s,v0.4h,v3.h[2] - umlal2 v21.4s,v0.8h,v3.h[2] - umlal v22.4s,v0.4h,v3.h[6] - umlal2 v23.4s,v0.8h,v3.h[6] - add x11,x10,x6,lsl #2 // compute output row 3 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - umlal v20.4s,v1.4h,v3.h[3] - umlal2 v21.4s,v1.8h,v3.h[3] - add x12,x11,x6,lsl #2 // compute output row 4 - umlal v22.4s,v1.4h,v3.h[7] - umlal2 v23.4s,v1.8h,v3.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM4 - cbnz x13,SkipAccumulateOutputM4 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -SkipAccumulateOutputM4 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,ProcessNextColumnLoopM4 - -ExitKernelM4 - mov x0,#4 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM4 - cbz x13,StoreOutputPartialAddModeM4 - -StoreOutputPartialZeroModeM4 - tbz x5,#2,StoreOutputPartial2ZeroModeM4 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2ZeroModeM4 - tbz x5,#1,StoreOutputPartial1ZeroModeM4 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1ZeroModeM4 - tbz x5,#0,ExitKernelM4 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -StoreOutputPartialAddModeM4 - tbz x5,#2,StoreOutputPartial2AddModeM4 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2AddModeM4 - tbz x5,#1,StoreOutputPartial1AddModeM4 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1AddModeM4 - tbz x5,#0,ExitKernelM4 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -// -// Process 2 rows of the matrices. -// - -ProcessNextColumnLoopM2 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,SkipScaleByZeroPointBM2 - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - mul v18.4s,v25.4s,v28.4s - mul v19.4s,v25.4s,v29.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - add v18.4s,v2.4s,v18.4s - add v19.4s,v3.4s,v19.4s - b ComputeBlockLoopM2 - -SkipScaleByZeroPointBM2 - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - add v18.4s,v2.4s,v25.4s - add v19.4s,v3.4s,v25.4s - -ComputeBlockLoopM2 - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - umlal v18.4s,v0.4h,v2.h[4] - umlal2 v19.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - umlal v18.4s,v1.4h,v2.h[5] - umlal2 v19.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM2 - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - b ComputeBlockLoopM2 - -ComputeBlockLoopFinishM2 - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v18.4s,v0.4h,v2.h[6] - umlal2 v19.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - umlal v18.4s,v1.4h,v2.h[7] - umlal2 v19.4s,v1.8h,v2.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM2 - cbnz x13,SkipAccumulateOutputM2 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -SkipAccumulateOutputM2 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,ProcessNextColumnLoopM2 - -ExitKernelM2 - mov x0,#2 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM2 - cbz x13,StoreOutputPartialAddModeM2 - -StoreOutputPartialZeroModeM2 - tbz x5,#2,StoreOutputPartial2ZeroModeM2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2ZeroModeM2 - tbz x5,#1,StoreOutputPartial1ZeroModeM2 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1ZeroModeM2 - tbz x5,#0,ExitKernelM2 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -StoreOutputPartialAddModeM2 - tbz x5,#2,StoreOutputPartial2AddModeM2 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2AddModeM2 - tbz x5,#1,StoreOutputPartial1AddModeM2 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1AddModeM2 - tbz x5,#0,ExitKernelM2 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -// -// Process 1 row of the matrices. -// - -ProcessNextColumnLoopM1 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - cbz x9,SkipScaleByZeroPointBM1 - ld1 {v28.4s},[x9],#16 // load ZeroPointB0 - ld1 {v29.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v24.4s,v28.4s - mul v17.4s,v24.4s,v29.4s - ldr s4,[x0],#4 // load first packed A0 - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b ComputeBlockLoopM1 - -SkipScaleByZeroPointBM1 - ldr s4,[x0],#4 // load first packed A0 - add v16.4s,v2.4s,v24.4s - add v17.4s,v3.4s,v24.4s - -ComputeBlockLoopM1 - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v16.4s,v0.4h,v2.h[0] - umlal2 v17.4s,v0.8h,v2.h[0] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v16.4s,v1.4h,v2.h[1] - umlal2 v17.4s,v1.8h,v2.h[1] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM1 - umlal v16.4s,v0.4h,v2.h[2] - umlal2 v17.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - ldr s4,[x0],#4 // load first packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - uxtl v0.8h,v0.8b - b ComputeBlockLoopM1 - -ComputeBlockLoopFinishM1 - umlal v16.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v17.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - umlal v16.4s,v1.4h,v2.h[3] - umlal2 v17.4s,v1.8h,v2.h[3] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM1 - cbnz x13,SkipAccumulateOutputM1 - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -SkipAccumulateOutputM1 - stp q16,q17,[x2],#32 - cbnz x5,ProcessNextColumnLoopM1 - -ExitKernelM1 - mov x0,#1 // return number of rows handled - ret - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM1 - cbz x13,StoreOutputPartialAddModeM1 - -StoreOutputPartialZeroModeM1 - tbz x5,#2,StoreOutputPartial2ZeroModeM1 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2ZeroModeM1 - tbz x5,#1,StoreOutputPartial1ZeroModeM1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1ZeroModeM1 - tbz x5,#0,ExitKernelM1 - st1 {v16.s}[0],[x2] - b ExitKernelM1 - -StoreOutputPartialAddModeM1 - tbz x5,#2,StoreOutputPartial2AddModeM1 - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2AddModeM1 - tbz x5,#1,StoreOutputPartial1AddModeM1 - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1AddModeM1 - tbz x5,#0,ExitKernelM1 - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b ExitKernelM1 - - LEAF_END MlasGemmU8X8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm b/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm deleted file mode 100644 index 372ade9e876bc..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/QgemmU8X8KernelUdot.asm +++ /dev/null @@ -1,1054 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelUdot.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses ARM v8.4 dot product instructions. - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the U8X8 kernel. -// Defining spaces for saving 2 vector registers, and pointers to parameters -// on the stack -// - -#define GemmU8X8KernelFrame_SavedNeonRegisters (2 * 8) -#define GemmU8X8KernelFrame_SavedRegisters GemmU8X8KernelFrame_SavedNeonRegisters -#define GemmU8X8KernelFrame_ColumnSumBuffer (0 + GemmU8X8KernelFrame_SavedRegisters) -#define GemmU8X8KernelFrame_ZeroPointB (8 + GemmU8X8KernelFrame_SavedRegisters) -#define GemmU8X8KernelFrame_ZeroMode (16 + GemmU8X8KernelFrame_SavedRegisters) - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmQuantCopyPackA. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values - have been pre-scaled by the zero point offset of matrix B if the offset - is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasGemmU8X8KernelUdot - - PROLOG_SAVE_REG_PAIR d8,d9,#-16! - ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB] - ldrb w13,[sp,#GemmU8X8KernelFrame_ZeroMode] - mov x14,x0 - ld1 {v8.4s},[x7],#16 // load row sum 0 ~ 4 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq ProcessLoopM1 - cmp x4,#4 // CountM < 4? - blo ProcessLoopM2 - cmp x4,#8 // CountM < 8? - blo ProcessNextColumnLoopM4 - ld1 {v9.4s},[x7] // load row sum 5 ~ 8 - -// -// Process 8 rows of the matrices. -// Row Sums: v8 ~ v9 -// A 4x8 block -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// B 8x4 block -// /---------------------\ /-----------------------------------------| -// |v4.b[0] ... v4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v4.b[4] ... v4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v4.b[8] ... v4.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v4.b[12] ... v4.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v5.b[0] ... v5.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v5.b[4] ... v5.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v5.b[8] ... v5.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v5.b[12] ... v5.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ -// -// unroll for the next 4 in k dimension -// /-----------------------------------------| -// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| -// | ... ... | -// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| -// \-----------------------------------------/ -// /---------------------\ /-----------------------------------------\ -// |v6.b[0] ... v6.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |v6.b[4] ... v6.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |v6.b[8] ... v6.b[11]| |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |v6.b[12] ... v6.b[15]| |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// |v7.b[0] ... v7.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |v7.b[4] ... v7.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |v7.b[8] ... v7.b[11]| |v28.s[0] .. v28.s[3] v29.s[0] .. v29.s[3]| -// |v7.b[12] ... v7.b[15]| |v30.s[0] .. v30.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ - -// Starting the loop: initialize accumulators with scaled combination -// of row and column sums - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - dup v25.4s,v9.s[0] - dup v27.4s,v9.s[1] - dup v29.4s,v9.s[2] - dup v31.4s,v9.s[3] - -ProcessNextColumnLoopM8 - mov x0,x14 // reload matrix A - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v7.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,SkipScaleByZeroPointBM8 - - // accumulator = zero point B * row sum A + column sum B - ld1 {v0.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v0.4s,v17.4s - mul v18.4s,v0.4s,v19.4s - ld1 {v1.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v0.4s,v21.4s - mul v22.4s,v0.4s,v23.4s - mul v24.4s,v0.4s,v25.4s - mul v26.4s,v0.4s,v27.4s - mul v28.4s,v0.4s,v29.4s - mul v30.4s,v0.4s,v31.4s - mul v17.4s,v1.4s,v17.4s - mul v19.4s,v1.4s,v19.4s - mul v21.4s,v1.4s,v21.4s - mul v23.4s,v1.4s,v23.4s - mul v25.4s,v1.4s,v25.4s - mul v27.4s,v1.4s,v27.4s - mul v29.4s,v1.4s,v29.4s - mul v31.4s,v1.4s,v31.4s - - // preloading mixed with accumulator inits - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v16.4s - add v18.4s,v3.4s,v18.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v20.4s - add v22.4s,v3.4s,v22.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v24.4s - add v26.4s,v3.4s,v26.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v28.4s - add v30.4s,v3.4s,v30.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - b ComputeBlockLoopM8 - -SkipScaleByZeroPointBM8 - // accumulator = row sum A + column sum B - ld1 {v0.16b},[x1],#16 // load packed B0 - add v16.4s,v3.4s,v17.4s - add v18.4s,v3.4s,v19.4s - ldr q4,[x0],#16 // load packed A0 - add v20.4s,v3.4s,v21.4s - add v22.4s,v3.4s,v23.4s - ldr q5,[x0],#16 // load packed A1 - add v24.4s,v3.4s,v25.4s - add v26.4s,v3.4s,v27.4s - ld1 {v1.16b},[x1],#16 // load packed B1 - add v28.4s,v3.4s,v29.4s - add v30.4s,v3.4s,v31.4s - ldr q6,[x0],#16 // load packed A2 - add v17.4s,v7.4s,v17.4s - add v19.4s,v7.4s,v19.4s - ld1 {v2.16b},[x1],#16 // load packed B0_next4k - add v21.4s,v7.4s,v21.4s - add v23.4s,v7.4s,v23.4s - add v25.4s,v7.4s,v25.4s - add v27.4s,v7.4s,v27.4s - add v29.4s,v7.4s,v29.4s - add v31.4s,v7.4s,v31.4s - -ComputeBlockLoopM8 - sub x3,x3,#1 - ld1 {v3.16b},[x1],#16 // load packed B1_next4k - UdotByElement 16, 0, 4, 0 - UdotByElement 18, 0, 4, 1 - ldr q7,[x0],#16 // load packed A3 - UdotByElement 20, 0, 4, 2 - UdotByElement 22, 0, 4, 3 - cbz x3,ComputeBlockLoopFinishM8 - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - UdotByElement 21, 1, 4, 2 - UdotByElement 23, 1, 4, 3 - ldr q4,[x0],#16 // load packed A0 for next iteration - UdotByElement 24, 0, 5, 0 - UdotByElement 26, 0, 5, 1 - UdotByElement 28, 0, 5, 2 - UdotByElement 30, 0, 5, 3 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - UdotByElement 25, 1, 5, 0 - UdotByElement 27, 1, 5, 1 - UdotByElement 29, 1, 5, 2 - UdotByElement 31, 1, 5, 3 - ld1 {v1.16b},[x1],#16 // load packed B1 for next iteration - - UdotByElement 16, 2, 6, 0 - UdotByElement 18, 2, 6, 1 - ldr q5,[x0],#16 // load packed A1 for next iteration - UdotByElement 20, 2, 6, 2 - UdotByElement 22, 2, 6, 3 - UdotByElement 17, 3, 6, 0 - UdotByElement 19, 3, 6, 1 - UdotByElement 21, 3, 6, 2 - UdotByElement 23, 3, 6, 3 - ldr q6,[x0],#16 // load packed A2 for next iteration - UdotByElement 24, 2, 7, 0 - UdotByElement 26, 2, 7, 1 - UdotByElement 28, 2, 7, 2 - UdotByElement 30, 2, 7, 3 - ld1 {v2.16b},[x1],#16 // load packed B0_next4k for next iteration - UdotByElement 25, 3, 7, 0 - UdotByElement 27, 3, 7, 1 - UdotByElement 29, 3, 7, 2 - UdotByElement 31, 3, 7, 3 - b ComputeBlockLoopM8 - -ComputeBlockLoopFinishM8 - // postfix, compute tail values and prepare to write results - // We are either about to go to ProcessNextColumnLoopM8 - // where x0 and x3 are about to be restored, or exit - // when x0 and x3 will not be used. - // x4 x7 has finished their task - // so we can use x0 x3 x4 x7 as output row pointers - - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add x11,x10,x6,lsl #2 // compute output row 3 - UdotByElement 21, 1, 4, 2 - UdotByElement 23, 1, 4, 3 - add x12,x11,x6,lsl #2 // compute output row 4 - add x0,x12,x6,lsl #2 // compute output row 5 - UdotByElement 24, 0, 5, 0 - UdotByElement 26, 0, 5, 1 - add x3,x0,x6,lsl #2 // compute output row 6 - add x4,x3,x6,lsl #2 // compute output row 7 - UdotByElement 28, 0, 5, 2 - UdotByElement 30, 0, 5, 3 - add x7,x4,x6,lsl #2 // compute output row 8 - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 25, 1, 5, 0 - UdotByElement 27, 1, 5, 1 - UdotByElement 29, 1, 5, 2 - UdotByElement 31, 1, 5, 3 - UdotByElement 16, 2, 6, 0 - UdotByElement 18, 2, 6, 1 - UdotByElement 20, 2, 6, 2 - UdotByElement 22, 2, 6, 3 - UdotByElement 17, 3, 6, 0 - UdotByElement 19, 3, 6, 1 - UdotByElement 21, 3, 6, 2 - UdotByElement 23, 3, 6, 3 - UdotByElement 24, 2, 7, 0 - UdotByElement 26, 2, 7, 1 - UdotByElement 28, 2, 7, 2 - UdotByElement 30, 2, 7, 3 - UdotByElement 25, 3, 7, 0 - UdotByElement 27, 3, 7, 1 - UdotByElement 29, 3, 7, 2 - UdotByElement 31, 3, 7, 3 - blo StoreOutputPartialM8 - cbnz x13,SkipAccumulateOutputM8 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - ldp q0, q1, [x0] - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - ldp q2, q3, [x3] - add v24.4s,v24.4s,v0.4s - add v25.4s,v25.4s,v1.4s - ldp q4, q5, [x4] - add v26.4s,v26.4s,v2.4s - add v27.4s,v27.4s,v3.4s - ldp q6, q7, [x7] - add v28.4s,v28.4s,v4.4s - add v29.4s,v29.4s,v5.4s - add v30.4s,v30.4s,v6.4s - add v31.4s,v31.4s,v7.4s - -SkipAccumulateOutputM8 - stp q16,q17,[x2],#32 - dup v17.4s,v8.s[0] // broadcast row sums - stp q18,q19,[x10] - dup v19.4s,v8.s[1] - stp q20,q21,[x11] - dup v21.4s,v8.s[2] - stp q22,q23,[x12] - dup v23.4s,v8.s[3] - stp q24,q25,[x0] - dup v25.4s,v9.s[0] - stp q26,q27,[x3] - dup v27.4s,v9.s[1] - stp q28,q29,[x4] - dup v29.4s,v9.s[2] - stp q30,q31,[x7] - dup v31.4s,v9.s[3] - cbnz x5,ProcessNextColumnLoopM8 - -ExitKernelM8 - mov x0,#8 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM8 - cbz x13,StoreOutputPartialAddModeM8 - -StoreOutputPartialZeroModeM8 - tbz x5,#2,StoreOutputPartial2ZeroModeM8 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -StoreOutputPartial2ZeroModeM8 - tbz x5,#1,StoreOutputPartial1ZeroModeM8 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -StoreOutputPartial1ZeroModeM8 - tbz x5,#0,ExitKernelM8 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - st1 {v24.s}[0],[x0] - st1 {v26.s}[0],[x3] - st1 {v28.s}[0],[x4] - st1 {v30.s}[0],[x7] - b ExitKernelM8 - -StoreOutputPartialAddModeM8 - tbz x5,#2,StoreOutputPartial2AddModeM8 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - ld1 {v4.4s},[x0] - ld1 {v5.4s},[x3] - ld1 {v6.4s},[x4] - ld1 {v7.4s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.4s},[x0],#16 - mov v24.16b,v25.16b - st1 {v26.4s},[x3],#16 - mov v26.16b,v27.16b - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.4s},[x4],#16 - mov v28.16b,v29.16b - st1 {v30.4s},[x7],#16 - mov v30.16b,v31.16b - -StoreOutputPartial2AddModeM8 - tbz x5,#1,StoreOutputPartial1AddModeM8 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - ld1 {v4.2s},[x0] - ld1 {v5.2s},[x3] - ld1 {v6.2s},[x4] - ld1 {v7.2s},[x7] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - add v24.4s,v24.4s,v4.4s - add v26.4s,v26.4s,v5.4s - st1 {v24.2s},[x0],#8 - dup v24.4s,v24.s[2] - st1 {v26.2s},[x3],#8 - dup v26.4s,v26.s[2] - add v28.4s,v28.4s,v6.4s - add v30.4s,v30.4s,v7.4s - st1 {v28.2s},[x4],#8 - dup v28.4s,v28.s[2] - st1 {v30.2s},[x7],#8 - dup v30.4s,v30.s[2] - -StoreOutputPartial1AddModeM8 - tbz x5,#0,ExitKernelM8 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - ld1 {v4.s}[0],[x0] - ld1 {v5.s}[0],[x3] - ld1 {v6.s}[0],[x4] - ld1 {v7.s}[0],[x7] - add v24.4s,v24.4s,v4.4s - st1 {v24.s}[0],[x0] - add v26.4s,v26.4s,v5.4s - st1 {v26.s}[0],[x3] - add v28.4s,v28.4s,v6.4s - st1 {v28.s}[0],[x4] - add v30.4s,v30.4s,v7.4s - st1 {v30.s}[0],[x7] - b ExitKernelM8 - - -// -// Process 4 rows of the matrices. -// -// -// The packing layout is setup to have a pair of four quad vectors from -// packed matrix A and a pair of eight quad vectors from packed matrix B. -// With this scheme, alternating loads from the packed matrices can be -// interleaved with the dot product instructions. -// -// One negative consequence of using four rows here is that the accumulator -// register tile is too small for processors with high out of order execution -// windows (such as the Apple M1). The dot product instructions for a given -// cell are too close to each other to avoid dependencies. To workaround this, -// the below loop uses a pair of accumulator registers that are then added -// together when the loop finishes. -// -// A55-based cores are optimized for 64-bit loads, so use 64-bit loads for -// packed matrix A. At the time of this implementation, using a wider 128-bit -// load didn't affect performance for higher end cores. -// -// B 4x8 block -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// A 4x4 block -// /---------------------\ /-----------------------------------------| -// |d4.b[0] ... d4.b[3] | |v16.s[0] .. v16.s[3] v17.s[0] .. v17.s[3]| -// |d4.b[4] ... d4.b[7] | |v18.s[0] .. v18.s[3] v19.s[0] .. v19.s[3]| -// |d5.b[0] ... d5.b[3] | |v20.s[0] .. v20.s[3] v21.s[0] .. v21.s[3]| -// |d5.b[4] ... d5.b[7] | |v22.s[0] .. v22.s[3] v23.s[0] .. v23.s[3]| -// \---------------------/ \-----------------------------------------/ -// unroll for the next 4 in k dimension -// /-----------------------------------------| -// |v0.b[0] ... v0.b[12] v1.b[0] ... v1.b[12]| -// | ... ... | -// |v0.b[3] ... v0.b[15] v1.b[3] ... v1.b[15]| -// \-----------------------------------------/ -// /---------------------\ /-----------------------------------------\ -// |d6.b[0] ... d6.b[3] | |v24.s[0] .. v24.s[3] v25.s[0] .. v25.s[3]| -// |d6.b[4] ... d6.b[7] | |v26.s[0] .. v26.s[3] v27.s[0] .. v27.s[3]| -// |d7.b[0] ... d7.b[3] | |v28.s[0] .. v24.s[3] v29.s[0] .. v29.s[3]| -// |d7.b[4] ... d7.b[7] | |v30.s[0] .. v24.s[3] v31.s[0] .. v31.s[3]| -// \---------------------/ \-----------------------------------------/ - -ProcessNextColumnLoopM4 - ld1 {v0.16b},[x1],#16 // load packed B0 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - dup v17.4s,v8.s[0] // broadcast row sums - dup v19.4s,v8.s[1] - dup v21.4s,v8.s[2] - dup v23.4s,v8.s[3] - cbz x9,SkipScaleByZeroPointBM4 - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - mul v16.4s,v30.4s,v17.4s - mul v18.4s,v30.4s,v19.4s - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v20.4s,v30.4s,v21.4s - mul v22.4s,v30.4s,v23.4s - mul v17.4s,v31.4s,v17.4s - mul v19.4s,v31.4s,v19.4s - mul v21.4s,v31.4s,v21.4s - mul v23.4s,v31.4s,v23.4s - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - add v20.4s,v2.4s,v20.4s - add v22.4s,v2.4s,v22.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - b ComputeBlockLoopStartM4 - -SkipScaleByZeroPointBM4 - add v16.4s,v2.4s,v17.4s - add v18.4s,v2.4s,v19.4s - add v20.4s,v2.4s,v21.4s - add v22.4s,v2.4s,v23.4s - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - add v21.4s,v3.4s,v21.4s - add v23.4s,v3.4s,v23.4s - -ComputeBlockLoopStartM4 - ldr d4,[x0],#32 // load packed A0.l - movi v24.4s,#0 - movi v25.4s,#0 - ldur d5,[x0,#-24] // load packed A0.h - movi v26.4s,#0 - movi v27.4s,#0 - ldur d6,[x0,#-16] // load packed A1.l - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - -ComputeBlockLoopM4 - ld1 {v1.16b},[x1],#16 // load packed B1 - UdotByElement 16, 0, 4, 0 - UdotByElement 18, 0, 4, 1 - ldur d7,[x0,#-8] // load packed A1.h - UdotByElement 20, 0, 5, 0 - UdotByElement 22, 0, 5, 1 - ld1 {v0.16b},[x1],#16 // load packed B0_next4k - UdotByElement 17, 1, 4, 0 - UdotByElement 19, 1, 4, 1 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM4 - ldr d4,[x0],#32 // load packed A0.l for next iteration - UdotByElement 21, 1, 5, 0 - UdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - UdotByElement 24, 0, 6, 0 - UdotByElement 26, 0, 6, 1 - ldur d5,[x0,#-24] // load packed A0.h for next iteration - UdotByElement 28, 0, 7, 0 - UdotByElement 30, 0, 7, 1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iteration - UdotByElement 25, 1, 6, 0 - UdotByElement 27, 1, 6, 1 - ldur d6,[x0,#-16] // load packed A1.l for next iteration - UdotByElement 29, 1, 7, 0 - UdotByElement 31, 1, 7, 1 - b ComputeBlockLoopM4 - -ComputeBlockLoopFinishM4 - UdotByElement 21, 1, 5, 0 - UdotByElement 23, 1, 5, 1 - ld1 {v1.16b},[x1],#16 // load packed B1_next4k - UdotByElement 24, 0, 6, 0 - UdotByElement 26, 0, 6, 1 - UdotByElement 28, 0, 7, 0 - UdotByElement 30, 0, 7, 1 - UdotByElement 25, 1, 6, 0 - UdotByElement 27, 1, 6, 1 - UdotByElement 29, 1, 7, 0 - UdotByElement 31, 1, 7, 1 - add x10,x2,x6,lsl #2 // compute output row 2 - add v16.4s,v16.4s,v24.4s // fold high results into low results - add v18.4s,v18.4s,v26.4s - add v20.4s,v20.4s,v28.4s - add v22.4s,v22.4s,v30.4s - add x11,x10,x6,lsl #2 // compute output row 3 - add v17.4s,v17.4s,v25.4s - add v19.4s,v19.4s,v27.4s - add v21.4s,v21.4s,v29.4s - add v23.4s,v23.4s,v31.4s - add x12,x11,x6,lsl #2 // compute output row 4 - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM4 - cbnz x13,SkipAccumulateOutputM4 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - ldp q4,q5,[x11] - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - ldp q6,q7,[x12] - add v20.4s,v20.4s,v4.4s - add v21.4s,v21.4s,v5.4s - add v22.4s,v22.4s,v6.4s - add v23.4s,v23.4s,v7.4s - -SkipAccumulateOutputM4 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - stp q20,q21,[x11] - stp q22,q23,[x12] - cbnz x5,ProcessNextColumnLoopM4 - -ExitKernelM4 - mov x0,#4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM4 - cbz x13,StoreOutputPartialAddModeM4 - -StoreOutputPartialZeroModeM4 - tbz x5,#2,StoreOutputPartial2ZeroModeM4 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2ZeroModeM4 - tbz x5,#1,StoreOutputPartial1ZeroModeM4 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1ZeroModeM4 - tbz x5,#0,ExitKernelM4 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -StoreOutputPartialAddModeM4 - tbz x5,#2,StoreOutputPartial2AddModeM4 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.4s},[x11],#16 - mov v20.16b,v21.16b - st1 {v22.4s},[x12],#16 - mov v22.16b,v23.16b - -StoreOutputPartial2AddModeM4 - tbz x5,#1,StoreOutputPartial1AddModeM4 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - add v20.4s,v20.4s,v2.4s - add v22.4s,v22.4s,v3.4s - st1 {v20.2s},[x11],#8 - dup v20.4s,v20.s[2] - st1 {v22.2s},[x12],#8 - dup v22.4s,v22.s[2] - -StoreOutputPartial1AddModeM4 - tbz x5,#0,ExitKernelM4 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v18.4s,v18.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v20.4s,v20.4s,v2.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - add v22.4s,v22.4s,v3.4s - st1 {v20.s}[0],[x11] - st1 {v22.s}[0],[x12] - b ExitKernelM4 - -// -// Process 2 rows of the matrices. -// -ProcessLoopM2 - dup v9.4s, v8.s[1] - dup v8.4s, v8.s[0] - -ProcessNextColumnLoopM2 - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer[0] - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer[4] - cbz x9,SkipScaleByZeroPointBM2 - ld1 {v30.4s},[x9],#16 // load ZeroPointB[0] - ld1 {v31.4s},[x9],#16 // load ZeroPointB[4] - mul v16.4s,v30.4s,v8.4s - mul v18.4s,v30.4s,v9.4s - mul v17.4s,v31.4s,v8.4s - mul v19.4s,v31.4s,v9.4s - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v16.4s - add v18.4s,v2.4s,v18.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v17.4s - add v19.4s,v3.4s,v19.4s - b ComputeBlockLoopM2 - -SkipScaleByZeroPointBM2 - ldr d4,[x0],#8 // load packed A0.l - add v16.4s,v2.4s,v8.4s - add v18.4s,v2.4s,v9.4s - ldr d5,[x0],#8 // load packed A0.h - add v17.4s,v3.4s,v8.4s - add v19.4s,v3.4s,v9.4s - -ComputeBlockLoopM2 - sub x3,x3,#1 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - UdotByElement 16, 0, 4, 0 - UdotByElement 17, 1, 4, 0 - UdotByElement 18, 0, 4, 1 - UdotByElement 19, 1, 4, 1 - cbz x3,ComputeBlockLoopFinishM2 - ldr d4,[x0],#8 // load packed A0.l for next iter - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - UdotByElement 16, 6, 5, 0 - UdotByElement 17, 7, 5, 0 - UdotByElement 18, 6, 5, 1 - UdotByElement 19, 7, 5, 1 - ldr d5,[x0],#8 // load packed A0.h for next iter - b ComputeBlockLoopM2 - -ComputeBlockLoopFinishM2 - add x10,x2,x6,lsl #2 // compute output row 2 - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 16, 6, 5, 0 - UdotByElement 17, 7, 5, 0 - UdotByElement 18, 6, 5, 1 - UdotByElement 19, 7, 5, 1 - blo StoreOutputPartialM2 - cbnz x13,SkipAccumulateOutputM2 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - add v18.4s,v18.4s,v2.4s - add v19.4s,v19.4s,v3.4s - -SkipAccumulateOutputM2 - stp q16,q17,[x2],#32 - stp q18,q19,[x10] - cbnz x5,ProcessNextColumnLoopM2 - -ExitKernelM2 - mov x0,#2 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM2 - cbz x13,StoreOutputPartialAddModeM2 - -StoreOutputPartialZeroModeM2 - tbz x5,#2,StoreOutputPartial2ZeroModeM2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2ZeroModeM2 - tbz x5,#1,StoreOutputPartial1ZeroModeM2 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1ZeroModeM2 - tbz x5,#0,ExitKernelM2 - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -StoreOutputPartialAddModeM2 - tbz x5,#2,StoreOutputPartial2AddModeM2 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - st1 {v18.4s},[x10],#16 - mov v18.16b,v19.16b - -StoreOutputPartial2AddModeM2 - tbz x5,#1,StoreOutputPartial1AddModeM2 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v18.2s},[x10],#8 - dup v18.4s,v18.s[2] - -StoreOutputPartial1AddModeM2 - tbz x5,#0,ExitKernelM2 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v16.4s,v16.4s,v0.4s - add v18.4s,v18.4s,v1.4s - st1 {v16.s}[0],[x2] - st1 {v18.s}[0],[x10] - b ExitKernelM2 - -// -// Process 1 row of the matrices. -// - -ProcessLoopM1 - dup v8.4s,v8.s[0] -ProcessNextColumnLoopM1 - ld1 {v0.16b},[x1],#16 // load packed B0 - ld1 {v1.16b},[x1],#16 // load packed B1 - mov x0,x14 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x15 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - cbz x9,SkipScaleByZeroPointBM1 - ld1 {v30.4s},[x9],#16 // load ZeroPointB0 - ld1 {v31.4s},[x9],#16 // load ZeroPointB1 - mul v16.4s,v30.4s,v8.4s - mul v17.4s,v31.4s,v8.4s - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v16.4s - add v17.4s,v3.4s,v17.4s - b ComputeBlockLoopM1 - -SkipScaleByZeroPointBM1 - ldr d4,[x0],#8 // load packed A0 - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k - add v16.4s,v2.4s,v8.4s - add v17.4s,v3.4s,v8.4s - -ComputeBlockLoopM1 - sub x3,x3,#1 - UdotByElement 16, 0, 4, 0 - UdotByElement 17, 1, 4, 0 - cbz x3,ComputeBlockLoopFinishM1 - ld1 {v0.16b},[x1],#16 // load packed B0 for next iter - ld1 {v1.16b},[x1],#16 // load packed B1 for next iter - UdotByElement 16, 6, 4, 1 - UdotByElement 17, 7, 4, 1 - ldr d4,[x0],#8 // load packed A0 for next iter - ld1 {v6.16b},[x1],#16 // load packed B0 next 4 k for next iter - ld1 {v7.16b},[x1],#16 // load packed B1 next 4 k for next iter - b ComputeBlockLoopM1 - -ComputeBlockLoopFinishM1 - subs x5,x5,#8 // adjust CountN remaining - UdotByElement 16, 6, 4, 1 - UdotByElement 17, 7, 4, 1 - blo StoreOutputPartialM1 - cbnz x13,SkipAccumulateOutputM1 - ldp q0,q1,[x2] - add v16.4s,v16.4s,v0.4s - add v17.4s,v17.4s,v1.4s - -SkipAccumulateOutputM1 - stp q16,q17,[x2],#32 - cbnz x5,ProcessNextColumnLoopM1 - -ExitKernelM1 - mov x0,#1 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d8,d9,#16! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM1 - cbz x13,StoreOutputPartialAddModeM1 - -StoreOutputPartialZeroModeM1 - tbz x5,#2,StoreOutputPartial2ZeroModeM1 - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2ZeroModeM1 - tbz x5,#1,StoreOutputPartial1ZeroModeM1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1ZeroModeM1 - tbz x5,#0,ExitKernelM1 - st1 {v16.s}[0],[x2] - b ExitKernelM1 - -StoreOutputPartialAddModeM1 - tbz x5,#2,StoreOutputPartial2AddModeM1 - ld1 {v0.4s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.4s},[x2],#16 - mov v16.16b,v17.16b // shift remaining elements down - -StoreOutputPartial2AddModeM1 - tbz x5,#1,StoreOutputPartial1AddModeM1 - ld1 {v0.2s},[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -StoreOutputPartial1AddModeM1 - tbz x5,#0,ExitKernelM1 - ld1 {v0.s}[0],[x2] - add v16.4s,v16.4s,v0.4s - st1 {v16.s}[0],[x2] - b ExitKernelM1 - - NESTED_END MlasGemmU8X8KernelUdot - - END diff --git a/onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm deleted file mode 100644 index 33b1ed1e641aa..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/SgemmKernelNeon.asm +++ /dev/null @@ -1,502 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelNeon.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -;-- - -#include "kxarm64.h" - - TEXTAREA - -; -; ClearRowAccumulators -; -; Generates the code to clear the accumulators for a single row of the output -; block. -; - - MACRO - ClearRowAccumulators $Columns, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - movi $Vec1Reg..16b,#0 - movi $Vec2Reg..16b,#0 - IF $Columns > 8 - movi $Vec3Reg..16b,#0 - movi $Vec4Reg..16b,#0 - ENDIF - - MEND - -; -; ClearBlockAccumulators -; -; Generates the code to clear the accumulators for a single row of the output -; block. -; - - MACRO - ClearBlockAccumulators $Columns, $Rows - - ClearRowAccumulators $Columns, v16, v17, v18, v19 - IF $Rows >= 2 - ClearRowAccumulators $Columns, v20, v21, v22, v23 - ENDIF - IF $Rows >= 4 - ClearRowAccumulators $Columns, v24, v25, v26, v27 - ClearRowAccumulators $Columns, v28, v29, v30, v31 - ENDIF - - MEND - -; -; LoadMatrixAElementsBy4 -; LoadMatrixAElementsBy1 -; -; Generates the code to load 1 or 4 elements from matrix A. -; - - MACRO - LoadMatrixAElementsBy4 $Rows - - ldr v8,[x0],#16 - IF $Rows >= 2 - ldr v9,[x10],#16 - ENDIF - IF $Rows >= 4 - ldr v10,[x11],#16 - ldr v11,[x12],#16 - ENDIF - - MEND - - MACRO - LoadMatrixAElementsBy1 $Rows - - ldr s8,[x0],#4 - IF $Rows >= 2 - ldr s9,[x10],#4 - ENDIF - IF $Rows >= 4 - ldr s10,[x11],#4 - ldr s11,[x12],#4 - ENDIF - - MEND - -; -; MultiplyAccumulateRow -; -; Generates the code to multiply and accumulate a single row of the output -; block. -; - - MACRO - MultiplyAccumulateRow $Columns, $MatrixAReg, $Broadcast, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - fmla $Vec1Reg..4s,v4.4s,$MatrixAReg..s[$Broadcast] - fmla $Vec2Reg..4s,v5.4s,$MatrixAReg..s[$Broadcast] - IF $Columns > 8 - fmla $Vec3Reg..4s,v6.4s,$MatrixAReg..s[$Broadcast] - fmla $Vec4Reg..4s,v7.4s,$MatrixAReg..s[$Broadcast] - ENDIF - - MEND - -; -; MultiplyAccumulateBlock -; -; Generates the code to multiply and accumulate into the output block. -; - - MACRO - MultiplyAccumulateBlock $Columns, $Rows, $Broadcast - - MultiplyAccumulateRow $Columns, v8, $Broadcast, v16, v17, v18, v19 - IF $Rows >= 2 - MultiplyAccumulateRow $Columns, v9, $Broadcast, v20, v21, v22, v23 - ENDIF - IF $Rows >= 4 - MultiplyAccumulateRow $Columns, v10, $Broadcast, v24, v25, v26, v27 - MultiplyAccumulateRow $Columns, v11, $Broadcast, v28, v29, v30, v31 - ENDIF - - MEND - -; -; ComputeBlockLoop -; -; Generates the code to loop over K entries of the input matrices to produce -; the output block. -; - - MACRO - ComputeBlockLoop $Mode, $Columns, $Rows - - ClearBlockAccumulators $Columns, $Rows - - IF $Rows >= 2 - add x10,x0,x6 lsl #2 ; compute matrix A plus 1 row - ENDIF - IF $Rows >= 4 - add x11,x10,x6 lsl #2 ; compute matrix A plus 2 rows - add x12,x11,x6 lsl #2 ; compute matrix A plus 3 rows - ENDIF - - sub x9,x3,#4 ; decrement block count to process - tbnz x9,#63,$Mode.ProcessRemaining$Columns.x$Rows.Blocks - -$Mode.Compute$Columns.x$Rows.BlockBy4Loop - LoadMatrixAElementsBy4 $Rows - ldp v4,v5,[x1],#64*4 - IF $Columns > 8 - ldp v6,v7,[x1,#-56*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,0 - ldp v4,v5,[x1,#-48*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-40*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,1 - ldp v4,v5,[x1,#-32*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-24*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,2 - ldp v4,v5,[x1,#-16*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-8*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,3 - sub x9,x9,#4 - tbz x9,#63,$Mode.Compute$Columns.x$Rows.BlockBy4Loop - -$Mode.ProcessRemaining$Columns.x$Rows.Blocks - add x9,x9,#4 ; correct for over-subtract above - cbz x9,$Mode.Output$Columns.x$Rows.Block - -$Mode.Compute$Columns.x$Rows.BlockBy1Loop - LoadMatrixAElementsBy1 $Rows - ldp v4,v5,[x1],#16*4 - IF $Columns > 8 - ldp v6,v7,[x1,#-8*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,0 - sub x9,x9,#1 - cbnz x9,$Mode.Compute$Columns.x$Rows.BlockBy1Loop - -$Mode.Output$Columns.x$Rows.Block - - MEND - -; -; MultiplyAlphaRow -; -; Generates the code to multiply a single row of the output block by the alpha -; value. -; - - MACRO - MultiplyAlphaRow $Columns, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF $Columns <= 4 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - ELIF $Columns <= 8 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - ELIF $Columns <= 12 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - fmul $Vec3Reg..4s,$Vec3Reg..4s,v0.s[0] - ELSE - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - fmul $Vec3Reg..4s,$Vec3Reg..4s,v0.s[0] - fmul $Vec4Reg..4s,$Vec4Reg..4s,v0.s[0] - ENDIF - - MEND - -; -; MultiplyAlphaBlock -; -; Generates the code to multiply the output block by the alpha value. -; - - MACRO - MultiplyAlphaBlock $Columns, $Rows - - MultiplyAlphaRow $Columns, v16, v17, v18, v19 - IF $Rows >= 2 - MultiplyAlphaRow $Columns, v20, v21, v22, v23 - ENDIF - IF $Rows >= 4 - MultiplyAlphaRow $Columns, v24, v25, v26, v27 - MultiplyAlphaRow $Columns, v28, v29, v30, v31 - ENDIF - - MEND - -; -; OutputRow1Element -; OutputRow2Element -; OutputRow4Element -; OutputRow8Element -; OutputRow16Element -; -; Generates the code to store elements to the output block. -; - - MACRO - OutputRow1Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.s}[0],[$AddrReg] - fmla v4.2s,$Vec1Reg..2s,v0.s[0] - st1 {v4.s}[0],[$AddrReg] ; post-increment not needed for last element - ELSE - st1 {$Vec1Reg..s}[0],[$AddrReg] ; post-increment not needed for last element - ENDIF - - MEND - - MACRO - OutputRow2Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.2s},[$AddrReg] - fmla v4.2s,$Vec1Reg..2s,v0.s[0] - st1 {v4.2s},[$AddrReg],#2*4 - ELSE - st1 {$Vec1Reg..2s},[$AddrReg],#2*4 - ENDIF - dup $Vec1Reg..4s,$Vec1Reg..s[2] ; shift remaining elements down - - MEND - - MACRO - OutputRow4Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.4s},[$AddrReg] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - st1 {v4.4s},[$AddrReg],#4*4 - ELSE - st1 {$Vec1Reg..4s},[$AddrReg],#4*4 - ENDIF - mov $Vec1Reg..16b,$Vec2Reg..16b ; shift remaining elements down - - MEND - - MACRO - OutputRow8Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ldp v4,v5,[$AddrReg] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - fmla v5.4s,$Vec2Reg..4s,v0.s[0] - stp v4,v5,[$AddrReg],#8*4 - ELSE - stp $Vec1Reg.,$Vec2Reg.,[$AddrReg],#8*4 - ENDIF - mov $Vec1Reg..16b,$Vec3Reg..16b ; shift remaining elements down - mov $Vec2Reg..16b,$Vec4Reg..16b - - MEND - - MACRO - OutputRow16Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ldp v4,v5,[$AddrReg] - ldp v6,v7,[$AddrReg,#8*4] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - fmla v5.4s,$Vec2Reg..4s,v0.s[0] - fmla v6.4s,$Vec3Reg..4s,v0.s[0] - fmla v7.4s,$Vec4Reg..4s,v0.s[0] - stp v4,v5,[$AddrReg],#16*4 - stp v6,v7,[$AddrReg,#-8*4] - ELSE - stp $Vec1Reg.,$Vec2Reg.,[$AddrReg],#16*4 - stp $Vec3Reg.,$Vec4Reg.,[$AddrReg,#-8*4] - ENDIF - - MEND - -; -; OutputBlock -; -; Generates the code to store the output block. -; - - MACRO - OutputBlock $Mode, $Columns, $Rows - - OutputRow$Columns.Element $Mode, x2, v16, v17, v18, v19 - IF $Rows >= 2 - OutputRow$Columns.Element $Mode, x13, v20, v21, v22, v23 - ENDIF - IF $Rows >= 4 - OutputRow$Columns.Element $Mode, x14, v24, v25, v26, v27 - OutputRow$Columns.Element $Mode, x15, v28, v29, v30, v31 - ENDIF - - MEND - -; -; ProcessRows -; -; Generates the code to process a compute and store the output block for a -; fixed number of rows. -; - - MACRO - ProcessRows $Mode, $Rows - - mov x4,#$Rows ; return number of rows handled - cmp x5,#8 - ble $Mode.ProcessRemainingCountN$Rows - -$Mode.ProcessNextColumnLoop16x$Rows - ComputeBlockLoop $Mode,16,$Rows - IF "$Mode"=="Zero" - MultiplyAlphaBlock 16,$Rows - ENDIF - sub x5,x5,#16 - tbnz x5,#63,$Mode.OutputMasked16x$Rows.Block - OutputBlock $Mode,16,$Rows - mov x0,x8 ; reload matrix A - cmp x5,#8 - bgt $Mode.ProcessNextColumnLoop16x$Rows - cbz x5,$Mode.ExitKernel - -$Mode.ProcessRemainingCountN$Rows - ComputeBlockLoop $Mode,8,$Rows - IF "$Mode"=="Zero" - MultiplyAlphaBlock 8,$Rows - ENDIF - -$Mode.OutputMasked16x$Rows.Block - tbz x5,#3,$Mode.OutputRemaining7x$Rows.Block - OutputBlock $Mode,8,$Rows - -$Mode.OutputRemaining7x$Rows.Block - tbz x5,#2,$Mode.OutputRemaining3x$Rows.Block - OutputBlock $Mode,4,$Rows - -$Mode.OutputRemaining3x$Rows.Block - tbz x5,#1,$Mode.OutputRemaining1x$Rows.Block - OutputBlock $Mode,2,$Rows - -$Mode.OutputRemaining1x$Rows.Block - tbz x5,#0,$Mode.ExitKernel - OutputBlock $Mode,1,$Rows - - MEND - - SUBT "SGEMM kernel" -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (x0) - Supplies the address of matrix A. -; -; B (x1) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (x2) - Supplies the address of matrix C. -; -; CountK (x3) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM (x4) - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN (x5) - Supplies the number of columns from matrix B and matrix C to -; iterate over. -; -; lda (x6) - Supplies the first dimension of matrix A. -; -; ldc (x7) - Supplies the first dimension of matrix C. -; -; Alpha (s0) - Supplies the scalar multiplier (see SGEMM definition). -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - MACRO - SgemmKernelNeonFunction $Mode - - NESTED_ENTRY MlasSgemmKernel$Mode - - PROLOG_SAVE_REG_PAIR d8,d9,#-32! - PROLOG_SAVE_REG_PAIR d10,d11,#16 - - add x13,x2,x7 lsl #2 ; compute matrix C plus 1 row - add x14,x13,x7 lsl #2 ; compute matrix C plus 2 rows - add x15,x14,x7 lsl #2 ; compute matrix C plus 3 rows - mov x8,x0 ; save matrix A - -; -; Process 4 rows of the matrices. -; - - cmp x4,#4 - blt $Mode.ProcessCountMLessThan4 - ProcessRows $Mode,4 - -; -; Restore non-volatile registers and return. -; - -$Mode.ExitKernel - mov x0,x4 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#32! - EPILOG_RETURN - -; -; Process 2 rows of the matrices. -; - -$Mode.ProcessCountMLessThan4 - cmp x4,#2 - blt $Mode.ProcessCountMLessThan2 - ProcessRows $Mode,2 - b $Mode.ExitKernel - -; -; Process 1 row of the matrices. -; - -$Mode.ProcessCountMLessThan2 - ProcessRows $Mode,1 - b $Mode.ExitKernel - - NESTED_END - - MEND - - SgemmKernelNeonFunction Zero - SgemmKernelNeonFunction Add - - END diff --git a/onnxruntime/core/mlas/lib/arm64/SgemvKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/SgemvKernelNeon.asm deleted file mode 100644 index 9cda05114b56d..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/SgemvKernelNeon.asm +++ /dev/null @@ -1,305 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemvKernelNeon.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/vector -; multiply operation (SGEMV). -; -;-- - -#include "kxarm64.h" - - TEXTAREA - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. This handles the special case of M=1. -; -; The elements in matrix B are not transposed. -; -; Arguments: -; -; A (x0) - Supplies the address of matrix A. -; -; B (x1) - Supplies the address of matrix B. -; -; C (x2) - Supplies the address of matrix C. -; -; CountK (x3) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountN (x4) - Supplies the number of columns from matrix B and matrix C to -; iterate over. -; -; ldb (x5) - Supplies the first dimension of matrix B. -; -; ZeroMode (x6) - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; None. -; -;-- - - LEAF_ENTRY MlasGemvFloatKernel - - cmp x4,#64 - blo ProcessRemainingCountN - mov x14,x0 ; preserve vector A - -; -; Process 64 columns at a time in a loop. -; - -ProcessColumnLoopBy64 - ldr q4,[x1] - add x15,x1,#256 ; compute next matrix B - ldr q5,[x1,#16] - tst w6,0xFF ; ZeroMode? - mov x13,x3 ; reload CountK - ldr q6,[x1,#32] - beq LoadOutputBy64 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - movi v20.4s,#0 - movi v21.4s,#0 - movi v22.4s,#0 - movi v23.4s,#0 - movi v24.4s,#0 - movi v25.4s,#0 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - b MultiplyAccumulateBy64 - -LoadOutputBy64 - ldp q16,q17,[x2] - ldp q18,q19,[x2,#32] - ldp q20,q21,[x2,#64] - ldp q22,q23,[x2,#96] - ldp q24,q25,[x2,#128] - ldp q26,q27,[x2,#160] - ldp q28,q29,[x2,#192] - ldp q30,q31,[x2,#224] - -MultiplyAccumulateBy64 - ld1r {v0.4s},[x0] ; broadcast next vector A element - add x0,x0,4 ; advance vector A by 1 element - sub x13,x13,#1 ; decrement K remaining - fmla v16.4s,v4.4s,v0.4s - ldr q7,[x1,#48] - fmla v17.4s,v5.4s,v0.4s - ldr q4,[x1,#64] - fmla v18.4s,v6.4s,v0.4s - ldr q5,[x1,#80] - fmla v19.4s,v7.4s,v0.4s - ldr q6,[x1,#96] - fmla v20.4s,v4.4s,v0.4s - ldr q7,[x1,#112] - fmla v21.4s,v5.4s,v0.4s - ldr q4,[x1,#128] - fmla v22.4s,v6.4s,v0.4s - ldr q5,[x1,#144] - fmla v23.4s,v7.4s,v0.4s - ldr q6,[x1,#160] - fmla v24.4s,v4.4s,v0.4s - ldr q7,[x1,#176] - fmla v25.4s,v5.4s,v0.4s - ldr q4,[x1,#192] - fmla v26.4s,v6.4s,v0.4s - ldr q5,[x1,#208] - fmla v27.4s,v7.4s,v0.4s - ldr q6,[x1,#224] - fmla v28.4s,v4.4s,v0.4s - ldr q7,[x1,#240] - add x1,x1,x5,lsl #2 ; compute next matrix B row address - cbz x13,StoreOutputBy64 - ldr q4,[x1] ; load data for next iteration - fmla v29.4s,v5.4s,v0.4s - ldr q5,[x1,#16] - fmla v30.4s,v6.4s,v0.4s - ldr q6,[x1,#32] - fmla v31.4s,v7.4s,v0.4s - b MultiplyAccumulateBy64 - -StoreOutputBy64 - stp q16,q17,[x2] - fmla v29.4s,v5.4s,v0.4s ; finish computing tail vectors - stp q18,q19,[x2,#32] - fmla v30.4s,v6.4s,v0.4s - stp q20,q21,[x2,#64] - fmla v31.4s,v7.4s,v0.4s - stp q22,q23,[x2,#96] - sub x4,x4,#64 ; subtract 64 columns - stp q24,q25,[x2,#128] - mov x0,x14 ; reload vector A - stp q26,q27,[x2,#160] - mov x1,x15 ; load next matrix B - stp q28,q29,[x2,#192] - stp q30,q31,[x2,#224] - add x2,x2,#256 ; advance vector C by 64 columns - cbz x4,ExitKernel - cmp x4,#64 - bhs ProcessColumnLoopBy64 - -; -; Process the remaining 1 to 63 columns. -; - -ProcessRemainingCountN - tst w6,0xFF ; ZeroMode? - beq LoadOutputPartial32 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - movi v20.4s,#0 - movi v21.4s,#0 - movi v22.4s,#0 - movi v23.4s,#0 - movi v24.4s,#0 - movi v25.4s,#0 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 ; trailing float[2] - movi v1.4s,#0 ; trailing float[1] - b ProcessNextPartialRow - -LoadOutputPartial32 - mov x15,x2 - tbz x4,#5,LoadOutputPartial16 - ldp q16,q17,[x15],#128 - ldp q18,q19,[x15,#-96] - ldp q20,q21,[x15,#-64] - ldp q22,q23,[x15,#-32] - -LoadOutputPartial16 - tbz x4,#4,LoadOutputPartial8 - ldp q24,q25,[x15],#64 - ldp q26,q27,[x15,#-32] - -LoadOutputPartial8 - tbz x4,#3,LoadOutputPartial4 - ldp q28,q29,[x15],#32 - -LoadOutputPartial4 - tbz x4,#2,LoadOutputPartial2 - ldr q30,[x15],#16 - -LoadOutputPartial2 - tbz x4,#1,LoadOutputPartial1 - ldr d31,[x15],#8 - -LoadOutputPartial1 - tbz x4,#0,ProcessNextPartialRow - ldr s1,[x15] - -ProcessNextPartialRow - ld1r {v0.4s},[x0] - add x0,x0,4 - sub x3,x3,#1 ; decrement K remaining - mov x15,x1 - -MultiplyAccumulatePartial32 - tbz x4,#5,MultiplyAccumulatePartial16 - ldp q4,q5,[x15],#128 - fmla v16.4s,v4.4s,v0.4s - ldp q6,q7,[x15,#-96] - fmla v17.4s,v5.4s,v0.4s - ldp q4,q5,[x15,#-64] - fmla v18.4s,v6.4s,v0.4s - fmla v19.4s,v7.4s,v0.4s - ldp q6,q7,[x15,#-32] - fmla v20.4s,v4.4s,v0.4s - fmla v21.4s,v5.4s,v0.4s - fmla v22.4s,v6.4s,v0.4s - fmla v23.4s,v7.4s,v0.4s - -MultiplyAccumulatePartial16 - tbz x4,#4,MultiplyAccumulatePartial8 - ldp q4,q5,[x15],#64 - fmla v24.4s,v4.4s,v0.4s - ldp q6,q7,[x15,#-32] - fmla v25.4s,v5.4s,v0.4s - fmla v26.4s,v6.4s,v0.4s - fmla v27.4s,v7.4s,v0.4s - -MultiplyAccumulatePartial8 - tbz x4,#3,MultiplyAccumulatePartial4 - ldp q4,q5,[x15],#32 - fmla v28.4s,v4.4s,v0.4s - fmla v29.4s,v5.4s,v0.4s - -MultiplyAccumulatePartial4 - tbz x4,#2,MultiplyAccumulatePartial2 - ldr q4,[x15],#16 - fmla v30.4s,v4.4s,v0.4s - -MultiplyAccumulatePartial2 - tbz x4,#1,MultiplyAccumulatePartial1 - ldr d4,[x15],#8 - fmla v31.4s,v4.4s,v0.4s - -MultiplyAccumulatePartial1 - tbz x4,#0,AdvancePartialRow - ldr s4,[x15] - fmla v1.4s,v4.4s,v0.4s - -AdvancePartialRow - add x1,x1,x5,lsl #2 ; compute next matrix B row address - cbnz x3,ProcessNextPartialRow - -StoreOutputPartial32 - tbz x4,#5,StoreOutputPartial16 - stp q16,q17,[x2],#128 - stp q18,q19,[x2,#-96] - stp q20,q21,[x2,#-64] - stp q22,q23,[x2,#-32] - -StoreOutputPartial16 - tbz x4,#4,StoreOutputPartial8 - stp q24,q25,[x2],#64 - stp q26,q27,[x2,#-32] - -StoreOutputPartial8 - tbz x4,#3,StoreOutputPartial4 - stp q28,q29,[x2],#32 - -StoreOutputPartial4 - tbz x4,#2,StoreOutputPartial2 - str q30,[x2],#16 - -StoreOutputPartial2 - tbz x4,#1,StoreOutputPartial1 - str d31,[x2],#8 - -StoreOutputPartial1 - tbz x4,#0,ExitKernel - str s1,[x2] - -ExitKernel - ret - - LEAF_END MlasGemvFloatKernel - - END diff --git a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelNeon.asm deleted file mode 100644 index 4770b071dd84d..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelNeon.asm +++ /dev/null @@ -1,538 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelNeon.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the S8S8 kernel. -// - -#define SQGemmS8Frame_SavedNeonRegisters (8 * 8) -#define SQGemmS8Frame_SavedRegisters SQGemmS8Frame_SavedNeonRegisters -#define SQGemmS8Frame_ColumnSumBuffer 0 + SQGemmS8Frame_SavedRegisters - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasSymQgemmS8KernelNeon - - PROLOG_SAVE_REG_PAIR d8,d9,#-SQGemmS8Frame_SavedRegisters! - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - ldr x13,[sp,#SQGemmS8Frame_ColumnSumBuffer] - mov x14,x0 - mov x15,x3 - cmp x4,#1 // CountM == 1? - beq M1_ProcessLoop - cmp x4,#4 // CountM < 4? - blo M2_ProcessLoop - -// -// Process 4 rows of the matrices. -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v8.b[0] v9.b[0] v10.b[0] v11.b[0]| -// | ... ... ... ... | -// |v8.b[7] v9.b[7] v10.b[7] v11.b[7]| -// A 4x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v24.4s v25.4s v26.4s v27.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v28.4s v29.4s v30.4s v31.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// - -M4_ProcessNextColumnLoop - mov x0,x14 // reload matrix A0 - mov x3,x15 // reload PackedCountK - ldr d0,[x0],#8 // Load A0 - add x9,x14,x7 // A1 - ldr d2,[x0],#8 // Load A0 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d4,d8,[x1],#64 // B - movi v18.4s,#0 - movi v19.4s,#0 - ldp d5,d9,[x1,#-48] - movi v20.4s,#0 - movi v21.4s,#0 - ldp d6,d10,[x1,#-32] - movi v22.4s,#0 - movi v23.4s,#0 - ldp d7,d11,[x1,#-16] - movi v24.4s,#0 - movi v25.4s,#0 - add x10,x9,x7 // A2 - ldp d1,d3,[x9],#16 // Load A1 - movi v26.4s,#0 - movi v27.4s,#0 - movi v28.4s,#0 - movi v29.4s,#0 - movi v30.4s,#0 - movi v31.4s,#0 - add x11,x10,x7 // A3 - -M4_ComputeBlockLoop - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x10],#16 // Load A2 - sadalp v16.4s,v12.8h - sadalp v17.4s,v13.8h - sadalp v18.4s,v14.8h - sadalp v19.4s,v15.8h - sub x3,x3,#1 - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - ldp d1,d3,[x11],#16 // Load A3 - sadalp v20.4s,v12.8h - sadalp v21.4s,v13.8h - sadalp v22.4s,v14.8h - sadalp v23.4s,v15.8h - cbz x3,M4_ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ldp d0,d2,[x0],#16 // Load A0 next iter - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - ldp d4,d8,[x1],#64 // B - smlal v13.8h,v3.8b,v9.8b - ldp d5,d9,[x1,#-48] - smlal v14.8h,v3.8b,v10.8b - ldp d6,d10,[x1,#-32] - smlal v15.8h,v3.8b,v11.8b - ldp d7,d11,[x1,#-16] - sadalp v28.4s,v12.8h - ldp d1,d3,[x9],#16 // Load A1 next iter - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - b M4_ComputeBlockLoop - -M4_ComputeBlockLoopFinish - smull v12.8h,v0.8b,v4.8b - smull v13.8h,v0.8b,v5.8b - smull v14.8h,v0.8b,v6.8b - smull v15.8h,v0.8b,v7.8b - smlal v12.8h,v2.8b,v8.8b - smlal v13.8h,v2.8b,v9.8b - smlal v14.8h,v2.8b,v10.8b - smlal v15.8h,v2.8b,v11.8b - ld1 {v2.4s},[x13],#16 // load ColumnSumBuffer[0] - sadalp v24.4s,v12.8h - sadalp v25.4s,v13.8h - sadalp v26.4s,v14.8h - sadalp v27.4s,v15.8h - smull v12.8h,v1.8b,v4.8b - smull v13.8h,v1.8b,v5.8b - smull v14.8h,v1.8b,v6.8b - smull v15.8h,v1.8b,v7.8b - smlal v12.8h,v3.8b,v8.8b - smlal v13.8h,v3.8b,v9.8b - smlal v14.8h,v3.8b,v10.8b - smlal v15.8h,v3.8b,v11.8b - sadalp v28.4s,v12.8h - sadalp v29.4s,v13.8h - sadalp v30.4s,v14.8h - sadalp v31.4s,v15.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v24.4s,v24.4s,v25.4s - addp v26.4s,v26.4s,v27.4s - addp v28.4s,v28.4s,v29.4s - addp v30.4s,v30.4s,v31.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - addp v24.4s,v24.4s,v26.4s - addp v28.4s,v28.4s,v30.4s - - // accumulator += column sum B - add v16.4s,v16.4s,v2.4s - add v20.4s,v20.4s,v2.4s - add v24.4s,v24.4s,v2.4s - add v28.4s,v28.4s,v2.4s - -M4_StoreOutput - add x10,x2,x6,lsl #2 - add x11,x10,x6,lsl #2 - add x12,x11,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo M4_StoreOutputPartial - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - st1 {v24.4s},[x11] - st1 {v28.4s},[x12] - cbnz x5,M4_ProcessNextColumnLoop - -M4_ExitKernel - mov x0,#4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -M4_StoreOutputPartial - -M4_StoreOutputPartial_ZeroMode - tbz x5,#1,M4_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - st1 {v24.2s},[x11],#8 - dup v24.4s,v24.s[2] - st1 {v28.2s},[x12],#8 - dup v28.4s,v28.s[2] - -M4_StoreOutputPartial1_ZeroMode - tbz x5,#0,M4_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - st1 {v24.s}[0],[x11] - st1 {v28.s}[0],[x12] - b M4_ExitKernel - -// -// Process 2 rows of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// Each row sum replicated to all 4 elements of a vector register -// v30 v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 2x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// |v1.b[0]..v1.b[7] v3.b[0]..v3.b[7]| |v20.4s v21.4s v22.4s v23.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) - -M2_ProcessLoop - -M2_ProcessNextColumnLoop - ldp d4,d24,[x1],#16 // B - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // Load A0 - add x9,x14,x7 // A1 - movi v16.4s,#0 - movi v17.4s,#0 - ldp d5,d25,[x1],#16 - movi v18.4s,#0 - movi v19.4s,#0 - ldp d6,d26,[x1],#16 - movi v20.4s,#0 - movi v21.4s,#0 - ldp d7,d27,[x1],#16 - movi v22.4s,#0 - movi v23.4s,#0 - ldp d1,d3,[x9],#16 // Load A1 - -M2_ComputeBlockLoop - sub x3,x3,#1 - smull v28.8h,v0.8b,v4.8b - smull v29.8h,v0.8b,v5.8b - smull v30.8h,v0.8b,v6.8b - smull v31.8h,v0.8b,v7.8b - cbz x3,M2_ComputeBlockLoopFinish - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // Load A0 - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v29.8h,v3.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v30.8h,v3.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v31.8h,v3.8b,v27.8b - ldp d7,d27,[x1],#16 - sadalp v20.4s,v28.8h - ldp d1,d3,[x9],#16 // Load A1 - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - b M2_ComputeBlockLoop - -M2_ComputeBlockLoopFinish - ld1 {v0.4s},[x13],#16 // load ColumnSumBuffer[0] - smlal v28.8h,v2.8b,v24.8b - smlal v29.8h,v2.8b,v25.8b - smlal v30.8h,v2.8b,v26.8b - smlal v31.8h,v2.8b,v27.8b - sadalp v16.4s,v28.8h - sadalp v17.4s,v29.8h - sadalp v18.4s,v30.8h - sadalp v19.4s,v31.8h - smull v28.8h,v1.8b,v4.8b - smull v29.8h,v1.8b,v5.8b - smull v30.8h,v1.8b,v6.8b - smull v31.8h,v1.8b,v7.8b - smlal v28.8h,v3.8b,v24.8b - smlal v29.8h,v3.8b,v25.8b - smlal v30.8h,v3.8b,v26.8b - smlal v31.8h,v3.8b,v27.8b - sadalp v20.4s,v28.8h - sadalp v21.4s,v29.8h - sadalp v22.4s,v30.8h - sadalp v23.4s,v31.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v20.4s,v20.4s,v21.4s - addp v22.4s,v22.4s,v23.4s - addp v16.4s,v16.4s,v18.4s - addp v20.4s,v20.4s,v22.4s - - // accumulator = column sum B - add v16.4s,v16.4s,v0.4s - add v20.4s,v20.4s,v0.4s - -M2_StoreOutput - add x10,x2,x6,lsl #2 - subs x5,x5,#4 // adjust CountN remaining - blo M2_StoreOutputPartial - st1 {v16.4s},[x2],#16 - st1 {v20.4s},[x10] - cbnz x5,M2_ProcessNextColumnLoop - -M2_ExitKernel - mov x0,#2 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -M2_StoreOutputPartial - -M2_StoreOutputPartial_ZeroMode - tbz x5,#1,M2_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v20.2s},[x10],#8 - dup v20.4s,v20.s[2] - -M2_StoreOutputPartial1_ZeroMode - tbz x5,#0,M2_ExitKernel - st1 {v16.s}[0],[x2] - st1 {v20.s}[0],[x10] - b M2_ExitKernel - -// -// Process 1 row of the matrices. -// -// Column Sum v2.s[0] v2.s[4] -// row sum replicated to all 4 elements of a vector register -// v31 -// B 16x4 -// ---------------------------------------- -// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | -// | ... ... ... ... | -// |v4.b[7] v5.b[7] v6.b[7] v7.b[7] | -// |v24.b[0] v25.b[0] v26.b[0] v27.b[0]| -// | ... ... ... ... | -// |v24.b[7] v25.b[7] v26.b[7] v27.b[7]| -// A 1x16 ---------------------------------------- -// ----------------------------------- ---------------------------------------- -// |v0.b[0]..v0.b[7] v2.b[0]..v2.b[7]| |v16.4s v17.4s v18.4s v19.4s | -// ----------------------------------- ---------------------------------------- -// -// Accumulators are horizontally aggregated to the left most register -// for each row. e.g. (v16.s[0], v16.s[1], v16.s[2], v16.s[3]) <- (v16, v17, v18, v19) -// -M1_ProcessLoop - -M1_ProcessNextColumnLoop - ldp d4,d24,[x1],#16 // B - ldp d5,d25,[x1],#16 - ldp d6,d26,[x1],#16 - ldp d7,d27,[x1],#16 - mov x0,x14 // reload matrix A - mov x3,x15 // reload PackedCountK - ldp d0,d2,[x0],#16 // A0 - movi v16.4s,#0 - movi v17.4s,#0 - movi v18.4s,#0 - movi v19.4s,#0 - -M1_ComputeBlockLoop - sub x3,x3,#1 - smull v20.8h,v0.8b,v4.8b - smull v21.8h,v0.8b,v5.8b - cbz x3,M1_ComputeBlockLoopFinish - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - ldp d4,d24,[x1],#16 // B - smlal v21.8h,v2.8b,v25.8b - ldp d5,d25,[x1],#16 - smlal v22.8h,v2.8b,v26.8b - ldp d6,d26,[x1],#16 - smlal v23.8h,v2.8b,v27.8b - ldp d0,d2,[x0],#16 // A0 - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - ldp d7,d27,[x1],#16 - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - b M1_ComputeBlockLoop - -M1_ComputeBlockLoopFinish - ld1 {v4.4s},[x13],#16 // load ColumnSumBuffer[0] - smull v22.8h,v0.8b,v6.8b - smull v23.8h,v0.8b,v7.8b - smlal v20.8h,v2.8b,v24.8b - smlal v21.8h,v2.8b,v25.8b - smlal v22.8h,v2.8b,v26.8b - smlal v23.8h,v2.8b,v27.8b - sadalp v16.4s,v20.8h - sadalp v17.4s,v21.8h - sadalp v18.4s,v22.8h - sadalp v19.4s,v23.8h - addp v16.4s,v16.4s,v17.4s - addp v18.4s,v18.4s,v19.4s - addp v16.4s,v16.4s,v18.4s - - // accumulator += column sum B - add v16.4s,v16.4s,v4.4s - -M1_StoreOutput - subs x5,x5,#4 // adjust CountN remaining - blo M1_StoreOutputPartial - st1 {v16.4s},[x2],#16 - cbnz x5,M1_ProcessNextColumnLoop - -M1_ExitKernel - mov x0,#1 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -M1_StoreOutputPartial - -M1_StoreOutputPartial_ZeroMode - tbz x5,#1,M1_StoreOutputPartial1_ZeroMode - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - -M1_StoreOutputPartial1_ZeroMode - tbz x5,#0,M1_ExitKernel - st1 {v16.s}[0],[x2] - b M1_ExitKernel - - NESTED_END MlasSymQgemmS8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdot.asm b/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdot.asm deleted file mode 100644 index 2d4f6ea52e5c2..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdot.asm +++ /dev/null @@ -1,391 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelSdot.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the S8S8 kernel. -// - - -#define GemmS8S8KernelFrame_SavedRegisters (4 * 8) -#define GemmS8S8KernelFrame_ColumnSumBuffer (0 + GemmS8S8KernelFrame_SavedRegisters) - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - Packed K should be 16x - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasSymQgemmS8KernelSdot - - PROLOG_SAVE_REG_PAIR d8,d9,#-GemmS8S8KernelFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#GemmS8S8KernelFrame_ColumnSumBuffer] - PROLOG_SAVE_REG_PAIR d10,d11,#16 - - // compute C pointers: x2, x16, x17, x6 - cmp x4,#2 // M < 2 ? - add x16,x2,x6,lsl #2 // x16 -> C1 - add x17,x2,x6,lsl #3 // x17 -> C2 - csel x16,x2,x16,lo // if M < 2 x16/C1 -> C0 - csel x17,x16,x17,ls // if M <= 2 x17/C2 -> C1 - cmp x4,#4 // M < 4 ? - mov x12,#4 // set max M to 4 - add x6,x16,x6,lsl #3 // x6 -> C3 - mov x9,x0 // save A0 - mov x10,x3 // save K - csel x6,x17,x6,lo // if M < 4 x6/C3 -> C2 - csel x4,x12,x4,hi // if M > 4 M = 4; - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x0 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x12 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x13 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v22.s[3] v26.s[0]_v26.s[3] v30.s[0]_v30.s[3]| x17 -// x14 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x6 -// ------------------ ---------------------------------------------------------------------------- - -ProcessNextColumnLoop - ldr q16,[x8],#16 // Init accumulators with column sums - ldr q20,[x8],#16 - ldr q24,[x8],#16 - ldr q28,[x8],#16 - mov x0,x9 // reload A0 - cmp x4,#2 // M < 2 ? - add x12,x9,x7 // x12 -> A1 - add x13,x0,x7,lsl #1 // x13 -> A2 - ldr q4,[x1],#16 // Load B - csel x12,x0,x12,lo // if M < 2 A1 -> A0 - csel x13,x12,x13,ls // if M <= 2 A2 -> A1 - cmp x4,4 // M < 4 ? - add x14,x12,x7,lsl #1 // x14 -> A3 - ldr q5,[x1],#16 - csel x14,x13,x14,lo // if M < 4 A3 -> A2 - ldr d0,[x0],#8 // Load A0 1st/2nd block of 4 - mov v17.16b,v16.16b - mov v18.16b,v16.16b - ldr d1,[x12],#8 // Load A1 - mov v19.16b,v16.16b - mov v21.16b,v20.16b - ldr d2,[x13],#8 // Load A2 - mov v22.16b,v20.16b - mov v23.16b,v20.16b - ldr d3,[x14],#8 // Load A3 - mov v25.16b,v24.16b - mov v26.16b,v24.16b - ldr q6,[x1],#16 - mov v27.16b,v24.16b - mov v29.16b,v28.16b - ldr q7,[x1],#16 - subs x3,x10,#2 // one loop iteration and epilogue consume k = 32 - mov v30.16b,v28.16b - mov v31.16b,v28.16b - b.lo BlockLoopEpilogue // Need 32 k for main loop - -BlockLoop - sdot v16.4s,v4.16b,v0.4b[0] - sdot v17.4s,v4.16b,v1.4b[0] - ldr d8,[x0],#8 // Load A0 3rd/4th block of 4 - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - ldr d9,[x12],#8 - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - ldr d10,[x13],#8 - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - ldr d11,[x14],#8 - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v0.4b[1] - sdot v17.4s,v4.16b,v1.4b[1] - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v0.4b[1] - sdot v21.4s,v5.16b,v1.4b[1] - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v0.4b[1] - sdot v25.4s,v6.16b,v1.4b[1] - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v0.4b[1] - sdot v29.4s,v7.16b,v1.4b[1] - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v8.4b[0] - sdot v17.4s,v4.16b,v9.4b[0] - ldr d0,[x0],#8 - sdot v18.4s,v4.16b,v10.4b[0] - sdot v19.4s,v4.16b,v11.4b[0] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v8.4b[0] - sdot v21.4s,v5.16b,v9.4b[0] - ldr d1,[x12],#8 - sdot v22.4s,v5.16b,v10.4b[0] - sdot v23.4s,v5.16b,v11.4b[0] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v8.4b[0] - sdot v25.4s,v6.16b,v9.4b[0] - ldr d2,[x13],#8 - sdot v26.4s,v6.16b,v10.4b[0] - sdot v27.4s,v6.16b,v11.4b[0] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v8.4b[0] - sdot v29.4s,v7.16b,v9.4b[0] - ldr d3,[x14],#8 - sdot v30.4s,v7.16b,v10.4b[0] - sdot v31.4s,v7.16b,v11.4b[0] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v8.4b[1] - sdot v17.4s,v4.16b,v9.4b[1] - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v8.4b[1] - sdot v21.4s,v5.16b,v9.4b[1] - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v8.4b[1] - sdot v25.4s,v6.16b,v9.4b[1] - sdot v26.4s,v6.16b,v10.4b[1] - sdot v27.4s,v6.16b,v11.4b[1] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v8.4b[1] - sdot v29.4s,v7.16b,v9.4b[1] - subs x3,x3,#1 // k -= 16 - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - ldr q7,[x1],#16 - b.hs BlockLoop - -BlockLoopEpilogue - sdot v16.4s,v4.16b,v0.4b[0] - sdot v17.4s,v4.16b,v1.4b[0] - ldr d8,[x0],#8 - sdot v18.4s,v4.16b,v2.4b[0] - sdot v19.4s,v4.16b,v3.4b[0] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v0.4b[0] - sdot v21.4s,v5.16b,v1.4b[0] - ldr d9,[x12],#8 - sdot v22.4s,v5.16b,v2.4b[0] - sdot v23.4s,v5.16b,v3.4b[0] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v0.4b[0] - sdot v25.4s,v6.16b,v1.4b[0] - ldr d10,[x13],#8 - sdot v26.4s,v6.16b,v2.4b[0] - sdot v27.4s,v6.16b,v3.4b[0] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v0.4b[0] - sdot v29.4s,v7.16b,v1.4b[0] - ldr d11,[x14],#8 - sdot v30.4s,v7.16b,v2.4b[0] - sdot v31.4s,v7.16b,v3.4b[0] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v0.4b[1] - sdot v17.4s,v4.16b,v1.4b[1] - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v0.4b[1] - sdot v21.4s,v5.16b,v1.4b[1] - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v0.4b[1] - sdot v25.4s,v6.16b,v1.4b[1] - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v0.4b[1] - sdot v29.4s,v7.16b,v1.4b[1] - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v8.4b[0] - sdot v17.4s,v4.16b,v9.4b[0] - sdot v18.4s,v4.16b,v10.4b[0] - sdot v19.4s,v4.16b,v11.4b[0] - ldr q4,[x1],#16 - sdot v20.4s,v5.16b,v8.4b[0] - sdot v21.4s,v5.16b,v9.4b[0] - sdot v22.4s,v5.16b,v10.4b[0] - sdot v23.4s,v5.16b,v11.4b[0] - ldr q5,[x1],#16 - sdot v24.4s,v6.16b,v8.4b[0] - sdot v25.4s,v6.16b,v9.4b[0] - sdot v26.4s,v6.16b,v10.4b[0] - sdot v27.4s,v6.16b,v11.4b[0] - ldr q6,[x1],#16 - sdot v28.4s,v7.16b,v8.4b[0] - sdot v29.4s,v7.16b,v9.4b[0] - sdot v30.4s,v7.16b,v10.4b[0] - sdot v31.4s,v7.16b,v11.4b[0] - ldr q7,[x1],#16 - sdot v16.4s,v4.16b,v8.4b[1] - sdot v17.4s,v4.16b,v9.4b[1] - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - sdot v20.4s,v5.16b,v8.4b[1] - sdot v21.4s,v5.16b,v9.4b[1] - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - sdot v24.4s,v6.16b,v8.4b[1] - sdot v25.4s,v6.16b,v9.4b[1] - sdot v26.4s,v6.16b,v10.4b[1] - sdot v27.4s,v6.16b,v11.4b[1] - sdot v28.4s,v7.16b,v8.4b[1] - sdot v29.4s,v7.16b,v9.4b[1] - subs x5,x5,#16 // adjust CountN remaining - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - blo StoreOutputPartial - stp q16,q20,[x2],#32 - stp q24,q28,[x2],#32 - stp q17,q21,[x16],#32 - stp q25,q29,[x16],#32 - stp q18,q22,[x17],#32 - stp q26,q30,[x17],#32 - stp q19,q23,[x6],#32 - stp q27,q31,[x6],#32 - cbnz x5,ProcessNextColumnLoop - -ExitKernel - mov x0,x4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#GemmS8S8KernelFrame_SavedRegisters! - EPILOG_RETURN - -// -// Store the partial 1 to 15 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartial - tbz x5,#3,StoreOutputPartial4 - stp q16,q20,[x2],#32 - mov v16.16b,v24.16b // shift remaining elements down - mov v20.16b,v28.16b - stp q17,q21,[x16],#32 - mov v17.16b,v25.16b - mov v21.16b,v29.16b - stp q18,q22,[x17],#32 - mov v18.16b,v26.16b - mov v22.16b,v30.16b - stp q19,q23,[x6],#32 - mov v19.16b,v27.16b - mov v23.16b,v31.16b - -StoreOutputPartial4 - tbz x5,#2,StoreOutputPartial2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v20.16b // shift remaining elements down - st1 {v17.4s},[x16],#16 - mov v17.16b,v21.16b - st1 {v18.4s},[x17],#16 - mov v18.16b,v22.16b - st1 {v19.4s},[x6],#16 - mov v19.16b,v23.16b - -StoreOutputPartial2 - tbz x5,#1,StoreOutputPartial1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v17.2s},[x16],#8 - dup v17.4s,v17.s[2] - st1 {v18.2s},[x17],#8 - dup v18.4s,v18.s[2] - st1 {v19.2s},[x6],#8 - dup v19.4s,v19.s[2] - -StoreOutputPartial1 - tbz x5,#0,ExitKernel - st1 {v16.s}[0],[x2] - st1 {v17.s}[0],[x16] - st1 {v18.s}[0],[x17] - st1 {v19.s}[0],[x6] - b ExitKernel - - NESTED_END MlasSymQgemmS8KernelSdot - - END diff --git a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdotLd64.asm b/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdotLd64.asm deleted file mode 100644 index 4f5078610cb76..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64/SymQgemmS8KernelSdotLd64.asm +++ /dev/null @@ -1,456 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SymQgemmS8KernelSdot.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM), where the right hand side is symmetrically quantized, - i.e. zero point being zero. - - This kernel only requires prepacking of the right hand side, which is usually - constant. When the packed right hand side is cached, we achieves higher performance - by avoid packing all together. - - This version utilizes dot product instructions, and uses only 64b loads that performs - better on cores with narrow memory interface such as A55 - ---*/ - -#include "kxarm64.h" -#include "AssembleDotProduct.h" - -// -// Stack frame layout for the S8S8 kernel. -// - - -#define GemmS8S8KernelFrame_SavedRegisters (6 * 8) -#define GemmS8S8KernelFrame_ColumnSumBuffer (0 + GemmS8S8KernelFrame_SavedRegisters) - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmQuantCopyPackB. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - Packed K should be 16x - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - lda (x7) - Supplies the first dimension of matrix A. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY MlasSymQgemmS8KernelSdotLd64 - - PROLOG_SAVE_REG_PAIR d8,d9,#-GemmS8S8KernelFrame_SavedRegisters! - PROLOG_NOP ldr x8,[sp,#GemmS8S8KernelFrame_ColumnSumBuffer] - PROLOG_NOP cmp x4,#2 // M < 2 ? - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_NOP add x16,x2,x6,lsl #2 // x16 -> C1 - PROLOG_NOP add x17,x2,x6,lsl #3 // x17 -> C2 - PROLOG_SAVE_REG_PAIR x20,x21,#32 - csel x16,x2,x16,lo // if M < 2 x16/C1 -> C0 - mov x12,#4 // set max M to 4 - csel x17,x16,x17,ls // if M <= 2 x17/C2 -> C1 - cmp x4,#4 // M < 4 ? - add x6,x16,x6,lsl #3 // x6 -> C3 - mov x9,x0 // save A0 - mov x10,x3 // save K - csel x6,x17,x6,lo // if M < 4 x6/C3 -> C2 - csel x4,x12,x4,hi // if M > 4 M = 4; - -// Register Usage -// B (x1) -> 4x16 -// ---------------------------------------------------------------------------- -// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| -// | ... ... ... ... ... ... ... ... | -// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| -// A 4x4 ---------------------------------------------------------------------------- -// ------------------ ---------------------------------------------------------------------------- -// x0 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 -// x12 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 -// x13 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v22.s[3] v26.s[0]_v26.s[3] v30.s[0]_v30.s[3]| x17 -// x14 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x6 -// ------------------ ---------------------------------------------------------------------------- - -ProcessNextColumnLoop - ldr q16,[x8],#16 // Init accumulators with column sums - ldr q20,[x8],#16 - ldr q24,[x8],#16 - ldr q28,[x8],#16 - mov x0,x9 // reload A0 - cmp x4,#2 // M < 2 ? - ldr q4,[x1],#16 // Load B - add x12,x9,x7 // x12 -> A1 - add x13,x0,x7,lsl #1 // x13 -> A2 - csel x12,x0,x12,lo // if M < 2 A1 -> A0 - ldr d0,[x0],#8 // Load A0 1st/2nd block of 4 - csel x13,x12,x13,ls // if M <= 2 A2 -> A1 - cmp x4,4 // M < 4 ? - ldr d5,[x1],#8 - add x14,x12,x7,lsl #1 // x14 -> A3 - ldr d1,[x12],#8 // Load A1 - csel x14,x13,x14,lo // if M < 4 A3 -> A2 - ldr d2,[x13],#8 // Load A2 - mov v17.16b,v16.16b - ldr d3,[x14],#8 // Load A3 - mov v18.16b,v16.16b - ldr x15,[x1],#8 - mov v19.16b,v16.16b - ldr d6,[x1],#8 - mov v21.16b,v20.16b - ldr x20,[x1],#8 - mov v22.16b,v20.16b - mov v23.16b,v20.16b - mov v25.16b,v24.16b - mov v26.16b,v24.16b - mov v27.16b,v24.16b - mov v29.16b,v28.16b - subs x3,x10,#2 // one loop iteration and epilogue consume k = 32 - mov v30.16b,v28.16b - mov v31.16b,v28.16b - b.lo BlockLoopEpilogue // Need 32 k for main loop - -BlockLoop - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v0.4b[0] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v1.4b[0] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v2.4b[0] - ldr d8,[x0],#8 // Load A0 3rd/4th block of 4 - sdot v19.4s,v4.16b,v3.4b[0] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v0.4b[0] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v1.4b[0] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v2.4b[0] - ldr d9,[x12],#8 - sdot v23.4s,v5.16b,v3.4b[0] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v0.4b[0] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v1.4b[0] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v2.4b[0] - ldr d10,[x13],#8 - sdot v27.4s,v6.16b,v3.4b[0] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v0.4b[0] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v1.4b[0] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v2.4b[0] - ldr d11,[x14],#8 - sdot v31.4s,v7.16b,v3.4b[0] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v0.4b[1] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v1.4b[1] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v0.4b[1] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v1.4b[1] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v0.4b[1] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v1.4b[1] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v0.4b[1] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v1.4b[1] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v8.4b[0] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v9.4b[0] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v10.4b[0] - ldr d0,[x0],#8 - sdot v19.4s,v4.16b,v11.4b[0] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v8.4b[0] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v9.4b[0] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v10.4b[0] - ldr d1,[x12],#8 - sdot v23.4s,v5.16b,v11.4b[0] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v8.4b[0] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v9.4b[0] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v10.4b[0] - ldr d2,[x13],#8 - sdot v27.4s,v6.16b,v11.4b[0] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v8.4b[0] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v9.4b[0] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v10.4b[0] - ldr d3,[x14],#8 - sdot v31.4s,v7.16b,v11.4b[0] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v8.4b[1] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v9.4b[1] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v8.4b[1] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v9.4b[1] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v8.4b[1] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v9.4b[1] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v10.4b[1] - subs x3,x3,#1 // k -= 16 - sdot v27.4s,v6.16b,v11.4b[1] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v8.4b[1] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v9.4b[1] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - b.hs BlockLoop - -BlockLoopEpilogue - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v0.4b[0] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v1.4b[0] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v2.4b[0] - ldr d8,[x0],#8 - sdot v19.4s,v4.16b,v3.4b[0] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v0.4b[0] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v1.4b[0] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v2.4b[0] - ldr d9,[x12],#8 - sdot v23.4s,v5.16b,v3.4b[0] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v0.4b[0] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v1.4b[0] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v2.4b[0] - ldr d10,[x13],#8 - sdot v27.4s,v6.16b,v3.4b[0] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v0.4b[0] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v1.4b[0] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v2.4b[0] - ldr d11,[x14],#8 - sdot v31.4s,v7.16b,v3.4b[0] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v0.4b[1] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v1.4b[1] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v2.4b[1] - sdot v19.4s,v4.16b,v3.4b[1] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v0.4b[1] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v1.4b[1] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v2.4b[1] - sdot v23.4s,v5.16b,v3.4b[1] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v0.4b[1] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v1.4b[1] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v2.4b[1] - sdot v27.4s,v6.16b,v3.4b[1] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v0.4b[1] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v1.4b[1] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v2.4b[1] - sdot v31.4s,v7.16b,v3.4b[1] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v8.4b[0] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v9.4b[0] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v10.4b[0] - sdot v19.4s,v4.16b,v11.4b[0] - ldr d4,[x1],#8 - sdot v20.4s,v5.16b,v8.4b[0] - ldr x11,[x1],#8 - sdot v21.4s,v5.16b,v9.4b[0] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v10.4b[0] - sdot v23.4s,v5.16b,v11.4b[0] - ldr d5,[x1],#8 - sdot v24.4s,v6.16b,v8.4b[0] - ldr x15,[x1],#8 - sdot v25.4s,v6.16b,v9.4b[0] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v10.4b[0] - sdot v27.4s,v6.16b,v11.4b[0] - ldr d6,[x1],#8 - sdot v28.4s,v7.16b,v8.4b[0] - ldr x20,[x1],#8 - sdot v29.4s,v7.16b,v9.4b[0] - ins v4.d[1],x11 - sdot v30.4s,v7.16b,v10.4b[0] - sdot v31.4s,v7.16b,v11.4b[0] - ldr d7,[x1],#8 - sdot v16.4s,v4.16b,v8.4b[1] - ldr x21,[x1],#8 - sdot v17.4s,v4.16b,v9.4b[1] - ins v5.d[1],x15 - sdot v18.4s,v4.16b,v10.4b[1] - sdot v19.4s,v4.16b,v11.4b[1] - sdot v20.4s,v5.16b,v8.4b[1] - sdot v21.4s,v5.16b,v9.4b[1] - ins v6.d[1],x20 - sdot v22.4s,v5.16b,v10.4b[1] - sdot v23.4s,v5.16b,v11.4b[1] - sdot v24.4s,v6.16b,v8.4b[1] - sdot v25.4s,v6.16b,v9.4b[1] - ins v7.d[1],x21 - sdot v26.4s,v6.16b,v10.4b[1] - sdot v27.4s,v6.16b,v11.4b[1] - sdot v28.4s,v7.16b,v8.4b[1] - sdot v29.4s,v7.16b,v9.4b[1] - subs x5,x5,#16 // adjust CountN remaining - sdot v30.4s,v7.16b,v10.4b[1] - sdot v31.4s,v7.16b,v11.4b[1] - blo StoreOutputPartial - stp q16,q20,[x2],#32 - stp q24,q28,[x2],#32 - stp q17,q21,[x16],#32 - stp q25,q29,[x16],#32 - stp q18,q22,[x17],#32 - stp q26,q30,[x17],#32 - stp q19,q23,[x6],#32 - stp q27,q31,[x6],#32 - cbnz x5,ProcessNextColumnLoop - -ExitKernel - mov x0,x4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR x20,x21,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#GemmS8S8KernelFrame_SavedRegisters! - EPILOG_RETURN - -// -// Store the partial 1 to 15 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartial - tbz x5,#3,StoreOutputPartial4 - stp q16,q20,[x2],#32 - mov v16.16b,v24.16b // shift remaining elements down - mov v20.16b,v28.16b - stp q17,q21,[x16],#32 - mov v17.16b,v25.16b - mov v21.16b,v29.16b - stp q18,q22,[x17],#32 - mov v18.16b,v26.16b - mov v22.16b,v30.16b - stp q19,q23,[x6],#32 - mov v19.16b,v27.16b - mov v23.16b,v31.16b - -StoreOutputPartial4 - tbz x5,#2,StoreOutputPartial2 - st1 {v16.4s},[x2],#16 - mov v16.16b,v20.16b // shift remaining elements down - st1 {v17.4s},[x16],#16 - mov v17.16b,v21.16b - st1 {v18.4s},[x17],#16 - mov v18.16b,v22.16b - st1 {v19.4s},[x6],#16 - mov v19.16b,v23.16b - -StoreOutputPartial2 - tbz x5,#1,StoreOutputPartial1 - st1 {v16.2s},[x2],#8 - dup v16.4s,v16.s[2] // shift remaining elements down - st1 {v17.2s},[x16],#8 - dup v17.4s,v17.s[2] - st1 {v18.2s},[x17],#8 - dup v18.4s,v18.s[2] - st1 {v19.2s},[x6],#8 - dup v19.4s,v19.s[2] - -StoreOutputPartial1 - tbz x5,#0,ExitKernel - st1 {v16.s}[0],[x2] - st1 {v17.s}[0],[x16] - st1 {v18.s}[0],[x17] - st1 {v19.s}[0],[x6] - b ExitKernel - - NESTED_END MlasSymQgemmS8KernelSdotLd64 - - END diff --git a/onnxruntime/core/mlas/lib/arm64ec/QgemmU8X8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64ec/QgemmU8X8KernelNeon.asm deleted file mode 100644 index 64cff406620e0..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64ec/QgemmU8X8KernelNeon.asm +++ /dev/null @@ -1,629 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelNeon.asm - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - ---*/ - -#include "kxarm64.h" - -// -// Stack frame layout for the U8X8 kernel. -// - -#define GemmU8X8KernelFrame_SavedNeonRegisters (8 * 8) -#define GemmU8X8KernelFrame_SavedRegisters GemmU8X8KernelFrame_SavedNeonRegisters -#define GemmU8X8KernelFrame_ColumnSumBuffer (0 + GemmU8X8KernelFrame_SavedRegisters) -#define GemmU8X8KernelFrame_ZeroPointB (8 + GemmU8X8KernelFrame_SavedRegisters) -#define GemmU8X8KernelFrame_ZeroMode (16 + GemmU8X8KernelFrame_SavedRegisters) - -// -// Define instruction aliases not implemented by ARMASM64. -// - - MACRO - uxtl $DestReg, $SrcReg - - ushll $DestReg.,$SrcReg.,#0 - - MEND - - TEXTAREA - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (x0) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackANeon. - - B (x1) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBNeon. - - C (x2) - Supplies the address of matrix C. - - PackedCountK (x3) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (x4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (x5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc (x6) - Supplies the first dimension of matrix C. - - RowSumBuffer (x7) - Supplies the sum of each row from matrix A multiplied by - the zero point offset of matrix B. These values are accumulated into every - row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroMode - Supplies true if the output matrix must be zero initialized, else - false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - NESTED_ENTRY_COMDAT A64NAME(MlasGemmU8X8KernelNeon) - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - ldr x8,[sp,#GemmU8X8KernelFrame_ColumnSumBuffer] - ldr x9,[sp,#GemmU8X8KernelFrame_ZeroPointB] - ldrb w15,[sp,#GemmU8X8KernelFrame_ZeroMode] - mov x16,x0 - ld1 {v7.4s},[x7] // load RowSumBuffer - mov x17,x3 - cmp x4,#1 // CountM == 1? - beq ProcessNextColumnLoopM1 - cmp x4,#4 // CountM < 4? - blo ProcessNextColumnLoopM2 - -// -// Process 4 rows of the matrices. -// - -ProcessNextColumnLoopM4 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x16 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x17 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - dup v9.4s,v7.s[0] - dup v11.4s,v7.s[1] - dup v13.4s,v7.s[2] - dup v15.4s,v7.s[3] - cbz x9,SkipScaleByZeroPointBM4 - ld1 {v6.4s},[x9],#16 // load ZeroPointB0 - mul v8.4s,v9.4s,v6.4s - mul v10.4s,v11.4s,v6.4s - mul v12.4s,v13.4s,v6.4s - mul v14.4s,v15.4s,v6.4s - ld1 {v6.4s},[x9],#16 // load ZeroPointB1 - mul v9.4s,v9.4s,v6.4s - mul v11.4s,v11.4s,v6.4s - mul v13.4s,v13.4s,v6.4s - mul v15.4s,v15.4s,v6.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v8.4s,v8.4s,v2.4s - add v9.4s,v9.4s,v3.4s - add v10.4s,v10.4s,v2.4s - add v11.4s,v11.4s,v3.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v12.4s,v12.4s,v2.4s - add v13.4s,v13.4s,v3.4s - add v14.4s,v14.4s,v2.4s - add v15.4s,v15.4s,v3.4s - b ComputeBlockLoopM4 - -SkipScaleByZeroPointBM4 - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v8.4s,v9.4s,v2.4s - add v9.4s,v9.4s,v3.4s - add v10.4s,v11.4s,v2.4s - add v11.4s,v11.4s,v3.4s - ld1 {v5.8b},[x0],#8 // load first packed A1 - add v12.4s,v13.4s,v2.4s - add v13.4s,v13.4s,v3.4s - add v14.4s,v15.4s,v2.4s - add v15.4s,v15.4s,v3.4s - -ComputeBlockLoopM4 - uxtl v2.8h,v4.8b - uxtl v3.8h,v5.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v8.4s,v0.4h,v2.h[0] - umlal2 v9.4s,v0.8h,v2.h[0] - umlal v10.4s,v0.4h,v2.h[4] - umlal2 v11.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - umlal v12.4s,v0.4h,v3.h[0] - umlal2 v13.4s,v0.8h,v3.h[0] - umlal v14.4s,v0.4h,v3.h[4] - umlal2 v15.4s,v0.8h,v3.h[4] - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v8.4s,v1.4h,v2.h[1] - umlal2 v9.4s,v1.8h,v2.h[1] - umlal v10.4s,v1.4h,v2.h[5] - umlal2 v11.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - umlal v12.4s,v1.4h,v3.h[1] - umlal2 v13.4s,v1.8h,v3.h[1] - umlal v14.4s,v1.4h,v3.h[5] - umlal2 v15.4s,v1.8h,v3.h[5] - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM4 - umlal v8.4s,v0.4h,v2.h[2] - umlal2 v9.4s,v0.8h,v2.h[2] - umlal v10.4s,v0.4h,v2.h[6] - umlal2 v11.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - umlal v12.4s,v0.4h,v3.h[2] - umlal2 v13.4s,v0.8h,v3.h[2] - umlal v14.4s,v0.4h,v3.h[6] - umlal2 v15.4s,v0.8h,v3.h[6] - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - umlal v10.4s,v1.4h,v2.h[7] - umlal2 v11.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - ld1 {v5.8b},[x0],#8 // load next packed A1 - umlal v12.4s,v1.4h,v3.h[3] - umlal2 v13.4s,v1.8h,v3.h[3] - umlal v14.4s,v1.4h,v3.h[7] - umlal2 v15.4s,v1.8h,v3.h[7] - b ComputeBlockLoopM4 - -ComputeBlockLoopFinishM4 - umlal v8.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v9.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v10.4s,v0.4h,v2.h[6] - umlal2 v11.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v12.4s,v0.4h,v3.h[2] - umlal2 v13.4s,v0.8h,v3.h[2] - umlal v14.4s,v0.4h,v3.h[6] - umlal2 v15.4s,v0.8h,v3.h[6] - add x11,x10,x6,lsl #2 // compute output row 3 - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - umlal v10.4s,v1.4h,v2.h[7] - umlal2 v11.4s,v1.8h,v2.h[7] - umlal v12.4s,v1.4h,v3.h[3] - umlal2 v13.4s,v1.8h,v3.h[3] - add x12,x11,x6,lsl #2 // compute output row 4 - umlal v14.4s,v1.4h,v3.h[7] - umlal2 v15.4s,v1.8h,v3.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM4 - cbnz x15,SkipAccumulateOutputM4 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v8.4s,v8.4s,v0.4s - add v9.4s,v9.4s,v1.4s - ldp q0,q1,[x11] - add v10.4s,v10.4s,v2.4s - add v11.4s,v11.4s,v3.4s - ldp q2,q3,[x12] - add v12.4s,v12.4s,v0.4s - add v13.4s,v13.4s,v1.4s - add v14.4s,v14.4s,v2.4s - add v15.4s,v15.4s,v3.4s - -SkipAccumulateOutputM4 - stp q8,q9,[x2],#32 - stp q10,q11,[x10] - stp q12,q13,[x11] - stp q14,q15,[x12] - cbnz x5,ProcessNextColumnLoopM4 - -ExitKernelM4 - mov x0,#4 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM4 - cbz x15,StoreOutputPartialAddModeM4 - -StoreOutputPartialZeroModeM4 - tbz x5,#2,StoreOutputPartial2ZeroModeM4 - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - st1 {v10.4s},[x10],#16 - mov v10.16b,v11.16b - st1 {v12.4s},[x11],#16 - mov v12.16b,v13.16b - st1 {v14.4s},[x12],#16 - mov v14.16b,v15.16b - -StoreOutputPartial2ZeroModeM4 - tbz x5,#1,StoreOutputPartial1ZeroModeM4 - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - st1 {v10.2s},[x10],#8 - dup v10.4s,v10.s[2] - st1 {v12.2s},[x11],#8 - dup v12.4s,v12.s[2] - st1 {v14.2s},[x12],#8 - dup v14.4s,v14.s[2] - -StoreOutputPartial1ZeroModeM4 - tbz x5,#0,ExitKernelM4 - st1 {v8.s}[0],[x2] - st1 {v10.s}[0],[x10] - st1 {v12.s}[0],[x11] - st1 {v14.s}[0],[x12] - b ExitKernelM4 - -StoreOutputPartialAddModeM4 - tbz x5,#2,StoreOutputPartial2AddModeM4 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - ld1 {v2.4s},[x11] - ld1 {v3.4s},[x12] - add v8.4s,v8.4s,v0.4s - add v10.4s,v10.4s,v1.4s - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - st1 {v10.4s},[x10],#16 - mov v10.16b,v11.16b - add v12.4s,v12.4s,v2.4s - add v14.4s,v14.4s,v3.4s - st1 {v12.4s},[x11],#16 - mov v12.16b,v13.16b - st1 {v14.4s},[x12],#16 - mov v14.16b,v15.16b - -StoreOutputPartial2AddModeM4 - tbz x5,#1,StoreOutputPartial1AddModeM4 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - ld1 {v2.2s},[x11] - ld1 {v3.2s},[x12] - add v8.4s,v8.4s,v0.4s - add v10.4s,v10.4s,v1.4s - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - st1 {v10.2s},[x10],#8 - dup v10.4s,v10.s[2] - add v12.4s,v12.4s,v2.4s - add v14.4s,v14.4s,v3.4s - st1 {v12.2s},[x11],#8 - dup v12.4s,v12.s[2] - st1 {v14.2s},[x12],#8 - dup v14.4s,v14.s[2] - -StoreOutputPartial1AddModeM4 - tbz x5,#0,ExitKernelM4 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v8.4s,v8.4s,v0.4s - ld1 {v2.s}[0],[x11] - add v10.4s,v10.4s,v1.4s - ld1 {v3.s}[0],[x12] - add v12.4s,v12.4s,v2.4s - st1 {v8.s}[0],[x2] - st1 {v10.s}[0],[x10] - add v14.4s,v14.4s,v3.4s - st1 {v12.s}[0],[x11] - st1 {v14.s}[0],[x12] - b ExitKernelM4 - -// -// Process 2 rows of the matrices. -// - -ProcessNextColumnLoopM2 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x16 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x17 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - dup v9.4s,v7.s[0] - dup v11.4s,v7.s[1] - cbz x9,SkipScaleByZeroPointBM2 - ld1 {v14.4s},[x9],#16 // load ZeroPointB0 - ld1 {v15.4s},[x9],#16 // load ZeroPointB1 - mul v8.4s,v9.4s,v14.4s - mul v10.4s,v11.4s,v14.4s - mul v9.4s,v9.4s,v15.4s - mul v11.4s,v11.4s,v15.4s - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v8.4s,v8.4s,v2.4s - add v9.4s,v9.4s,v3.4s - add v10.4s,v10.4s,v2.4s - add v11.4s,v11.4s,v3.4s - b ComputeBlockLoopM2 - -SkipScaleByZeroPointBM2 - ld1 {v4.8b},[x0],#8 // load first packed A0 - add v8.4s,v9.4s,v2.4s - add v9.4s,v9.4s,v3.4s - add v10.4s,v11.4s,v2.4s - add v11.4s,v11.4s,v3.4s - -ComputeBlockLoopM2 - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v8.4s,v0.4h,v2.h[0] - umlal2 v9.4s,v0.8h,v2.h[0] - umlal v10.4s,v0.4h,v2.h[4] - umlal2 v11.4s,v0.8h,v2.h[4] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v8.4s,v1.4h,v2.h[1] - umlal2 v9.4s,v1.8h,v2.h[1] - umlal v10.4s,v1.4h,v2.h[5] - umlal2 v11.4s,v1.8h,v2.h[5] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM2 - umlal v8.4s,v0.4h,v2.h[2] - umlal2 v9.4s,v0.8h,v2.h[2] - umlal v10.4s,v0.4h,v2.h[6] - umlal2 v11.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - ld1 {v4.8b},[x0],#8 // load next packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - umlal v10.4s,v1.4h,v2.h[7] - umlal2 v11.4s,v1.8h,v2.h[7] - uxtl v0.8h,v0.8b - b ComputeBlockLoopM2 - -ComputeBlockLoopFinishM2 - umlal v8.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v9.4s,v0.8h,v2.h[2] - add x10,x2,x6,lsl #2 // compute output row 2 - umlal v10.4s,v0.4h,v2.h[6] - umlal2 v11.4s,v0.8h,v2.h[6] - uxtl v1.8h,v1.8b - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - umlal v10.4s,v1.4h,v2.h[7] - umlal2 v11.4s,v1.8h,v2.h[7] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM2 - cbnz x15,SkipAccumulateOutputM2 - ldp q0,q1,[x2] - ldp q2,q3,[x10] - add v8.4s,v8.4s,v0.4s - add v9.4s,v9.4s,v1.4s - add v10.4s,v10.4s,v2.4s - add v11.4s,v11.4s,v3.4s - -SkipAccumulateOutputM2 - stp q8,q9,[x2],#32 - stp q10,q11,[x10] - cbnz x5,ProcessNextColumnLoopM2 - -ExitKernelM2 - mov x0,#2 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM2 - cbz x15,StoreOutputPartialAddModeM2 - -StoreOutputPartialZeroModeM2 - tbz x5,#2,StoreOutputPartial2ZeroModeM2 - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - st1 {v10.4s},[x10],#16 - mov v10.16b,v11.16b - -StoreOutputPartial2ZeroModeM2 - tbz x5,#1,StoreOutputPartial1ZeroModeM2 - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - st1 {v10.2s},[x10],#8 - dup v10.4s,v10.s[2] - -StoreOutputPartial1ZeroModeM2 - tbz x5,#0,ExitKernelM2 - st1 {v8.s}[0],[x2] - st1 {v10.s}[0],[x10] - b ExitKernelM2 - -StoreOutputPartialAddModeM2 - tbz x5,#2,StoreOutputPartial2AddModeM2 - ld1 {v0.4s},[x2] - ld1 {v1.4s},[x10] - add v8.4s,v8.4s,v0.4s - add v10.4s,v10.4s,v1.4s - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - st1 {v10.4s},[x10],#16 - mov v10.16b,v11.16b - -StoreOutputPartial2AddModeM2 - tbz x5,#1,StoreOutputPartial1AddModeM2 - ld1 {v0.2s},[x2] - ld1 {v1.2s},[x10] - add v8.4s,v8.4s,v0.4s - add v10.4s,v10.4s,v1.4s - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - st1 {v10.2s},[x10],#8 - dup v10.4s,v10.s[2] - -StoreOutputPartial1AddModeM2 - tbz x5,#0,ExitKernelM2 - ld1 {v0.s}[0],[x2] - ld1 {v1.s}[0],[x10] - add v8.4s,v8.4s,v0.4s - add v10.4s,v10.4s,v1.4s - st1 {v8.s}[0],[x2] - st1 {v10.s}[0],[x10] - b ExitKernelM2 - -// -// Process 1 row of the matrices. -// - -ProcessNextColumnLoopM1 - ld1 {v0.8b},[x1],#8 // load packed B0 - mov x0,x16 // reload matrix A - ld1 {v2.4s},[x8],#16 // load ColumnSumBuffer0 - mov x3,x17 // reload PackedCountK - ld1 {v3.4s},[x8],#16 // load ColumnSumBuffer1 - uxtl v0.8h,v0.8b - dup v9.4s,v7.s[0] - cbz x9,SkipScaleByZeroPointBM1 - ld1 {v14.4s},[x9],#16 // load ZeroPointB0 - ld1 {v15.4s},[x9],#16 // load ZeroPointB1 - mul v8.4s,v9.4s,v14.4s - mul v9.4s,v9.4s,v15.4s - ldr s4,[x0],#4 // load first packed A0 - add v8.4s,v8.4s,v2.4s - add v9.4s,v9.4s,v3.4s - b ComputeBlockLoopM1 - -SkipScaleByZeroPointBM1 - ldr s4,[x0],#4 // load first packed A0 - add v8.4s,v9.4s,v2.4s - add v9.4s,v9.4s,v3.4s - -ComputeBlockLoopM1 - uxtl v2.8h,v4.8b - ld1 {v1.8b},[x1],#8 // load packed B1 - umlal v8.4s,v0.4h,v2.h[0] - umlal2 v9.4s,v0.8h,v2.h[0] - uxtl v1.8h,v1.8b - ld1 {v0.8b},[x1],#8 // load packed B2 - umlal v8.4s,v1.4h,v2.h[1] - umlal2 v9.4s,v1.8h,v2.h[1] - uxtl v0.8h,v0.8b - ld1 {v1.8b},[x1],#8 // load packed B3 - sub x3,x3,#1 - cbz x3,ComputeBlockLoopFinishM1 - umlal v8.4s,v0.4h,v2.h[2] - umlal2 v9.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - ldr s4,[x0],#4 // load first packed A0 - ld1 {v0.8b},[x1],#8 // load packed B0 - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - uxtl v0.8h,v0.8b - b ComputeBlockLoopM1 - -ComputeBlockLoopFinishM1 - umlal v8.4s,v0.4h,v2.h[2] // finish computing tail vectors - umlal2 v9.4s,v0.8h,v2.h[2] - uxtl v1.8h,v1.8b - umlal v8.4s,v1.4h,v2.h[3] - umlal2 v9.4s,v1.8h,v2.h[3] - subs x5,x5,#8 // adjust CountN remaining - blo StoreOutputPartialM1 - cbnz x15,SkipAccumulateOutputM1 - ldp q0,q1,[x2] - add v8.4s,v8.4s,v0.4s - add v9.4s,v9.4s,v1.4s - -SkipAccumulateOutputM1 - stp q8,q9,[x2],#32 - cbnz x5,ProcessNextColumnLoopM1 - -ExitKernelM1 - mov x0,#1 // return number of rows handled - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -// -// Store the partial 1 to 7 columns either overwriting the output matrix or -// accumulating into the existing contents of the output matrix. -// - -StoreOutputPartialM1 - cbz x15,StoreOutputPartialAddModeM1 - -StoreOutputPartialZeroModeM1 - tbz x5,#2,StoreOutputPartial2ZeroModeM1 - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - -StoreOutputPartial2ZeroModeM1 - tbz x5,#1,StoreOutputPartial1ZeroModeM1 - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - -StoreOutputPartial1ZeroModeM1 - tbz x5,#0,ExitKernelM1 - st1 {v8.s}[0],[x2] - b ExitKernelM1 - -StoreOutputPartialAddModeM1 - tbz x5,#2,StoreOutputPartial2AddModeM1 - ld1 {v0.4s},[x2] - add v8.4s,v8.4s,v0.4s - st1 {v8.4s},[x2],#16 - mov v8.16b,v9.16b // shift remaining elements down - -StoreOutputPartial2AddModeM1 - tbz x5,#1,StoreOutputPartial1AddModeM1 - ld1 {v0.2s},[x2] - add v8.4s,v8.4s,v0.4s - st1 {v8.2s},[x2],#8 - dup v8.4s,v8.s[2] // shift remaining elements down - -StoreOutputPartial1AddModeM1 - tbz x5,#0,ExitKernelM1 - ld1 {v0.s}[0],[x2] - add v8.4s,v8.4s,v0.4s - st1 {v8.s}[0],[x2] - b ExitKernelM1 - - NESTED_END MlasGemmU8X8KernelNeon - - END diff --git a/onnxruntime/core/mlas/lib/arm64ec/SgemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64ec/SgemmKernelNeon.asm deleted file mode 100644 index 3c546b90510a0..0000000000000 --- a/onnxruntime/core/mlas/lib/arm64ec/SgemmKernelNeon.asm +++ /dev/null @@ -1,466 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelNeon.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -;-- - -#include "kxarm64.h" - - TEXTAREA - -; -; ClearRowAccumulators -; -; Generates the code to clear the accumulators for a single row of the output -; block. -; - - MACRO - ClearRowAccumulators $Columns, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - movi $Vec1Reg..16b,#0 - movi $Vec2Reg..16b,#0 - IF $Columns > 8 - movi $Vec3Reg..16b,#0 - movi $Vec4Reg..16b,#0 - ENDIF - - MEND - -; -; ClearBlockAccumulators -; -; Generates the code to clear the accumulators for a single row of the output -; block. -; - - MACRO - ClearBlockAccumulators $Columns, $Rows - - ClearRowAccumulators $Columns, v8, v9, v10, v11 - IF $Rows >= 2 - ClearRowAccumulators $Columns, v12, v13, v14, v15 - ENDIF - - MEND - -; -; LoadMatrixAElementsBy4 -; LoadMatrixAElementsBy1 -; -; Generates the code to load 1 or 4 elements from matrix A. -; - - MACRO - LoadMatrixAElementsBy4 $Rows - - ldr v2,[x0],#16 - IF $Rows >= 2 - ldr v3,[x10],#16 - ENDIF - - MEND - - MACRO - LoadMatrixAElementsBy1 $Rows - - ldr s2,[x0],#4 - IF $Rows >= 2 - ldr s3,[x10],#4 - ENDIF - - MEND - -; -; MultiplyAccumulateRow -; -; Generates the code to multiply and accumulate a single row of the output -; block. -; - - MACRO - MultiplyAccumulateRow $Columns, $MatrixAReg, $Broadcast, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - fmla $Vec1Reg..4s,v4.4s,$MatrixAReg..s[$Broadcast] - fmla $Vec2Reg..4s,v5.4s,$MatrixAReg..s[$Broadcast] - IF $Columns > 8 - fmla $Vec3Reg..4s,v6.4s,$MatrixAReg..s[$Broadcast] - fmla $Vec4Reg..4s,v7.4s,$MatrixAReg..s[$Broadcast] - ENDIF - - MEND - -; -; MultiplyAccumulateBlock -; -; Generates the code to multiply and accumulate into the output block. -; - - MACRO - MultiplyAccumulateBlock $Columns, $Rows, $Broadcast - - MultiplyAccumulateRow $Columns, v2, $Broadcast, v8, v9, v10, v11 - IF $Rows >= 2 - MultiplyAccumulateRow $Columns, v3, $Broadcast, v12, v13, v14, v15 - ENDIF - - MEND - -; -; ComputeBlockLoop -; -; Generates the code to loop over K entries of the input matrices to produce -; the output block. -; - - MACRO - ComputeBlockLoop $Mode, $Columns, $Rows - - ClearBlockAccumulators $Columns, $Rows - - IF $Rows >= 2 - add x10,x0,x6 lsl #2 ; compute matrix A plus 1 row - ENDIF - - sub x9,x3,#4 ; decrement block count to process - tbnz x9,#63,$Mode.ProcessRemaining$Columns.x$Rows.Blocks - -$Mode.Compute$Columns.x$Rows.BlockBy4Loop - LoadMatrixAElementsBy4 $Rows - ldp v4,v5,[x1],#64*4 - IF $Columns > 8 - ldp v6,v7,[x1,#-56*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,0 - ldp v4,v5,[x1,#-48*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-40*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,1 - ldp v4,v5,[x1,#-32*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-24*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,2 - ldp v4,v5,[x1,#-16*4] - IF $Columns > 8 - ldp v6,v7,[x1,#-8*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,3 - sub x9,x9,#4 - tbz x9,#63,$Mode.Compute$Columns.x$Rows.BlockBy4Loop - -$Mode.ProcessRemaining$Columns.x$Rows.Blocks - add x9,x9,#4 ; correct for over-subtract above - cbz x9,$Mode.Output$Columns.x$Rows.Block - -$Mode.Compute$Columns.x$Rows.BlockBy1Loop - LoadMatrixAElementsBy1 $Rows - ldp v4,v5,[x1],#16*4 - IF $Columns > 8 - ldp v6,v7,[x1,#-8*4] - ENDIF - MultiplyAccumulateBlock $Columns,$Rows,0 - sub x9,x9,#1 - cbnz x9,$Mode.Compute$Columns.x$Rows.BlockBy1Loop - -$Mode.Output$Columns.x$Rows.Block - - MEND - -; -; MultiplyAlphaRow -; -; Generates the code to multiply a single row of the output block by the alpha -; value. -; - - MACRO - MultiplyAlphaRow $Columns, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF $Columns <= 4 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - ELIF $Columns <= 8 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - ELIF $Columns <= 12 - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - fmul $Vec3Reg..4s,$Vec3Reg..4s,v0.s[0] - ELSE - fmul $Vec1Reg..4s,$Vec1Reg..4s,v0.s[0] - fmul $Vec2Reg..4s,$Vec2Reg..4s,v0.s[0] - fmul $Vec3Reg..4s,$Vec3Reg..4s,v0.s[0] - fmul $Vec4Reg..4s,$Vec4Reg..4s,v0.s[0] - ENDIF - - MEND - -; -; MultiplyAlphaBlock -; -; Generates the code to multiply the output block by the alpha value. -; - - MACRO - MultiplyAlphaBlock $Columns, $Rows - - MultiplyAlphaRow $Columns, v8, v9, v10, v11 - IF $Rows >= 2 - MultiplyAlphaRow $Columns, v12, v13, v14, v15 - ENDIF - - MEND - -; -; OutputRow1Element -; OutputRow2Element -; OutputRow4Element -; OutputRow8Element -; OutputRow16Element -; -; Generates the code to store elements to the output block. -; - - MACRO - OutputRow1Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.s}[0],[$AddrReg] - fmla v4.2s,$Vec1Reg..2s,v0.s[0] - st1 {v4.s}[0],[$AddrReg] ; post-increment not needed for last element - ELSE - st1 {$Vec1Reg..s}[0],[$AddrReg] ; post-increment not needed for last element - ENDIF - - MEND - - MACRO - OutputRow2Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.2s},[$AddrReg] - fmla v4.2s,$Vec1Reg..2s,v0.s[0] - st1 {v4.2s},[$AddrReg],#2*4 - ELSE - st1 {$Vec1Reg..2s},[$AddrReg],#2*4 - ENDIF - dup $Vec1Reg..4s,$Vec1Reg..s[2] ; shift remaining elements down - - MEND - - MACRO - OutputRow4Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ld1 {v4.4s},[$AddrReg] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - st1 {v4.4s},[$AddrReg],#4*4 - ELSE - st1 {$Vec1Reg..4s},[$AddrReg],#4*4 - ENDIF - mov $Vec1Reg..16b,$Vec2Reg..16b ; shift remaining elements down - - MEND - - MACRO - OutputRow8Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ldp v4,v5,[$AddrReg] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - fmla v5.4s,$Vec2Reg..4s,v0.s[0] - stp v4,v5,[$AddrReg],#8*4 - ELSE - stp $Vec1Reg.,$Vec2Reg.,[$AddrReg],#8*4 - ENDIF - mov $Vec1Reg..16b,$Vec3Reg..16b ; shift remaining elements down - mov $Vec2Reg..16b,$Vec4Reg..16b - - MEND - - MACRO - OutputRow16Element $Mode, $AddrReg, $Vec1Reg, $Vec2Reg, $Vec3Reg, $Vec4Reg - - IF "$Mode"=="Add" - ldp v4,v5,[$AddrReg] - ldp v6,v7,[$AddrReg,#8*4] - fmla v4.4s,$Vec1Reg..4s,v0.s[0] - fmla v5.4s,$Vec2Reg..4s,v0.s[0] - fmla v6.4s,$Vec3Reg..4s,v0.s[0] - fmla v7.4s,$Vec4Reg..4s,v0.s[0] - stp v4,v5,[$AddrReg],#16*4 - stp v6,v7,[$AddrReg,#-8*4] - ELSE - stp $Vec1Reg.,$Vec2Reg.,[$AddrReg],#16*4 - stp $Vec3Reg.,$Vec4Reg.,[$AddrReg,#-8*4] - ENDIF - - MEND - -; -; OutputBlock -; -; Generates the code to store the output block. -; - - MACRO - OutputBlock $Mode, $Columns, $Rows - - OutputRow$Columns.Element $Mode, x2, v8, v9, v10, v11 - IF $Rows >= 2 - OutputRow$Columns.Element $Mode, x11, v12, v13, v14, v15 - ENDIF - - MEND - -; -; ProcessRows -; -; Generates the code to process a compute and store the output block for a -; fixed number of rows. -; - - MACRO - ProcessRows $Mode, $Rows - - mov x4,#$Rows ; return number of rows handled - cmp x5,#8 - ble $Mode.ProcessRemainingCountN$Rows - -$Mode.ProcessNextColumnLoop16x$Rows - ComputeBlockLoop $Mode,16,$Rows - IF "$Mode"=="Zero" - MultiplyAlphaBlock 16,$Rows - ENDIF - sub x5,x5,#16 - tbnz x5,#63,$Mode.OutputMasked16x$Rows.Block - OutputBlock $Mode,16,$Rows - mov x0,x8 ; reload matrix A - cmp x5,#8 - bgt $Mode.ProcessNextColumnLoop16x$Rows - cbz x5,$Mode.ExitKernel - -$Mode.ProcessRemainingCountN$Rows - ComputeBlockLoop $Mode,8,$Rows - IF "$Mode"=="Zero" - MultiplyAlphaBlock 8,$Rows - ENDIF - -$Mode.OutputMasked16x$Rows.Block - tbz x5,#3,$Mode.OutputRemaining7x$Rows.Block - OutputBlock $Mode,8,$Rows - -$Mode.OutputRemaining7x$Rows.Block - tbz x5,#2,$Mode.OutputRemaining3x$Rows.Block - OutputBlock $Mode,4,$Rows - -$Mode.OutputRemaining3x$Rows.Block - tbz x5,#1,$Mode.OutputRemaining1x$Rows.Block - OutputBlock $Mode,2,$Rows - -$Mode.OutputRemaining1x$Rows.Block - tbz x5,#0,$Mode.ExitKernel - OutputBlock $Mode,1,$Rows - - MEND - - SUBT "SGEMM kernel" -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A (x0) - Supplies the address of matrix A. -; -; B (x1) - Supplies the address of matrix B. The matrix data has been packed -; using MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C (x2) - Supplies the address of matrix C. -; -; CountK (x3) - Supplies the number of columns from matrix A and the number -; of rows from matrix B to iterate over. -; -; CountM (x4) - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN (x5) - Supplies the number of columns from matrix B and matrix C to -; iterate over. -; -; lda (x6) - Supplies the first dimension of matrix A. -; -; ldc (x7) - Supplies the first dimension of matrix C. -; -; Alpha (s0) - Supplies the scalar multiplier (see SGEMM definition). -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - - MACRO - SgemmKernelNeonFunction $Mode - - NESTED_ENTRY_COMDAT A64NAME(MlasSgemmKernel$Mode) - - PROLOG_SAVE_REG_PAIR d8,d9,#-64! - PROLOG_SAVE_REG_PAIR d10,d11,#16 - PROLOG_SAVE_REG_PAIR d12,d13,#32 - PROLOG_SAVE_REG_PAIR d14,d15,#48 - - add x11,x2,x7 lsl #2 ; compute matrix C plus 1 row - mov x8,x0 ; save matrix A - -; -; Process 2 rows of the matrices. -; - - cmp x4,#2 - blt $Mode.ProcessCountMLessThan2 - ProcessRows $Mode,2 - -; -; Restore non-volatile registers and return. -; - -$Mode.ExitKernel - mov x0,x4 - EPILOG_RESTORE_REG_PAIR d14,d15,#48 - EPILOG_RESTORE_REG_PAIR d12,d13,#32 - EPILOG_RESTORE_REG_PAIR d10,d11,#16 - EPILOG_RESTORE_REG_PAIR d8,d9,#64! - EPILOG_RETURN - -; -; Process 1 row of the matrices. -; - -$Mode.ProcessCountMLessThan2 - ProcessRows $Mode,1 - b $Mode.ExitKernel - - NESTED_END - - MEND - - SgemmKernelNeonFunction Zero - SgemmKernelNeonFunction Add - - END diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp deleted file mode 100644 index 9b5800b08edbc..0000000000000 --- a/onnxruntime/core/mlas/lib/cast.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/*++ - -Copyright (c) Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - cast.cpp - -Abstract: - - This module implements Half (F16) to Single (F32) precision casting. - ---*/ -#include "mlasi.h" - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const MLAS_FP16* Source, - float* Destination, - size_t Count -) -{ - if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { - for (size_t i = 0; i < Count; ++i) { - Destination[i] = Source[i].ToFloat(); - } - } else { - // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast(Source), Destination, Count); - } -} - -void -MLASCALL -MlasConvertFloatToHalfBuffer( - const float* Source, - MLAS_FP16* Destination, - size_t Count -) -{ - if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) { - for (size_t i = 0; i < Count; ++i) { - Destination[i] = MLAS_FP16(Source[i]); - } - } else { - // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast(Destination), Count); - } -} diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp deleted file mode 100644 index 73df23e64ca1f..0000000000000 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ /dev/null @@ -1,989 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - compute.cpp - -Abstract: - - This module implements miscellaneous computation routines. - - Our usage requires building platform specific versions of the algorithm to - target different instruction sets. The implementation below targets the - base instruction set (typically SSE2) while assembly implementations target - newer instruction sets (such as FMA3). - ---*/ - -#include "mlasi.h" - -// -// Bundles the constants for use by kernels written in assembly. -// - -MLAS_INTERNAL_DATA const struct { - float LowerRange; - float UpperRange; - float LowerRangeSumExp; - float UpperRangeSumExp; - float RoundingBias; - float Log2Reciprocal; - float Log2High; - float Log2Low; - float poly_0; - float poly_1; - float poly_2; - float poly_3; - float poly_4; - float poly_56; - int32_t MinimumExponent; - int32_t MaximumExponent; -} MlasExpConstants = { - -103.9720840454f, - 88.7762626647950f, - -88.3762626647949f, - 88.3762626647949f, - MLAS_ROUNDING_BIAS_MAGIC, - 1.44269504088896341f, - -6.93145752e-1f, - -1.42860677e-6f, - 0x1.694000p-10, - 0x1.125edcp-7, - 0x1.555b5ap-5, - 0x1.555450p-3, - 0x1.fffff6p-2, - 0x1.000000p+0, - int32_t(0xC1000000), - int32_t(0x3F800000), -}; - -MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits::lowest(); - -// -// Define the parameters to execute segments of a softmax operation on worker -// threads. -// - -struct MLAS_SOFTMAX_WORK_BLOCK { - ptrdiff_t ThreadCountN; - bool LogSoftmax; - bool SmoothSoftmax; - const float* Input; - float* Output; - size_t N; - size_t D; -}; - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasComputeExpVector( - MLAS_FLOAT32X4 Vector -) -/*++ - -Routine Description: - - This routine computes the exponential function for the supplied vector. - - This merges ideas from multiple vectorized expf() implementations: - - 1. The original polynomials of expf() are extracted from MlasComputeErf, which - was based on an answer to the following Stack Overflow post: - - https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff - - 2. The author of the answer further refined the polynomials at: - - https://forums.developer.nvidia.com/t/a-more-accurate-performance-competitive-implementation-of-expf/47528/5 - - Using these polynomials yields even closer results to the Microsoft - UCRT version of std::expf() than the values from the above post. - - 3. XNNPACK has a further useful refinement to extend the effective - range of results from [-88.376, 88.376] to [-103.972, 88.776] by - splitting the step of exponent reconstruction into two pieces. This - yields results similar to an AVX512 implementation using VSCALEFPS. - -Arguments: - - Vector - Supplies the values to operate on. - -Return Value: - - Returns the exponential function of the input. - ---*/ -{ - Vector = MlasClampFloat32x4(Vector, MlasExpConstants.LowerRange, MlasExpConstants.UpperRange); - - // - // Range reduction of the input by computing "(2 ^ m) * exp(reduced)". - // - - const auto RoundingBias = MlasBroadcastFloat32x4(MlasExpConstants.RoundingBias); - - auto biased = MlasMultiplyAddFloat32x4(Vector, MlasExpConstants.Log2Reciprocal, RoundingBias); - auto m = MlasSubtractFloat32x4(biased, RoundingBias); - - Vector = MlasMultiplyAddFloat32x4(m, MlasExpConstants.Log2High, Vector); - Vector = MlasMultiplyAddFloat32x4(m, MlasExpConstants.Log2Low, Vector); - - // - // Compute the scaling factors used to reconstruct the "(2 ^ m)" value - // from above. To cover the entire single precision floating point range, - // two scaling factors are needed to handle exponents [-150, 128]. - // - - const auto MinimumExponent = MlasBroadcastInt32x4(MlasExpConstants.MinimumExponent); - const auto MaximumExponent = MlasBroadcastInt32x4(MlasExpConstants.MaximumExponent); - - auto overflow = MlasShiftLeftInt32x4<23>(MlasReinterpretAsInt32x4(biased)); - auto normal = overflow; -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. PMINSD/PMAXSD were not added until SSE 4.1, but the lower 16 bits - // are zero, so they can be ignored for this computation, so use PMINSW/PMAXSW - // instead. - normal = _mm_min_epi16(normal, MaximumExponent); - normal = _mm_max_epi16(normal, MinimumExponent); -#elif defined(MLAS_LSX_INTRINSICS) - normal = __lsx_vmin_h(normal, MaximumExponent); - normal = __lsx_vmax_h(normal, MinimumExponent); -#else - normal = MlasMinimumInt32x4(normal, MaximumExponent); - normal = MlasMaximumInt32x4(normal, MinimumExponent); -#endif - overflow = MlasSubtractInt32x4(overflow, normal); - overflow = MlasAddInt32x4(overflow, MaximumExponent); - normal = MlasAddInt32x4(normal, MaximumExponent); - - // - // Compute the polynomial approximation of exp(reduced) and reconstruct - // the final result using the above scaling factors. The final term of - // the polynomial (poly_6=1.0f) is merged as the multiply/add of the - // overflow exponent (reference XNNPACK). - // - - auto p = MlasBroadcastFloat32x4(MlasExpConstants.poly_0); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_1); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_2); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_3); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_4); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_56); - - Vector = MlasMultiplyFloat32x4(Vector, MlasReinterpretAsFloat32x4(overflow)); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasReinterpretAsFloat32x4(overflow)); - p = MlasMultiplyFloat32x4(p, MlasReinterpretAsFloat32x4(normal)); - - return p; -} - -void -MLASCALL -MlasComputeExpF32Kernel( - const float* Input, - float* Output, - size_t N -) -/*++ - -Routine Description: - - This routine implements the generic kernel for the exponential function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ - while (N > 0) { - MLAS_FLOAT32X4 Vector; - - if (N >= 4) { - Vector = MlasLoadFloat32x4(Input); - } else { -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle - // and use zeroes for the upper elements. - Vector = _mm_load_ss(Input); -#elif defined(MLAS_LSX_INTRINSICS) - Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); -#else - Vector = MlasBroadcastFloat32x4(Input); -#endif - } - - Vector = MlasComputeExpVector(Vector); - - if (N >= 4) { - MlasStoreFloat32x4(Output, Vector); - - Input += 4; - Output += 4; - N -= 4; - - } else { - MlasStoreLaneFloat32x4<0>(Output, Vector); - - Input += 1; - Output += 1; - N -= 1; - } - } -} - -void -MLASCALL -MlasComputeExp( - const float* Input, - float* Output, - size_t N -) -/*++ - -Routine Description: - - This routine computes the exponential function. - - N.B. This implementation supports in place updates of the output buffer. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().ComputeExpF32Kernel(Input, Output, N); -#else - MlasComputeExpF32Kernel(Input, Output, N); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasComputeSumExpVector( - MLAS_FLOAT32X4 Vector, - MLAS_FLOAT32X4 NegativeMaximumVector -) -/*++ - -Routine Description: - - This routine computes the exponential function for the supplied vector. - - This function handles a narrower range of inputs compared to - MlasComputeExpVector in order to improve efficiency. - -Arguments: - - Vector - Supplies the values to operate on. - - NegativeMaximumVector - Supplies the broadcasted negative maximum - value that is added to each element before computing the exponential - function. - -Return Value: - - Returns the exponential function of the input. - ---*/ -{ - // - // Subtract the maximum value from every element. - // - // N.B. For each of use by the assembly kernels, this value has been negated - // so add the value instead. - // - - Vector = MlasAddFloat32x4(Vector, NegativeMaximumVector); - - // - // Clamp to the lower range of this function. - // - // The value should already be negative or equal to zero as every value has - // been reduced by the maximum value. - // - -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. MINPS and MAXPS propagates the value from the second vector if the - // value is a NaN. -#endif - Vector = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasExpConstants.LowerRangeSumExp), Vector); - - // - // Range reduction of the input by computing "(2 ^ m) * exp(reduced)". - // - - const auto RoundingBias = MlasBroadcastFloat32x4(MlasExpConstants.RoundingBias); - - auto biased = MlasMultiplyAddFloat32x4(Vector, MlasExpConstants.Log2Reciprocal, RoundingBias); - auto m = MlasSubtractFloat32x4(biased, RoundingBias); - - Vector = MlasMultiplyAddFloat32x4(m, MlasExpConstants.Log2High, Vector); - Vector = MlasMultiplyAddFloat32x4(m, MlasExpConstants.Log2Low, Vector); - - // - // Compute the scaling factor used to reconstruct the "(2 ^ m)" value - // from above. The effective range of this function is smaller than - // MlasComputeExp to reduce the number of operations. - // - - auto normal = MlasShiftLeftInt32x4<23>(MlasReinterpretAsInt32x4(biased)); - normal = MlasAddInt32x4(normal, MlasBroadcastInt32x4(MlasExpConstants.MaximumExponent)); - - // - // Compute the polynomial approximation of exp(reduced) and reconstruct - // the final result using the above scale factor. - // - - auto p = MlasBroadcastFloat32x4(MlasExpConstants.poly_0); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_1); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_2); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_3); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_4); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_56); - p = MlasMultiplyAddFloat32x4(p, Vector, MlasExpConstants.poly_56); - - p = MlasMultiplyFloat32x4(p, MlasReinterpretAsFloat32x4(normal)); - - return p; -} - -float -MLASCALL -MlasComputeSumExpF32Kernel( - const float* Input, - float* Output, - size_t N, - const float* NegativeMaximum -) -/*++ - -Routine Description: - - This routine implements the generic kernel for the sum of exponential - functions. - -Arguments: - - Input - Supplies the input buffer. - - Output - Optionally supplies the output buffer. When used for Softmax, - the output buffer is used to store the intermediate exp() results. When - used for LogSoftmax, the intermediate exp() results are not required. - - N - Supplies the number of elements to process. - - NegativeMaximum - Supplies the address of the negative maximum - value that is added to each element before computing the exponential - function. - -Return Value: - - Returns the sum of the exponential functions. - ---*/ -{ - MLAS_FLOAT32X4 NegativeMaximumVector = MlasBroadcastFloat32x4(*NegativeMaximum); - float Accumulator = 0.0f; - - if (N >= 4) { - MLAS_FLOAT32X4 AccumulatorVector = MlasZeroFloat32x4(); - -#if !defined(MLAS_SSE2_INTRINSICS) - - // - // Unroll the loop for architectures that can benefit from improved - // instruction level parallelism. - // - // N.B. The extra code size is not worth the benefit for SSE2 as the - // MLAS_TARGET_AMD64 build already has specialized AVX2/AVX512F kernels - // that do this. - // - - while (N >= 8) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); - - Vector0 = MlasComputeSumExpVector(Vector0, NegativeMaximumVector); - Vector1 = MlasComputeSumExpVector(Vector1, NegativeMaximumVector); - AccumulatorVector = MlasAddFloat32x4(AccumulatorVector, Vector0); - AccumulatorVector = MlasAddFloat32x4(AccumulatorVector, Vector1); - - if (Output != nullptr) { - MlasStoreFloat32x4(Output, Vector0); - MlasStoreFloat32x4(Output + 4, Vector1); - Output += 8; - } - - Input += 8; - N -= 8; - } - -#endif - - while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); - - Vector = MlasComputeSumExpVector(Vector, NegativeMaximumVector); - AccumulatorVector = MlasAddFloat32x4(AccumulatorVector, Vector); - - if (Output != nullptr) { - MlasStoreFloat32x4(Output, Vector); - Output += 4; - } - - Input += 4; - N -= 4; - } - - Accumulator = MlasReduceAddFloat32x4(AccumulatorVector); - } - - while (N > 0) { -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and - // use zeroes for the upper elements. - MLAS_FLOAT32X4 Vector = _mm_load_ss(Input); -#elif defined(MLAS_LSX_INTRINSICS) - MLAS_FLOAT32X4 Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0); -#else - MLAS_FLOAT32X4 Vector = MlasBroadcastFloat32x4(Input); -#endif - - Vector = MlasComputeSumExpVector(Vector, NegativeMaximumVector); - Accumulator += MlasExtractLaneFloat32x4<0>(Vector); - - if (Output != nullptr) { - MlasStoreLaneFloat32x4<0>(Output, Vector); - Output += 1; - } - - Input += 1; - N -= 1; - } - - return Accumulator; -} - -float -MLASCALL -MlasReduceMaximumF32Kernel( - const float* Input, - size_t N -) -/*++ - -Routine Description: - - This routine implements the generic kernel to find the maximum value of - the supplied buffer. - -Arguments: - - Input - Supplies the input buffer. - - N - Supplies the number of elements to process. - -Return Value: - - Returns the maximum value of the supplied buffer. - ---*/ -{ - float Maximum = MlasMinimumF32Value; - - if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(Maximum); - - if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; - - while (N >= 16) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); - MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4)); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8)); - MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, MlasLoadFloat32x4(Input + 12)); - - Input += 16; - N -= 16; - } - - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3); - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2); - } - - while (N >= 4) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); - - Input += 4; - N -= 4; - } - - Maximum = MlasReduceMaximumFloat32x4(MaximumVector0); - } - - while (N > 0) { - Maximum = std::max(Maximum, *Input); - - Input += 1; - N -= 1; - } - - return Maximum; -} - -void -MLASCALL -MlasReduceMinimumMaximumF32Kernel( - const float* Input, - float* Min, - float* Max, - size_t N -) -{ - float tmp_min = std::numeric_limits::max(); - float tmp_max = std::numeric_limits::lowest(); - - if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(tmp_max); - MLAS_FLOAT32X4 MinimumVector0 = MlasBroadcastFloat32x4(tmp_min); - - if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; - - MLAS_FLOAT32X4 MinimumVector1 = MinimumVector0; - MLAS_FLOAT32X4 MinimumVector2 = MinimumVector0; - MLAS_FLOAT32X4 MinimumVector3 = MinimumVector0; - - while (N >= 16) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(Input + 4); - MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(Input + 8); - MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(Input + 12); - - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, InputVector0); - MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, InputVector1); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, InputVector2); - MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, InputVector3); - - MinimumVector0 = MlasMinimumFloat32x4(MinimumVector0, InputVector0); - MinimumVector1 = MlasMinimumFloat32x4(MinimumVector1, InputVector1); - MinimumVector2 = MlasMinimumFloat32x4(MinimumVector2, InputVector2); - MinimumVector3 = MlasMinimumFloat32x4(MinimumVector3, InputVector3); - - Input += 16; - N -= 16; - } - - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3); - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2); - - MinimumVector0 = MlasMinimumFloat32x4(MinimumVector0, MinimumVector1); - MinimumVector2 = MlasMinimumFloat32x4(MinimumVector2, MinimumVector3); - MinimumVector0 = MlasMinimumFloat32x4(MinimumVector0, MinimumVector2); - } - - while (N >= 4) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, InputVector0); - - MinimumVector0 = MlasMinimumFloat32x4(MinimumVector0, InputVector0); - - Input += 4; - N -= 4; - } - - tmp_min = MlasReduceMinimumFloat32x4(MinimumVector0); - tmp_max = MlasReduceMaximumFloat32x4(MaximumVector0); - } - - while (N > 0) { - tmp_max = std::max(tmp_max, *Input); - tmp_min = std::min(tmp_min, *Input); - - Input += 1; - N -= 1; - } - - *Min = tmp_min; - *Max = tmp_max; -} - -void -MLASCALL -MlasComputeSoftmaxOutputF32Kernel( - float* Output, - size_t N, - const float* Parameters -) -/*++ - -Routine Description: - - This routine implements the generic kernel to produce the final output for - the softmax operation. - -Arguments: - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Parameters - Supplies an array containing the scale value. - -Return Value: - - None. - ---*/ -{ - const float Scale = Parameters[0]; - - const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale); - - while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output)); - MLAS_FLOAT32X4 Vector1 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 4)); - MLAS_FLOAT32X4 Vector2 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 8)); - MLAS_FLOAT32X4 Vector3 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 12)); - - MlasStoreFloat32x4(Output, Vector0); - MlasStoreFloat32x4(Output + 4, Vector1); - MlasStoreFloat32x4(Output + 8, Vector2); - MlasStoreFloat32x4(Output + 12, Vector3); - - Output += 16; - N -= 16; - } - - while (N >= 4) { - MlasStoreFloat32x4(Output, MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output))); - - Output += 4; - N -= 4; - } - - while (N > 0) { - *Output *= Scale; - - Output += 1; - N -= 1; - } -} - -void -MLASCALL -MlasComputeLogSoftmaxOutputF32Kernel( - const float* Input, - float* Output, - size_t N, - const float* Parameters -) -/*++ - -Routine Description: - - This routine implements the generic kernel to produce the final output for - the log softmax operation. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Parameters - Supplies an array containing the negative maximum and - logarithm values. - -Return Value: - - None. - ---*/ -{ - const float NegativeMaximum = Parameters[0]; - const float Logarithm = Parameters[1]; - - const MLAS_FLOAT32X4 NegativeMaximumVector = MlasBroadcastFloat32x4(NegativeMaximum); - const MLAS_FLOAT32X4 LogarithmVector = MlasBroadcastFloat32x4(Logarithm); - - while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); - MLAS_FLOAT32X4 Vector2 = MlasLoadFloat32x4(Input + 8); - MLAS_FLOAT32X4 Vector3 = MlasLoadFloat32x4(Input + 12); - - Vector0 = MlasAddFloat32x4(Vector0, NegativeMaximumVector); - Vector1 = MlasAddFloat32x4(Vector1, NegativeMaximumVector); - Vector2 = MlasAddFloat32x4(Vector2, NegativeMaximumVector); - Vector3 = MlasAddFloat32x4(Vector3, NegativeMaximumVector); - - Vector0 = MlasSubtractFloat32x4(Vector0, LogarithmVector); - Vector1 = MlasSubtractFloat32x4(Vector1, LogarithmVector); - Vector2 = MlasSubtractFloat32x4(Vector2, LogarithmVector); - Vector3 = MlasSubtractFloat32x4(Vector3, LogarithmVector); - - MlasStoreFloat32x4(Output, Vector0); - MlasStoreFloat32x4(Output + 4, Vector1); - MlasStoreFloat32x4(Output + 8, Vector2); - MlasStoreFloat32x4(Output + 12, Vector3); - - Input += 16; - Output += 16; - N -= 16; - } - - while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); - Vector = MlasAddFloat32x4(Vector, NegativeMaximumVector); - Vector = MlasSubtractFloat32x4(Vector, LogarithmVector); - MlasStoreFloat32x4(Output, Vector); - - Input += 4; - Output += 4; - N -= 4; - } - - while (N > 0) { - *Output = *Input + NegativeMaximum - Logarithm; - - Input += 1; - Output += 1; - N -= 1; - } -} - -void -MlasComputeSoftmaxThreaded( - void* Context, - ptrdiff_t Index -) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - softmax or log softmax operation. - -Arguments: - - Context - Supplies the pointer to the context for the threaded operation. - - ThreadId - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; - - // - // Partition the operation along the N dimension. - // - - size_t n; - size_t CountN; - - MlasPartitionWork(Index, WorkBlock->ThreadCountN, WorkBlock->N, &n, &CountN); - - // - // Compute the softmax or log softmax function. - // - - const size_t D = WorkBlock->D; - const bool LogSoftmax = WorkBlock->LogSoftmax; - const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; - - const float* Input = WorkBlock->Input + n * D; - float* Output = WorkBlock->Output + n * D; - -#if defined(MLAS_SSE2_INTRINSICS) - // TODO: Use std::hardware_constructive_interference_size - constexpr size_t CacheLineSize = 64; - constexpr size_t ElementsPerCacheLine = CacheLineSize / sizeof(float); -#endif - - while (CountN > 0) { -#if defined(MLAS_SSE2_INTRINSICS) - // - // Prefetch the next row of the input buffer. - // - - for (size_t i = 0; i * ElementsPerCacheLine < D; i++) { - _mm_prefetch((char*)(Input + D) + i * CacheLineSize, _MM_HINT_T0); - } -#endif - - // - // Find the maximum value for the row. - // - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); -#else - float Maximum = MlasReduceMaximumF32Kernel(Input, D); -#endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; - } - - // - // Compute the exponential function for each element of the row (save to Temp if provided) and - // compute the sum of these exponential functions. - // - float* Temp = LogSoftmax ? nullptr : Output; -#if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); -#else - float Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); -#endif - - if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); - } - - if (LogSoftmax) { - // - // Compute the log softmax output. - // - float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); -#else - MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); -#endif - - } else { - // - // Normalize the softmax output. - // - float Parameters[] = {1.0f / Accumulation}; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); -#else - MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); -#endif - } - - Input += D; - Output += D; - CountN--; - } -} - -void -MLASCALL -MlasComputeSoftmax( - const float* Input, - float* Output, - size_t N, - size_t D, - bool LogSoftmax, - bool SmoothSoftmax, - MLAS_THREADPOOL* ThreadPool -) -/*++ - -Routine Description: - - This routine computes the softmax or log softmax function. - - N.B. This implementation supports in place updates of the output buffer. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of rows to process. - - D - Supplies the number of columns per row to process. - - LogSoftmax - Supplies true if this is a log softmax operation, else false - if this is a softmax operation. - - SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_SOFTMAX_WORK_BLOCK WorkBlock; - - // - // Capture the softmax parameters to the work block. - // - - WorkBlock.LogSoftmax = LogSoftmax; - WorkBlock.SmoothSoftmax = SmoothSoftmax; - WorkBlock.Input = Input; - WorkBlock.Output = Output; - WorkBlock.N = N; - WorkBlock.D = D; - - // - // Compute the number of target threads given the complexity of the softmax - // operation. Limit the number of threads to the number of rows and try to - // keep each thread processing a minimum number of elements before using - // another thread. - // - - ptrdiff_t ThreadCountN = MlasGetMaximumThreadCount(ThreadPool); - - if (size_t(ThreadCountN) > N) { - ThreadCountN = ptrdiff_t(N); - } - - constexpr size_t MinimumElementsPerThread = 16384; - - size_t BlockCount = ((N * D) / MinimumElementsPerThread) + 1; - - if (size_t(ThreadCountN) > BlockCount) { - ThreadCountN = ptrdiff_t(BlockCount); - } - - WorkBlock.ThreadCountN = ThreadCountN; - - MlasExecuteThreaded(MlasComputeSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp deleted file mode 100644 index ec79641559c6b..0000000000000 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ /dev/null @@ -1,1302 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - convolve.cpp - -Abstract: - - This module implements the convolution operation. - ---*/ - -#include "mlasi.h" - -// -// Define the number of working buffer elements required per thread. -// - -#define MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD \ - (MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK) - -// -// Define the parameters to execute segments of a convolution operation on -// worker threads. -// - -struct MLAS_CONV_WORK_BLOCK { - const MLAS_CONV_PARAMETERS* Parameters; - const float* Input; - const float* Filter; - const float* Bias; - float* WorkingBuffer; - float* Output; - struct SEGMENT { - size_t StartN; - size_t CountN; - } Segments[MLAS_MAXIMUM_THREAD_COUNT]; - ptrdiff_t TargetThreadCount; -}; - -void -MlasConvIm2Col( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - float* ColumnBuffer, - size_t k, - size_t CountK, - size_t n, - size_t CountN - ) -/*++ - -Routine Description: - - This routine converts the input image to a set of convolution patches - appropriate for use with a GEMM operation. - - This implementation supports sampling a portion of the convolution - patches. This avoids the need to allocate very large buffers to store - all of the convolution patches at once, when the underlying GEMM - implementation will already break up the operation into panels. Multiple - threads can also be used to process different portions of the image. - -Arguments: - - Parameters - Supplies the structure that contains the convolution - parameters. - - Input - Supplies the input tensor. - - ColumnBuffer - Supplies the buffer to receive the convolution patches. - - k - Supplies the K to begin sampling the convolution patches. - - CountK - Supplies the count of K to sample for the convolution patches. - - n - Supplies the N to begin sampling the convolution patches. - - CountN - Supplies the count of N to sample for the convolution patches. - -Return Value: - - None. - ---*/ -{ - constexpr size_t HeightShapeIndex = 0; - constexpr size_t WidthShapeIndex = 1; - - const size_t OutputWidth = Parameters->OutputShape[WidthShapeIndex]; - - const size_t StrideHeight = Parameters->StrideShape[HeightShapeIndex]; - const size_t StrideWidth = Parameters->StrideShape[WidthShapeIndex]; - - const size_t nx = (n % OutputWidth); - const size_t ny = (n / OutputWidth); - - const size_t OriginInputX = nx * StrideWidth; - const size_t OriginInputY = ny * StrideHeight; - - size_t OutputCountX = OutputWidth - nx; - - const size_t InputHeight = Parameters->InputShape[HeightShapeIndex]; - const size_t InputWidth = Parameters->InputShape[WidthShapeIndex]; - const size_t InputSize = Parameters->InputSize; - - const size_t KernelHeight = Parameters->KernelShape[HeightShapeIndex]; - const size_t KernelWidth = Parameters->KernelShape[WidthShapeIndex]; - - size_t kx = (k % KernelWidth); - size_t ky = (k / KernelWidth) % KernelHeight; - - Input = Input + (k / (KernelHeight * KernelWidth)) * InputSize; - - const size_t DilationHeight = Parameters->DilationShape[HeightShapeIndex]; - const size_t DilationWidth = Parameters->DilationShape[WidthShapeIndex]; - - const size_t PaddingLeftY = Parameters->Padding[HeightShapeIndex]; - const size_t PaddingLeftX = Parameters->Padding[WidthShapeIndex]; - - for (size_t EndingK = k + CountK; k < EndingK; k++) { - - size_t CountX = OutputCountX; - size_t InputY = (ky * DilationHeight) + OriginInputY - PaddingLeftY; - const size_t RowInitialInputX = (kx * DilationWidth) - PaddingLeftX; - size_t InitialInputX = RowInitialInputX + OriginInputX; - size_t RemainingN = CountN; - - do { - - if (CountX > RemainingN) { - CountX = RemainingN; - } - - RemainingN -= CountX; - - // - // Check if the input is in the top/bottom padding region. - // - - if (InputY < InputHeight) { - - size_t InputX = InitialInputX; - const float* InputRow = &Input[InputY * InputWidth]; - - do { - - // - // Check if the input is in the left/right padding region. - // - - if (InputX >= InputWidth) { - - *ColumnBuffer++ = 0; - InputX += StrideWidth; - CountX--; - - } else if (StrideWidth == 1) { - - // - // Copy input elements to the column buffer. - // - - size_t CountCopyX = InputWidth - InputX; - - if (CountCopyX > CountX) { - CountCopyX = CountX; - } - - CountX -= CountCopyX; - - while (CountCopyX >= 4) { - MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(&InputRow[InputX])); - ColumnBuffer += 4; - InputX += 4; - CountCopyX -= 4; - } - - while (CountCopyX > 0) { - *ColumnBuffer++ = InputRow[InputX++]; - CountCopyX--; - } - - } else if (InputX + CountX * StrideWidth <= InputWidth) { - - do { - *ColumnBuffer++ = InputRow[InputX]; - InputX += StrideWidth; - } while (--CountX > 0); - - } else { - - do { - *ColumnBuffer++ = (InputX < InputWidth) ? InputRow[InputX] : 0; - InputX += StrideWidth; - } while (--CountX > 0); - } - - } while (CountX > 0); - - } else { - - // - // The entire input row is in the padding region. - // - - MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - while (CountX >= 4) { - MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4); - ColumnBuffer += 4; - CountX -= 4; - } - - while (CountX > 0) { - MlasStoreLaneFloat32x4<0>(ColumnBuffer, ZeroFloat32x4); - ColumnBuffer++; - CountX--; - } - } - - CountX = OutputWidth; - InputY += StrideHeight; - InitialInputX = RowInitialInputX; - - } while (RemainingN > 0); - - // - // Advance the kernel indices and advance to the next channel if the - // entire kernel is complete. - // - - if (++kx == KernelWidth) { - - if (++ky == KernelHeight) { - - Input += InputSize; - - ky = 0; - } - - kx = 0; - } - } -} - -void -MlasConvVol2Col( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - float* ColumnBuffer, - size_t k, - size_t CountK, - size_t n, - size_t CountN - ) -/*++ - -Routine Description: - - This routine converts the input volume to a set of convolution patches - appropriate for use with a GEMM operation. - - This implementation supports sampling a portion of the convolution - patches. This avoids the need to allocate very large buffers to store - all of the convolution patches at once, when the underlying GEMM - implementation will already break up the operation into panels. Multiple - threads can also be used to process different portions of the image. - -Arguments: - - Parameters - Supplies the structure that contains the convolution - parameters. - - Input - Supplies the input tensor. - - ColumnBuffer - Supplies the buffer to receive the convolution patches. - - k - Supplies the K to begin sampling the convolution patches. - - CountK - Supplies the count of K to sample for the convolution patches. - - n - Supplies the N to begin sampling the convolution patches. - - CountN - Supplies the count of N to sample for the convolution patches. - -Return Value: - - None. - ---*/ -{ - constexpr size_t DepthShapeIndex = 0; - constexpr size_t HeightShapeIndex = 1; - constexpr size_t WidthShapeIndex = 2; - - const size_t OutputHeight = Parameters->OutputShape[HeightShapeIndex]; - const size_t OutputWidth = Parameters->OutputShape[WidthShapeIndex]; - - const size_t StrideDepth = Parameters->StrideShape[DepthShapeIndex]; - const size_t StrideHeight = Parameters->StrideShape[HeightShapeIndex]; - const size_t StrideWidth = Parameters->StrideShape[WidthShapeIndex]; - - const size_t nx = (n % OutputWidth); - const size_t ny = ((n / OutputWidth) % OutputHeight); - const size_t nz = ((n / OutputWidth) / OutputHeight); - - size_t OutputCountX = OutputWidth - nx; - size_t OutputCountY = OutputHeight - ny; - - const size_t OriginInputX = nx * StrideWidth; - const size_t OriginInputY = ny * StrideHeight; - const size_t OriginInputZ = nz * StrideDepth; - - const size_t InputDepth = Parameters->InputShape[DepthShapeIndex]; - const size_t InputHeight = Parameters->InputShape[HeightShapeIndex]; - const size_t InputWidth = Parameters->InputShape[WidthShapeIndex]; - const size_t InputSize = Parameters->InputSize; - - const size_t KernelDepth = Parameters->KernelShape[DepthShapeIndex]; - const size_t KernelHeight = Parameters->KernelShape[HeightShapeIndex]; - const size_t KernelWidth = Parameters->KernelShape[WidthShapeIndex]; - - size_t kx = (k % KernelWidth); - size_t ky = (k / KernelWidth) % KernelHeight; - size_t kz = ((k / KernelWidth) / KernelHeight) % KernelDepth; - - Input = Input + (k / (KernelDepth * KernelHeight * KernelWidth)) * InputSize; - - const size_t DilationDepth = Parameters->DilationShape[DepthShapeIndex]; - const size_t DilationHeight = Parameters->DilationShape[HeightShapeIndex]; - const size_t DilationWidth = Parameters->DilationShape[WidthShapeIndex]; - - const size_t PaddingLeftZ = Parameters->Padding[DepthShapeIndex]; - const size_t PaddingLeftY = Parameters->Padding[HeightShapeIndex]; - const size_t PaddingLeftX = Parameters->Padding[WidthShapeIndex]; - - for (size_t EndingK = k + CountK; k < EndingK; k++) { - - size_t CountY = OutputCountY; - size_t CountX = OutputCountX; - size_t InputZ = (kz * DilationDepth) + OriginInputZ - PaddingLeftZ; - const size_t RowInitialInputY = (ky * DilationHeight) - PaddingLeftY; - size_t InputY = RowInitialInputY + OriginInputY; - const size_t RowInitialInputX = (kx * DilationWidth) - PaddingLeftX; - size_t InitialInputX = RowInitialInputX + OriginInputX; - size_t RemainingN = CountN; - - do { - - if (CountX > RemainingN) { - CountX = RemainingN; - } - - RemainingN -= CountX; - - // - // Check if the input is in the top/bottom or front/back padding region. - // - - if (InputY < InputHeight && InputZ < InputDepth) { - - size_t InputX = InitialInputX; - const float* InputRow = &Input[InputZ * (InputHeight * InputWidth) + InputY * InputWidth]; - - do { - - // - // Check if the input is in the left/right padding region. - // - - if (InputX >= InputWidth) { - - *ColumnBuffer++ = 0; - InputX += StrideWidth; - CountX--; - - } else if (StrideWidth == 1) { - - // - // Copy input elements to the column buffer. - // - - size_t CountCopyX = InputWidth - InputX; - - if (CountCopyX > CountX) { - CountCopyX = CountX; - } - - CountX -= CountCopyX; - - while (CountCopyX >= 4) { - MlasStoreFloat32x4(ColumnBuffer, MlasLoadFloat32x4(&InputRow[InputX])); - ColumnBuffer += 4; - InputX += 4; - CountCopyX -= 4; - } - - while (CountCopyX > 0) { - *ColumnBuffer++ = InputRow[InputX++]; - CountCopyX--; - } - - } else if (InputX + CountX * StrideWidth <= InputWidth) { - - do { - *ColumnBuffer++ = InputRow[InputX]; - InputX += StrideWidth; - } while (--CountX > 0); - - } else { - - do { - *ColumnBuffer++ = (InputX < InputWidth) ? InputRow[InputX] : 0; - InputX += StrideWidth; - } while (--CountX > 0); - } - - } while (CountX > 0); - - } else { - - // - // The entire input row is in the padding region. - // - - MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - while (CountX >= 4) { - MlasStoreFloat32x4(ColumnBuffer, ZeroFloat32x4); - ColumnBuffer += 4; - CountX -= 4; - } - - while (CountX > 0) { - MlasStoreLaneFloat32x4<0>(ColumnBuffer, ZeroFloat32x4); - ColumnBuffer++; - CountX--; - } - } - - CountX = OutputWidth; - InputY += StrideHeight; - InitialInputX = RowInitialInputX; - - if (--CountY == 0) { - - InputY = RowInitialInputY; - InputZ += StrideDepth; - - CountY = OutputHeight; - } - - } while (RemainingN > 0); - - // - // Advance the kernel indices and advance to the next channel if the - // entire kernel is complete. - // - - if (++kx == KernelWidth) { - - if (++ky == KernelHeight) { - - if (++kz == KernelDepth) { - - Input += InputSize; - - kz = 0; - } - - ky = 0; - } - - kx = 0; - } - } -} - -void -MlasConvOperation( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - const float* Bias, - float* ColumnBuffer, - float* Output, - size_t SegmentStartN, - size_t SegmentCountN - ) -/*++ - -Routine Description: - - This routine implements the convolution operation. - -Arguments: - - Parameters - Supplies the structure that contains the convolution - parameters. - - Input - Supplies the input tensor. - - Filter - Supplies the filter tensor. - - Bias - Optionally supplies the bias vector. - - ColumnBuffer - Supplies the thread local slice of the working buffer. - - Output - Supplies the output tensor. - - SegmentStartN - Supplies the N to begin sampling the convolution patches. - - SegmentCountN - Supplies the count of N to sample for the convolution - patches. - -Return Value: - - None. - ---*/ -{ - const size_t FilterCount = Parameters->FilterCount; - const size_t OutputSize = Parameters->OutputSize; - const size_t K = Parameters->K; - - // - // Compute the strides to step through slices of the local segment. - // - // See MlasSgemmOperation. - // - - uint32_t StrideN = MLAS_SGEMM_STRIDEN; - uint32_t StrideK = MLAS_SGEMM_STRIDEK; - - if (SegmentCountN >= K) { - - while (StrideK / 2 >= K) { - StrideN *= 2; - StrideK /= 2; - } - - } else { - - while (StrideN > 16 && StrideN / 2 >= SegmentCountN) { - StrideK *= 2; - StrideN /= 2; - } - } - - // - // Step through each slice of the input tensor along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < SegmentCountN; n += CountN) { - - CountN = SegmentCountN - n; - - if (CountN > StrideN) { - CountN = StrideN; - } - - // - // Step through each slice of the input tensor along the K dimension. - // - - size_t CountK; - float beta = Parameters->Beta; - float* SegmentOutput = Output + SegmentStartN + n; - - for (size_t k = 0; k < K; k += CountK) { - - CountK = K - k; - - if (CountK > StrideK) { - CountK = StrideK; - } - - if (Parameters->Dimensions == 2) { - MlasConvIm2Col(Parameters, Input, ColumnBuffer, k, CountK, - SegmentStartN + n, CountN); - } else { - MlasConvVol2Col(Parameters, Input, ColumnBuffer, k, CountK, - SegmentStartN + n, CountN); - } - - MlasSgemmOperation(CblasNoTrans, CblasNoTrans, FilterCount, CountN, - CountK, 1.0f, Filter + k, K, ColumnBuffer, CountN, beta, - SegmentOutput, OutputSize); - - beta = 1.0f; - } - - // - // Apply the activation with optional bias. - // - - MlasActivation(Parameters->Activation, SegmentOutput, Bias, FilterCount, - CountN, OutputSize); - } -} - -void -MlasConvOperationThreaded( - void* Context, - ptrdiff_t Index - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - convolution operation. - -Arguments: - - Context - Supplies the pointer to the context for the threaded operation. - - Index - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; - - MLAS_CONV_WORK_BLOCK::SEGMENT* Segment = &WorkBlock->Segments[Index]; - - float* ColumnBuffer = - WorkBlock->WorkingBuffer + Index * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; - - MlasConvOperation(WorkBlock->Parameters, WorkBlock->Input, WorkBlock->Filter, - WorkBlock->Bias, ColumnBuffer, WorkBlock->Output, Segment->StartN, - Segment->CountN); -} - -void -MlasConvGemmDirectThreaded( - void* Context, - ptrdiff_t Index - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - convolution operation. - -Arguments: - - Context - Supplies the pointer to the context for the threaded operation. - - Index - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; - - const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; - - // - // Compute the range of indices to use for this thread. - // - - const size_t GroupCount = Parameters->GroupCount; - const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; - const float Beta = Parameters->Beta; - - size_t BatchGroupStart; - size_t BatchGroupRemaining; - - MlasPartitionWork(Index, WorkBlock->TargetThreadCount, BatchGroupCount, - &BatchGroupStart, &BatchGroupRemaining); - - size_t BatchGroupEnd = BatchGroupStart + BatchGroupRemaining; - - // - // Iterate over the batch and groups allocated to this thread. - // - - const size_t FilterCount = Parameters->FilterCount; - const size_t OutputSize = Parameters->OutputSize; - const size_t K = Parameters->K; - - const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; - const size_t OutputGroupSize = FilterCount * OutputSize; - const size_t FilterGroupSize = FilterCount * K; - - for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { - - size_t group = bg % GroupCount; - - const float* input = WorkBlock->Input + bg * InputGroupSize; - const float* filter = WorkBlock->Filter + group * FilterGroupSize; - float* output = WorkBlock->Output + bg * OutputGroupSize; - - // - // Invoke the non-threaded GEMM directly with the input tensor. - // - - MlasSgemmOperation(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize, - K, 1.0f, filter, K, input, Parameters->u.GemmDirect.ldb, Beta, output, - OutputSize); - - // - // Apply the activation with optional bias. - // - - const float* bias = WorkBlock->Bias; - - if (bias != nullptr) { - bias += group * FilterCount; - } - - MlasActivation(Parameters->Activation, output, bias, FilterCount, - OutputSize, OutputSize); - } -} - -inline -bool -MlasConvTryMultithread( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - const float* Bias, - float* WorkingBuffer, - float* Output, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine attempts to launch a convolution operation across multiple - threads. - -Arguments: - - Parameters - Supplies the structure that contains the convolution - parameters. - - Input - Supplies the input tensor. - - Filter - Supplies the filter tensor. - - Bias - Optionally supplies the bias vector. - - WorkingBuffer - Supplies a working buffer sized to the number of elements - returned by MlasConvPrepare. - - Output - Supplies the output tensor. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - Returns true if the operation was completed across multiple threads, else - false if the operation should fall back to a single thread. - ---*/ -{ - MLAS_CONV_WORK_BLOCK WorkBlock; - - const size_t OutputSize = Parameters->OutputSize; - const size_t ThreadStrideN = Parameters->u.ExpandThenGemmSegmented.ThreadStrideN; - - if (ThreadStrideN >= OutputSize) { - return false; - } - - // - // Initialize the common fields of the work block. - // - - WorkBlock.Parameters = Parameters; - WorkBlock.Input = Input; - WorkBlock.Filter = Filter; - WorkBlock.Bias = Bias; - WorkBlock.WorkingBuffer = WorkingBuffer; - WorkBlock.Output = Output; - - // - // Segment the operation across multiple threads. - // - - int32_t Index = 0; - size_t SegmentCountN; - - for (size_t SegmentStartN = 0; SegmentStartN < OutputSize; SegmentStartN += SegmentCountN) { - - SegmentCountN = OutputSize - SegmentStartN; - - if (SegmentCountN > ThreadStrideN) { - SegmentCountN = ThreadStrideN; - } - - WorkBlock.Segments[Index].StartN = SegmentStartN; - WorkBlock.Segments[Index].CountN = SegmentCountN; - - Index++; - } - - MlasExecuteThreaded(MlasConvOperationThreaded, &WorkBlock, Index, ThreadPool); - - return true; -} - -void -MLASCALL -MlasConv( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - const float* Bias, - float* WorkingBuffer, - float* Output, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the convolution operation. - -Arguments: - - Parameters - Supplies the structure that contains the convolution - parameters. - - Input - Supplies the input tensor. - - Filter - Supplies the filter tensor. - - Bias - Optionally supplies the bias vector. - - WorkingBuffer - Supplies a working buffer sized to the number of elements - returned by MlasConvPrepare. - - Output - Supplies the output tensor. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - const size_t FilterCount = Parameters->FilterCount; - const size_t OutputSize = Parameters->OutputSize; - const size_t K = Parameters->K; - - const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; - const size_t OutputGroupSize = FilterCount * OutputSize; - const size_t FilterGroupSize = FilterCount * K; - - const size_t BatchCount = Parameters->BatchCount; - const size_t GroupCount = Parameters->GroupCount; - - const MLAS_CONV_ALGORITHM Algorithm = Parameters->Algorithm; - - // - // Schedule batches of GEMMs across multiple threads. - // - - if (Algorithm == MlasConvAlgorithmGemmDirect && ((BatchCount > 1) || (GroupCount > 1))) { - - const size_t BatchGroupCount = BatchCount * GroupCount; - - ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (size_t(TargetThreadCount) >= BatchGroupCount) { - TargetThreadCount = ptrdiff_t(BatchGroupCount); - } - - MLAS_CONV_WORK_BLOCK WorkBlock; - - WorkBlock.Parameters = Parameters; - WorkBlock.Input = Input; - WorkBlock.Filter = Filter; - WorkBlock.Bias = Bias; - WorkBlock.WorkingBuffer = nullptr; - WorkBlock.Output = Output; - WorkBlock.TargetThreadCount = TargetThreadCount; - - MlasExecuteThreaded(MlasConvGemmDirectThreaded, &WorkBlock, TargetThreadCount, ThreadPool); - - return; - } - -#if defined(MLAS_TARGET_WASM_SCALAR) - - if (Algorithm == MlasConvAlgorithmDepthwise) { - // Fill the Working Buffer with Zero for use by the depthwise kernel. - // The length for the zeros are input image wide + 2 currently. - std::fill_n(WorkingBuffer, Parameters->InputShape[1] + 2, 0.0f); - } - -#endif - - // - // Iterate over each batch and group. - // - for (size_t batch = 0; batch < BatchCount; batch++) { - - const float* filter = Filter; - const float* bias = Bias; - - for (size_t group = 0; group < GroupCount; group++) { - - // - // Dispatch the convolution. - // - - switch (Algorithm) { - - case MlasConvAlgorithmGemmDirect: - { - // - // Invoke the threaded GEMM directly with the input tensor. - // - - MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize, - K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb, - Parameters->Beta, Output, OutputSize, ThreadPool); - - // - // Apply the activation with optional bias. - // - - MlasActivation(Parameters->Activation, Output, bias, FilterCount, - OutputSize, OutputSize); - - break; - } - - case MlasConvAlgorithmExpandThenGemm: - { - // - // Expand the input tensor to the working buffer and then invoke the - // threaded GEMM. - // - - if (Parameters->Dimensions == 2) { - MlasConvIm2Col(Parameters, Input, WorkingBuffer, 0, K, 0, OutputSize); - } else { - MlasConvVol2Col(Parameters, Input, WorkingBuffer, 0, K, 0, OutputSize); - } - - MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, OutputSize, K, 1.0f, filter, - K, WorkingBuffer, OutputSize, Parameters->Beta, Output, OutputSize, - ThreadPool); - - // - // Apply the activation with optional bias. - // - - MlasActivation(Parameters->Activation, Output, bias, FilterCount, - OutputSize, OutputSize); - - break; - } - -#if defined(MLAS_TARGET_WASM_SCALAR) - - case MlasConvAlgorithmDepthwise: - { - MlasConvDepthwiseFloat_CHW(Parameters, Input, filter, Output, WorkingBuffer); - MlasActivation(Parameters->Activation, Output, bias, FilterCount, OutputSize, OutputSize); - break; - } - -#endif - - case MlasConvAlgorithmExpandThenGemmSegmented: - { - // - // Attempt to launch the convolution across multiple threads or fall - // back to a single thread. - // - - if (!MlasConvTryMultithread(Parameters, Input, filter, bias, WorkingBuffer, - Output, ThreadPool)) { - MlasConvOperation(Parameters, Input, filter, bias, WorkingBuffer, - Output, 0, OutputSize); - } - - break; - } - } - - // - // Advance the buffer pointers. - // - - if (bias != nullptr) { - bias += FilterCount; - } - - filter += FilterGroupSize; - Input += InputGroupSize; - Output += OutputGroupSize; - } - } -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif -void -MLASCALL -MlasConvPrepare( - MLAS_CONV_PARAMETERS* Parameters, - size_t Dimensions, - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - size_t FilterCount, - const MLAS_ACTIVATION* Activation, - size_t* WorkingBufferSize, - float Beta, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine prepares for a convolution operation by computing required - parameters including the required working buffer size for intermediate - results. - -Arguments: - - Parameters - Supplies the structure that stores the provided and computed - parameters for the convolution operation. - - Dimensions - Supplies the number of dimensions (must be between 1 and 3). - - BatchCount - Supplies the number of batches to the processed. - - GroupCount - Supplies the number of channel groups. - - InputChannels - Supplies the number of input channels per group. - - InputShape - Supplies the shape of the input tensor. - - KernelShape - Supplies the shape of the kernel transform. - - DilationShape - Supplies the shape of the dilation. - - Padding - Supplies the number of zero padding elements at the edge of the - input tensor. - - StrideShape - Supplies the shape of the stride. - - OutputShape - Supplies the shape of the output tensor. - - FilterCount - Supplies the number of rows of the filter matrix per group. - - Activation - Supplies the parameters for the activation to apply to the - convolution output. - - WorkingBufferSize - Receives the number of elements to allocate for the - working buffer for intermediate results. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - // - // Save the convolution parameters. - // - - Parameters->Activation = Activation; - Parameters->BatchCount = BatchCount; - Parameters->GroupCount = GroupCount; - Parameters->InputChannels = InputChannels; - Parameters->FilterCount = FilterCount; - Parameters->Beta = Beta; - - size_t InputSize = 1; - size_t OutputSize = 1; - size_t K = InputChannels; - - bool AllStridesAreOne = true; - bool AllDilationsAreOne = true; - bool AllPaddingIsZero = true; - - for (size_t dim = 0; dim < Dimensions; dim++) { - - Parameters->InputShape[dim] = size_t(InputShape[dim]); - Parameters->OutputShape[dim] = size_t(OutputShape[dim]); - Parameters->KernelShape[dim] = size_t(KernelShape[dim]); - Parameters->DilationShape[dim] = size_t(DilationShape[dim]); - Parameters->Padding[dim] = size_t(Padding[dim]); - Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); - Parameters->StrideShape[dim] = size_t(StrideShape[dim]); - - InputSize *= Parameters->InputShape[dim]; - OutputSize *= Parameters->OutputShape[dim]; - K *= Parameters->KernelShape[dim]; - - AllStridesAreOne &= (Parameters->StrideShape[dim] == 1); - AllDilationsAreOne &= (Parameters->DilationShape[dim] == 1); - AllPaddingIsZero &= (Parameters->Padding[dim] == 0 && Parameters->Padding[dim + Dimensions] == 0); - } - - Parameters->InputSize = InputSize; - Parameters->OutputSize = OutputSize; - Parameters->K = K; - - // - // Promote 1D convolutions to 2D convolutions. - // - - if (Dimensions == 1) { - - Parameters->InputShape[1] = Parameters->InputShape[0]; - Parameters->InputShape[0] = 1; - Parameters->OutputShape[1] = Parameters->OutputShape[0]; - Parameters->OutputShape[0] = 1; - Parameters->KernelShape[1] = Parameters->KernelShape[0]; - Parameters->KernelShape[0] = 1; - Parameters->DilationShape[1] = Parameters->DilationShape[0]; - Parameters->DilationShape[0] = 1; - Parameters->Padding[3] = Parameters->Padding[1]; - Parameters->Padding[2] = 0; - Parameters->Padding[1] = Parameters->Padding[0]; - Parameters->Padding[0] = 0; - Parameters->StrideShape[1] = Parameters->StrideShape[0]; - Parameters->StrideShape[0] = 1; - - Dimensions = 2; - } - - Parameters->Dimensions = Dimensions; - - // - // Evaluate how the convolution will be performed. - // - - *WorkingBufferSize = 0; - - if (AllStridesAreOne && AllPaddingIsZero) { - - // - // Detect a pointwise convolution. - // - - if (K == InputChannels) { - - Parameters->Algorithm = MlasConvAlgorithmGemmDirect; - Parameters->u.GemmDirect.TransB = CblasNoTrans; - Parameters->u.GemmDirect.ldb = OutputSize; - - return; - } - - if (Dimensions == 2 && AllDilationsAreOne && InputChannels == 1) { - - // - // Detect convolutions where the kernel is using the entire input - // width or height. - // - - if (Parameters->KernelShape[1] == Parameters->InputShape[1]) { - - Parameters->Algorithm = MlasConvAlgorithmGemmDirect; - Parameters->u.GemmDirect.TransB = CblasTrans; - Parameters->u.GemmDirect.ldb = Parameters->InputShape[1]; - - return; - } - - if (Parameters->KernelShape[0] == Parameters->InputShape[0] && - Parameters->KernelShape[1] == 1) { - - Parameters->Algorithm = MlasConvAlgorithmGemmDirect; - Parameters->u.GemmDirect.TransB = CblasNoTrans; - Parameters->u.GemmDirect.ldb = Parameters->InputShape[1]; - - return; - } - } - } - - if (FilterCount > OutputSize) { - - // - // The filter count is larger than the output dimensions, so perform the - // full matrix expansion and then invoke the threaded GEMM. - // - - Parameters->Algorithm = MlasConvAlgorithmExpandThenGemm; - - *WorkingBufferSize = OutputSize * K; - - } else { - -#if defined(MLAS_TARGET_WASM_SCALAR) - - // Scalar direct conv for depthwise convolution. - // Currently only support 3x3 kernel with padding <=1 and dilations = 1. - // TODO: support more general depthwise convolution. - - if (Dimensions == 2 - && Parameters->FilterCount == 1 && Parameters->InputChannels == 1 - && Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3 - && Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1 - && Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1 - && Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) { - - *WorkingBufferSize = Parameters->InputShape[1] + 2; - Parameters->Algorithm = MlasConvAlgorithmDepthwise; - return; - } - -#endif - - // - // Segment the operation across multiple threads by slicing the N - // dimension (see MlasSgemmTryMultithread). - // - // Compute the number of target threads given the complexity of the - // convolution operation. Small requests should run using the single - // threaded path. - // - - ptrdiff_t TargetThreadCount; - double Complexity = double(FilterCount) * double(OutputSize) * double(K); - - if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - // - // Compute the thread stride for slicing the N dimension. - // - - size_t StrideN = OutputSize / TargetThreadCount; - - if ((StrideN * TargetThreadCount) != OutputSize) { - StrideN++; - } - - if (TargetThreadCount > 1) { - - StrideN = (StrideN + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); - - if (StrideN >= OutputSize) { - TargetThreadCount = 1; - } else if (StrideN * (TargetThreadCount - 1) >= OutputSize) { - TargetThreadCount--; - } - } - - Parameters->ThreadCount = TargetThreadCount; - - Parameters->Algorithm = MlasConvAlgorithmExpandThenGemmSegmented; - Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; - - *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; - } -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/convsym.cpp b/onnxruntime/core/mlas/lib/convsym.cpp deleted file mode 100644 index 5f8be3580bb72..0000000000000 --- a/onnxruntime/core/mlas/lib/convsym.cpp +++ /dev/null @@ -1,652 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - convsym.cpp - -Abstract: - - This module implements the symmetric quantized integer convolution - operation. - ---*/ - -#include "mlasi.h" - -// -// Define the prototypes of the platform optimized routines. -// - -typedef -void -(MLASCALL MLAS_CONV_SYM_KERNEL)( - const void* Input, - const void* Filter, - void* Output, - size_t KernelSize, - size_t InputChannels, - size_t OutputChannels, - unsigned ChannelCount, - unsigned OutputCount, - const struct MLAS_CONV_SYM_POST_PROCESS_PARAMS* PostProcessParams, - unsigned KernelFlags - ); - -typedef -void -(MLASCALL MLAS_CONV_SYM_DEPTHWISE_KERNEL)( - const void* Input, - const void* Filter, - void* Output, - size_t KernelSize, - size_t Channels, - size_t ChannelOffset, - unsigned ChannelCount, - unsigned OutputCount, - const struct MLAS_CONV_SYM_POST_PROCESS_PARAMS* PostProcessParams, - unsigned KernelFlags - ); - -// -// Processor for common kernel sized (e.g. 3x3, 5x5) -// -typedef -void -(MLASCALL MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC)( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - - -extern "C" { - -#if defined(MLAS_TARGET_AMD64) - MLAS_CONV_SYM_KERNEL MlasConvSymKernelAvx2; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelAvx2; - MLAS_CONV_SYM_KERNEL MlasConvSymKernelAvxVnni; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelAvxVnni; - MLAS_CONV_SYM_KERNEL MlasConvSymKernelAvx512Core; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelAvx512Core; - MLAS_CONV_SYM_KERNEL MlasConvSymKernelAvx512Vnni; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelAvx512Vnni; -#elif defined(MLAS_TARGET_ARM64) - MLAS_CONV_SYM_KERNEL MlasConvSymS8KernelNeon; - MLAS_CONV_SYM_KERNEL MlasConvSymU8KernelNeon; - MLAS_CONV_SYM_KERNEL MlasConvSymS8KernelDot; - MLAS_CONV_SYM_KERNEL MlasConvSymS8KernelDotLd64; - MLAS_CONV_SYM_KERNEL MlasConvSymU8KernelDot; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseU8KernelNeon; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseS8KernelNeon; - -// -// Specialized depthwise conv kernels for 3x3 and 5x5 filters -// - -void -MLASCALL -MlasConvSymDepthwiseKernelSize9Arm64U8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -void -MLASCALL -MlasConvSymDepthwiseKernelSize9Arm64S8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -void -MLASCALL -MlasConvSymDepthwiseKernelSize25ArmS8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -void -MLASCALL -MlasConvSymDepthwiseKernelSize25ArmU8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -#endif -} - -struct MLAS_CONV_SYM_DISPATCH { - MLAS_CONV_SYM_KERNEL* Kernel; -#if defined(MLAS_TARGET_ARM64) - MLAS_CONV_SYM_KERNEL* KernelLittle; // kernel for little core -#endif - MLAS_CONV_SYM_DEPTHWISE_KERNEL* DepthwiseKernel; - MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC* Depthwise3x3Proc; - MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC* Depthwise5x5Proc; - uint8_t FilterInputChannelPackCount; - uint8_t FilterOutputChannelPackCount; - uint8_t KernelChannelCount; - uint8_t KernelOutputCount; - uint8_t KernelInputChannelAlignment; - uint8_t KernelOutputChannelAlignment; - uint8_t KernelDepthwiseChannelCount; - uint8_t KernelDepthwiseOutputCount; - bool FixupInputZeroPoint; -}; - -#if defined(MLAS_TARGET_AMD64) - -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2 = { - MlasConvSymKernelAvx2, - MlasConvSymDepthwiseKernelAvx2, - nullptr, - nullptr, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 16, // KernelChannelCount - 4, // KernelOutputCount - 4, // KernelInputChannelAlignment - 8, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - false, // FixupInputZeroPoint -}; - -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni = { - MlasConvSymKernelAvxVnni, - MlasConvSymDepthwiseKernelAvxVnni, - nullptr, - nullptr, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 16, // KernelChannelCount - 6, // KernelOutputCount - 4, // KernelInputChannelAlignment - 8, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - false, // FixupInputZeroPoint -}; - -#if !defined(ORT_MINIMAL_BUILD) - -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core = { - MlasConvSymKernelAvx512Core, - MlasConvSymDepthwiseKernelAvx512Core, - nullptr, - nullptr, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 64, // KernelChannelCount - 6, // KernelOutputCount - 4, // KernelInputChannelAlignment - 4, // KernelOutputChannelAlignment - 64, // KernelDepthwiseChannelCount - 6, // KernelDepthwiseOutputCount - false, // FixupInputZeroPoint -}; - -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni = { - MlasConvSymKernelAvx512Vnni, - MlasConvSymDepthwiseKernelAvx512Vnni, - nullptr, - nullptr, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 64, // KernelChannelCount - 6, // KernelOutputCount - 4, // KernelInputChannelAlignment - 4, // KernelOutputChannelAlignment - 64, // KernelDepthwiseChannelCount - 6, // KernelDepthwiseOutputCount - false, // FixupInputZeroPoint -}; - -#endif // ORT_MINIMAL_BUILD - -#elif defined(MLAS_TARGET_ARM64) -const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchNeon = { - MlasConvSymU8KernelNeon, - MlasConvSymU8KernelNeon, - MlasConvSymDepthwiseU8KernelNeon, - MlasConvSymDepthwiseKernelSize9Arm64U8S8, - MlasConvSymDepthwiseKernelSize25ArmU8S8, - 8, // FilterInputChannelPackCount - 8, // FilterOutputChannelPackCount - 8, // KernelChannelCount - 2, // KernelOutputCount - 8, // KernelInputChannelAlignment - 8, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - true -}; - -const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon = { - MlasConvSymS8KernelNeon, - MlasConvSymS8KernelNeon, - MlasConvSymDepthwiseS8KernelNeon, - MlasConvSymDepthwiseKernelSize9Arm64S8S8, - MlasConvSymDepthwiseKernelSize25ArmS8S8, - 8, // FilterInputChannelPackCount - 8, // FilterOutputChannelPackCount - 8, // KernelChannelCount - 2, // KernelOutputCount - 8, // KernelInputChannelAlignment - 8, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - false -}; - -const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot = { - MlasConvSymU8KernelDot, - MlasConvSymU8KernelDot, - MlasConvSymDepthwiseU8KernelNeon, - MlasConvSymDepthwiseKernelSize9Arm64U8S8, - MlasConvSymDepthwiseKernelSize25ArmU8S8, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 0, // KernelChannelCount - 4, // KernelOutputCount - 4, // KernelInputChannelAlignment - 16, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - true -}; - -const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot = { - MlasConvSymS8KernelDot, - MlasConvSymS8KernelDotLd64, - MlasConvSymDepthwiseS8KernelNeon, - MlasConvSymDepthwiseKernelSize9Arm64S8S8, - MlasConvSymDepthwiseKernelSize25ArmS8S8, - 4, // FilterInputChannelPackCount - 16, // FilterOutputChannelPackCount - 0, // KernelChannelCount - 4, // KernelOutputCount - 4, // KernelInputChannelAlignment - 16, // KernelOutputChannelAlignment - 16, // KernelDepthwiseChannelCount - 4, // KernelDepthwiseOutputCount - false -}; -#endif // MLAS_TARGET_AMD64 - -MLAS_FORCEINLINE -void -MlasConvSymSetOutputZeroPoint( - MLAS_CONV_SYM_POST_PROCESS_PARAMS& PostProcessParams, - int32_t OutputZeroPoint, - bool InputIsSigned - ) -{ - int32_t minimum = InputIsSigned ? std::numeric_limits::lowest() - : std::numeric_limits::lowest(); - int32_t maximum = InputIsSigned ? std::numeric_limits::max() - : std::numeric_limits::max(); - PostProcessParams.MinimumValue = static_cast(minimum - OutputZeroPoint); - PostProcessParams.MaximumValue = static_cast(maximum - OutputZeroPoint); - PostProcessParams.OutputZeroPoint = OutputZeroPoint; -} - -MLAS_FORCEINLINE -const -MLAS_CONV_SYM_DISPATCH* -GetConvSymDispatch(bool InputIsSigned){ - return InputIsSigned ? GetMlasPlatform().ConvSymS8S8Dispatch : GetMlasPlatform().ConvSymU8S8Dispatch; -} - -size_t -MlasConvSymPackWSize( - size_t GroupCount, - size_t InputChannels, - size_t OutputChannels, - size_t KernelSize, - bool InputIsSigned - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); - - if (ConvSymDispatch == nullptr) { - return 0; - } - - if (GroupCount > 1) { - - if (ConvSymDispatch->DepthwiseKernel != nullptr && - InputChannels == 1 && OutputChannels == 1) { -#ifdef MLAS_TARGET_ARM64 - constexpr size_t GroupAlign = 8; -#else - constexpr size_t GroupAlign = 16; -#endif - size_t AlignedGroupCount = (GroupCount + GroupAlign - 1) & ~(GroupAlign - 1); - - if (AlignedGroupCount != GroupCount) { - return 0; - } - - return AlignedGroupCount * KernelSize; - - } else { - return 0; - } - - } else { - -#ifdef MLAS_TARGET_ARM64 - if (KernelSize <= 1) { - // im2col not needed, indirected buffer not needed - // just use qgemm path for pointwise - return 0; - } - if (InputChannels < 64) { - // Shallow indirect conv runs slower. - // TODO!! remove this for functional testing! - // TODO!! is there a way to know whether this is called by tests? - return 0; - } -#endif - - size_t OutputChannelPackCount = ConvSymDispatch->FilterOutputChannelPackCount; - - if (ConvSymDispatch->Kernel == nullptr || - OutputChannels < OutputChannelPackCount || - (InputChannels % ConvSymDispatch->KernelInputChannelAlignment) != 0 || - (OutputChannels % ConvSymDispatch->KernelOutputChannelAlignment) != 0 - ) { - return 0; - } - - size_t AlignedOutputChannels = (OutputChannels + OutputChannelPackCount - 1) / OutputChannelPackCount * OutputChannelPackCount; - return AlignedOutputChannels * InputChannels * KernelSize; - } -} - -void -MlasConvSymPackW( - size_t GroupCount, - size_t InputChannels, - size_t OutputChannels, - size_t KernelSize, - const int8_t* W, - int8_t* PackedW, - size_t PackedWSize, - bool InputIsSigned - ) -{ - memset(PackedW, 0, PackedWSize); - - if (GroupCount > 1) { - - for (size_t gc = 0; gc < GroupCount; gc++) { - - for (size_t k = 0; k < KernelSize; k++) { - - PackedW[k * GroupCount + gc] = W[gc * KernelSize + k]; - - } - } - - } else { - - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); - size_t InputChannelPackCount = ConvSymDispatch->FilterInputChannelPackCount; - size_t OutputChannelPackCount = ConvSymDispatch->FilterOutputChannelPackCount; - - size_t kernel_dim = InputChannels * KernelSize; - - for (size_t oc = 0; oc < OutputChannels; oc += OutputChannelPackCount) { - - const size_t oc_pack_size = std::min(OutputChannels - oc, OutputChannelPackCount); - - for (size_t ki = 0; ki < KernelSize; ki++) { - - for (size_t ic = 0; ic < InputChannels; ic += InputChannelPackCount) { - - const size_t ic_pack_size = std::min(InputChannels - ic, InputChannelPackCount); - - for (size_t oc_pack = 0; oc_pack < oc_pack_size; oc_pack++) { - - for (size_t ic_pack = 0; ic_pack < ic_pack_size; ic_pack++) { - - *(PackedW++) = W[(oc + oc_pack) * kernel_dim + (ic + ic_pack) * KernelSize + ki]; - - } - - PackedW += InputChannelPackCount - ic_pack_size; - - } - - PackedW += (OutputChannelPackCount - oc_pack_size) * InputChannelPackCount; - - } - } - } - - } -} - -int32_t -MlasConvSymFixupInputZeroPoint( - int32_t zero_point_value, - bool InputIsSigned - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); - - if (ConvSymDispatch != nullptr && ConvSymDispatch->FixupInputZeroPoint) { - return zero_point_value - 128; - } - return zero_point_value; -} - -int32_t -MlasConvSymGetKernelOutputCount( - bool InputIsSigned - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); - return ConvSymDispatch->KernelOutputCount; -} - -int32_t -MlasConvSymDepthwiseGetKernelOutputCnt( - bool InputIsSigned - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(InputIsSigned); - return ConvSymDispatch->KernelDepthwiseOutputCount; -} - - -void -MlasConvSym( - const MLAS_CONV_SYM_PARAMS& Params - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(Params.InputIsSigned); - - // Pick the suitable kernel for current core. Currently we only have specialized core for - // s8s8 under ARM64 -#if defined(MLAS_TARGET_ARM64) - const auto Kernel = - (Params.InputIsSigned && MLAS_CPUIDINFO::GetCPUIDInfo().IsCurrentCoreArmv8NarrowLd()) - ? ConvSymDispatch->KernelLittle - : ConvSymDispatch->Kernel; -#else - const auto Kernel = ConvSymDispatch->Kernel; -#endif - - int32_t KernelFlags = 0; - - if (Params.PerChannelScale) { - KernelFlags |= MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE; - } - - if (Params.InputIndirection == nullptr) { - KernelFlags |= MLAS_CONV_SYM_FLAG_INPUT_DIRECT; - } - - MLAS_CONV_SYM_POST_PROCESS_PARAMS PostProcessParams = {}; - - MlasConvSymSetOutputZeroPoint(PostProcessParams, Params.OutputZeroPoint, Params.InputIsSigned); - - const size_t KernelChannelCount = (ConvSymDispatch->KernelChannelCount == 0) - ? std::numeric_limits::max() - : ConvSymDispatch->KernelChannelCount; - const size_t KernelOutputCount = ConvSymDispatch->KernelOutputCount; - - const size_t KernelSize = Params.KernelSize; - const size_t InputChannels = Params.InputChannels; - const size_t OutputChannels = Params.OutputChannels; - - for (size_t oc_outside = 0; oc_outside < Params.OutputCount;) { - - const size_t oc_outside_block_size = std::min(Params.OutputCount - oc_outside, 240); - const int8_t* pwb = static_cast(Params.Filter); - - for (size_t co = 0; co < OutputChannels;) { - - const size_t ChannelCount = std::min(OutputChannels - co, KernelChannelCount); - void* conv_out = static_cast(Params.Output) + (oc_outside * OutputChannels) + co; - - PostProcessParams.Bias = Params.Bias + co; - PostProcessParams.Scale = Params.Scale + (Params.PerChannelScale ? co : 0); - - for (size_t oc = 0; oc < oc_outside_block_size;) { - - const void* Input; - if (Params.InputIndirection) { - Input = Params.InputIndirection + (oc_outside + oc) * KernelSize; - } else { - Input = static_cast(Params.InputDirect) + (oc_outside + oc) * InputChannels; - } - size_t OutputCount = std::min(oc_outside_block_size - oc, KernelOutputCount); - - Kernel( - Input, - pwb, - conv_out, - KernelSize, - InputChannels, - OutputChannels, - static_cast(ChannelCount), - static_cast(OutputCount), - &PostProcessParams, - KernelFlags); - oc += OutputCount; - conv_out = static_cast(conv_out) + OutputCount * OutputChannels; - } - - co += ChannelCount; - pwb += ChannelCount * InputChannels * KernelSize; - } - - oc_outside += oc_outside_block_size; - } -} - -void -MlasConvSymDepthwise( - const MLAS_CONV_SYM_PARAMS& Params - ) -{ - const MLAS_CONV_SYM_DISPATCH* ConvSymDispatch = GetConvSymDispatch(Params.InputIsSigned); - - unsigned KernelFlags = 0; - - if (Params.PerChannelScale) { - KernelFlags |= MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE; - } - - MLAS_CONV_SYM_POST_PROCESS_PARAMS PostProcessParams = {}; - - MlasConvSymSetOutputZeroPoint(PostProcessParams, Params.OutputZeroPoint, Params.InputIsSigned); - - if ((Params.OutputChannels & 15) == 0) { - PostProcessParams.Bias = Params.Bias; - PostProcessParams.Scale = Params.Scale; - if (ConvSymDispatch->Depthwise3x3Proc && Params.KernelSize == 9) { - ConvSymDispatch->Depthwise3x3Proc(Params.InputIndirection, (int8_t const*)Params.Filter, - Params.OutputChannels, Params.Output, - Params.OutputCount, &PostProcessParams, KernelFlags); - return; - } - if (ConvSymDispatch->Depthwise5x5Proc && Params.KernelSize == 25) { - ConvSymDispatch->Depthwise5x5Proc(Params.InputIndirection, (int8_t const*)Params.Filter, - Params.OutputChannels, Params.Output, - Params.OutputCount, &PostProcessParams, KernelFlags); - return; - } - } - - const size_t KernelChannelCount = ConvSymDispatch->KernelDepthwiseChannelCount; - const size_t KernelOutputCount = ConvSymDispatch->KernelDepthwiseOutputCount; - - const size_t KernelSize = Params.KernelSize; - const size_t OutputChannels = Params.OutputChannels; - - const auto* InputIndirection = Params.InputIndirection; - void* Output = Params.Output; - - for (size_t OutputCountRemaining = Params.OutputCount; OutputCountRemaining > 0;) { - - const size_t OutputCount = std::min(OutputCountRemaining, KernelOutputCount); - - for (size_t ChannelOffset = 0; ChannelOffset < OutputChannels;) { - - const size_t ChannelCount = std::min(OutputChannels - ChannelOffset, KernelChannelCount); - - PostProcessParams.Bias = Params.Bias + ChannelOffset; - PostProcessParams.Scale = Params.Scale + (Params.PerChannelScale ? ChannelOffset : 0); - - ConvSymDispatch->DepthwiseKernel( - InputIndirection, - static_cast(Params.Filter) + ChannelOffset, - static_cast(Output) + ChannelOffset, - KernelSize, - OutputChannels, - ChannelOffset, - static_cast(ChannelCount), - static_cast(OutputCount), - &PostProcessParams, - KernelFlags); - - ChannelOffset += ChannelCount; - } - - InputIndirection += OutputCount * KernelSize; - Output = static_cast(Output) + OutputCount * OutputChannels; - OutputCountRemaining -= OutputCount; - } -} diff --git a/onnxruntime/core/mlas/lib/dgemm.cpp b/onnxruntime/core/mlas/lib/dgemm.cpp deleted file mode 100644 index 50c62744f1d8e..0000000000000 --- a/onnxruntime/core/mlas/lib/dgemm.cpp +++ /dev/null @@ -1,890 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - dgemm.cpp - -Abstract: - - This module implements the double precision matrix/matrix multiply - operation (DGEMM). - ---*/ - -#include "mlasi.h" - -// -// Define the number of rows from matrix A to transpose to a local buffer. -// -// N.B. AVX processes a maximum of 4 rows, FMA3 processes a maximum of 6 -// rows, and AVX512F processes a maximum of 12 rows. -// - -#define MLAS_DGEMM_TRANSA_ROWS 12 - -#if defined (MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) - -void -MlasDgemmMultiplyBeta( - double* C, - size_t CountM, - size_t CountN, - size_t ldc, - double beta - ) -/*++ - -Routine Description: - - This routine multiplies all elements of the output matrix by the beta - scalar value. - -Arguments: - - C - Supplies the address of matrix C. - - CountM - Supplies the number of rows from matrix C. - - CountN - Supplies the number of columns from matrix C. - - ldc - Supplies the first dimension of matrix C. - - beta - Supplies the scalar beta multiplier (see DGEMM definition). - -Return Value: - - None. - ---*/ -{ - MLAS_FLOAT64X2 BetaBroadcast = MlasBroadcastFloat64x2(beta); - - while (CountM-- > 0) { - - double* c = C; - size_t n = CountN; - - while (n >= 2) { - MlasStoreFloat64x2(c, MlasMultiplyFloat64x2(MlasLoadFloat64x2(c), BetaBroadcast)); - c += 2; - n -= 2; - } - - if (n > 0) { -#if defined(MLAS_SSE2_INTRINSICS) - _mm_store_sd(c, _mm_mul_sd(_mm_load_sd(c), BetaBroadcast)); -#else - *c = *c * beta; -#endif - } - - C += ldc; - } -} - -void -MlasDgemmTransposeA( - double* D, - const double* A, - size_t lda, - size_t CountY, - size_t CountX - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - buffer. - -Arguments: - - D - Supplies the address of the destination buffer. - - A - Supplies the address of the source matrix. - - lda - Supplies the number of elements per row of the source matrix. - - CountY - Supplies the number of columns of the source matrix to transpose. - - CountX - Supplies the number of rows of the source matrix to transpose. - -Return Value: - - None. - ---*/ -{ - size_t ldd = CountX; - - // - // Transpose elements from matrix A into the destination buffer 4 columns - // at a time. - // - - while (CountX >= 4) { - - double* d = D; - const double* a = A; - size_t y = CountY; - - do { - - double t0 = a[0]; - double t1 = a[lda]; - double t2 = a[lda * 2]; - double t3 = a[lda * 3]; - - d[0] = t0; - d[1] = t1; - d[2] = t2; - d[3] = t3; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - - D += 4; - A += lda * 4; - CountX -= 4; - } - - // - // Transpose elements from matrix A into the destination buffer for the - // remaining columns. - // - - if (CountX >= 2) { - - double* d = D; - const double* a = A; - size_t y = CountY; - - do { - - double t0 = a[0]; - double t1 = a[lda]; - - d[0] = t0; - d[1] = t1; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - - D += 2; - A += lda * 2; - CountX -= 2; - } - - if (CountX >= 1) { - - double* d = D; - const double* a = A; - size_t y = CountY; - - do { - - d[0] = a[0]; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - } -} - -void -MlasDgemmCopyPackB( - double* D, - const double* B, - size_t ldb, - size_t CountX, - size_t CountY - ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - - Columns of 8 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the DGEMM kernels. Any remaining - columns less than 8 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountX - Supplies the number of columns of the source matrix to copy. - - CountY - Supplies the number of rows of the source matrix to copy. - -Return Value: - - None. - ---*/ -{ - // - // Copy data from matrix B into the destination buffer 16 columns at a - // time. - // - - while (CountX >= 8) { - - const double* b = B; - size_t y = CountY; - - do { - -#if defined(MLAS_NEON64_INTRINSICS) - vst4q_f64(D, vld4q_f64(b)); -#else - MLAS_FLOAT64X2 t0 = MlasLoadFloat64x2(&b[0]); - MLAS_FLOAT64X2 t1 = MlasLoadFloat64x2(&b[2]); - MLAS_FLOAT64X2 t2 = MlasLoadFloat64x2(&b[4]); - MLAS_FLOAT64X2 t3 = MlasLoadFloat64x2(&b[6]); - - MlasStoreAlignedFloat64x2(&D[0], t0); - MlasStoreAlignedFloat64x2(&D[2], t1); - MlasStoreAlignedFloat64x2(&D[4], t2); - MlasStoreAlignedFloat64x2(&D[6], t3); -#endif - - D += 8; - b += ldb; - y--; - - } while (y > 0); - - B += 8; - CountX -= 8; - } - - // - // Special case the handling of the remaining columns less than 16 elements - // wide. - // - - if (CountX > 0) { - - MLAS_FLOAT64X2 ZeroFloat64x2 = MlasZeroFloat64x2(); - -#if defined(MLAS_NEON64_INTRINSICS) - float64x2x4_t ZeroFloat64x2x4 = { ZeroFloat64x2, ZeroFloat64x2, ZeroFloat64x2, ZeroFloat64x2 }; -#endif - - size_t y = CountY; - - do { - - double* d = D; - const double* b = B; - -#if defined(MLAS_NEON64_INTRINSICS) - vst4q_f64(d, ZeroFloat64x2x4); -#else - MlasStoreAlignedFloat64x2(&d[0], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[2], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[4], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[6], ZeroFloat64x2); -#endif - - if ((CountX & 4) != 0) { - - MLAS_FLOAT64X2 t0 = MlasLoadFloat64x2(&b[0]); - MLAS_FLOAT64X2 t1 = MlasLoadFloat64x2(&b[2]); - - MlasStoreAlignedFloat64x2(&d[0], t0); - MlasStoreAlignedFloat64x2(&d[2], t1); - - d += 4; - b += 4; - } - - if ((CountX & 2) != 0) { - - MlasStoreAlignedFloat64x2(&d[0], MlasLoadFloat64x2(&b[0])); - - d += 2; - b += 2; - } - - if ((CountX & 1) != 0) { - d[0] = b[0]; - } - - D += 8; - B += ldb; - y--; - - } while (y > 0); - } -} - -void -MlasDgemmTransposePackB( - double* D, - const double* B, - size_t ldb, - size_t CountY, - size_t CountX - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - Columns of 8 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the DGEMM kernels. Any remaining - columns less than 8 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountY - Supplies the number of rows of the source matrix to transpose. - - CountX - Supplies the number of columns of the source matrix to transpose. - -Return Value: - - None. - ---*/ -{ - // - // Transpose elements from matrix B into the packed buffer 8 rows at a - // time. - // - - while (CountY >= 8) { - - const double* b = B; - size_t x = CountX; - - while (x > 0) { - - double t0 = b[0]; - double t1 = b[ldb]; - double t2 = b[ldb * 2]; - double t3 = b[ldb * 3]; - double t4 = b[ldb * 4]; - double t5 = b[ldb * 5]; - double t6 = b[ldb * 6]; - double t7 = b[ldb * 7]; - - D[0] = t0; - D[1] = t1; - D[2] = t2; - D[3] = t3; - D[4] = t4; - D[5] = t5; - D[6] = t6; - D[7] = t7; - - D += 8; - b += 1; - x--; - } - - B += ldb * 8; - CountY -= 8; - } - - // - // Special case the handling of the less than 8 remaining rows. - // - - if (CountY > 0) { - - MLAS_FLOAT64X2 ZeroFloat64x2 = MlasZeroFloat64x2(); - - size_t x = CountX; - - while (x > 0) { - - double* d = D; - const double* b = B; - - MlasStoreAlignedFloat64x2(&d[0], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[2], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[4], ZeroFloat64x2); - MlasStoreAlignedFloat64x2(&d[6], ZeroFloat64x2); - - if ((CountY & 4) != 0) { - - double t0 = b[0]; - double t1 = b[ldb]; - double t2 = b[ldb * 2]; - double t3 = b[ldb * 3]; - - d[0] = t0; - d[1] = t1; - d[2] = t2; - d[3] = t3; - - d += 4; - b += ldb * 4; - } - - if ((CountY & 2) != 0) { - - double t0 = b[0]; - double t1 = b[ldb]; - - d[0] = t0; - d[1] = t1; - - d += 2; - b += ldb * 2; - } - - if ((CountY & 1) != 0) { - d[0] = b[0]; - } - - D += 8; - B += 1; - x--; - } - } -} - -MLAS_FORCEINLINE -double* -MlasDgemmKernelLoop( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - double alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine steps through the rows of the input and output matrices calling - the kernel until all rows have been processed. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasDgemmCopyPackB or MlasDgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the number of rows from matrix A and matrix C to iterate - over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar alpha multiplier (see DGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the next address of matrix C. - ---*/ -{ - while (CountM > 0) { - - size_t RowsHandled; - -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasDgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasDgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); - } -#endif - - C += ldc * RowsHandled; - A += lda * RowsHandled; - CountM -= RowsHandled; - } - - return C; -} - -void -MlasDgemmOperation( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (DGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see DGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see DGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - -Return Value: - - None. - ---*/ -{ - double PanelA[MLAS_DGEMM_TRANSA_ROWS * MLAS_DGEMM_STRIDEK]; - MLAS_DECLSPEC_ALIGN(double PanelB[MLAS_DGEMM_STRIDEN * MLAS_DGEMM_STRIDEK], 8 * sizeof(double)); - - // - // Handle the special case of K equals zero. Apply the beta multiplier to - // the output matrix and exit. - // - - if (K == 0) { - MlasDgemmMultiplyBeta(C, M, N, ldc, beta); - return; - } - - // - // Compute the strides to step through slices of the input matrices. - // - // Expand the N stride if K is small or expand the K stride if N is small - // for better utilization of the B panel. Avoid changing the K stride if - // the A panel needs to be used for transposing. - // - - size_t StrideN = MLAS_DGEMM_STRIDEN; - size_t StrideK = MLAS_DGEMM_STRIDEK; - - if (N >= K) { - - while (StrideK / 2 >= K) { - StrideN *= 2; - StrideK /= 2; - } - - } else if (TransA == CblasNoTrans) { - - while (StrideN > 16 && StrideN / 2 >= N) { - StrideK *= 2; - StrideN /= 2; - } - } - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < N; n += CountN) { - - CountN = std::min(N - n, StrideN); - - // - // Multiply the output matrix by beta as needed. - // - - if (beta != 0.0f && beta != 1.0f) { - MlasDgemmMultiplyBeta(C + n, M, CountN, ldc, beta); - } - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - bool ZeroMode = (beta == 0.0f); - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, StrideK); - - // - // Copy or transpose a panel of matrix B to a local packed buffer. - // - - if (TransB == CblasNoTrans) { - MlasDgemmCopyPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); - } else { - MlasDgemmTransposePackB(PanelB, B + k + n * ldb, ldb, CountN, CountK); - } - - // - // Step through each slice of matrix A along the M dimension. - // - - double* c = C + n; - - if (TransA == CblasNoTrans) { - - MlasDgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); - - } else { - - const double* a = A + k * lda; - size_t RowsRemaining = M; - - while (RowsRemaining > 0) { - - // - // Transpose elements from matrix A into a local buffer. - // - - size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_DGEMM_TRANSA_ROWS)); - - MlasDgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); - - RowsRemaining -= RowsTransposed; - a += RowsTransposed; - - // - // Step through the rows of the local buffer. - // - - c = MlasDgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); - } - } - - ZeroMode = false; - } - } -} - -void -MlasDgemmThreaded( - const ptrdiff_t ThreadCountM, - const ptrdiff_t ThreadCountN, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const size_t M, - const size_t N, - const size_t K, - const MLAS_DGEMM_DATA_PARAMS* Data, - const ptrdiff_t ThreadId - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - DGEMM operation. - -Arguments: - - Context - Supplies the pointer to the context for the threaded operation. - - ThreadId - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - - const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; - const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; - - // - // Partition the operation along the M dimension. - // - - size_t RangeStartM; - size_t RangeCountM; - - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); - - // - // Partition the operation along the N dimension. - // - - size_t RangeStartN; - size_t RangeCountN; - - const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - - MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, - &RangeCountN); - - RangeStartN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); - - // - // Dispatch the partitioned operation. - // - - const size_t lda = Data->lda; - const size_t ldb = Data->ldb; - const size_t ldc = Data->ldc; - - const double* A = Data->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - const double* B = Data->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - double* C = Data->C + RangeStartM * ldc + RangeStartN; - - MlasDgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, - Data->alpha, A, lda, B, ldb, Data->beta, C, ldc); -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif -void -MLASCALL -MlasGemmBatch( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_DGEMM_DATA_PARAMS* Data, - size_t BatchSize, - MLAS_THREADPOOL* ThreadPool - ) -{ - // - // Compute the number of target threads given the complexity of the DGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K); - - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - // - // Segment the operation across multiple threads. - // - // N.B. Currently, the operation is segmented as a 1D partition, which - // works okay for operations involving skinny matrices. - // - - ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - - if (N > M) { - - const size_t BlockedN = (N + MLAS_DGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_DGEMM_STRIDEN_THREAD_ALIGN; - - if (size_t(ThreadsPerGemm) > BlockedN) { - ThreadsPerGemm = ptrdiff_t(BlockedN); - } - - ThreadCountM = 1; - ThreadCountN = ThreadsPerGemm; - - } else { - - if (size_t(ThreadsPerGemm) > M) { - ThreadsPerGemm = ptrdiff_t(M); - } - - ThreadCountM = ThreadsPerGemm; - ThreadCountN = 1; - } - - const ptrdiff_t TotalThreads = ThreadsPerGemm * static_cast(BatchSize); - MlasTrySimpleParallel(ThreadPool, TotalThreads, [=](ptrdiff_t tid) { - const ptrdiff_t GemmIdx = tid / ThreadsPerGemm; - const ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; - MlasDgemmThreaded(ThreadCountM, ThreadCountN, TransA, TransB, - M, N, K, &(Data[GemmIdx]), ThreadIdx); - }); - -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif -#endif diff --git a/onnxruntime/core/mlas/lib/dwconv.cpp b/onnxruntime/core/mlas/lib/dwconv.cpp deleted file mode 100644 index d48d9cbb17502..0000000000000 --- a/onnxruntime/core/mlas/lib/dwconv.cpp +++ /dev/null @@ -1,156 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - dwconv.cpp - -Abstract: - - This module implements the half precision floating point depthwise convolution routines. - ---*/ - -#include "fp16_common.h" - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - -MLAS_FORCEINLINE -void -MlasConvDepthwiseKernel( - const _mlas_fp16_* const* Input, - const _mlas_fp16_* Filter, - const _mlas_fp16_* Bias, - _mlas_fp16_* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize, - MLAS_HALF_GEMM_POSTPROCESSOR* PostProc -) -{ - while (OutputCount > 0) { - size_t ChannelOffset = 0; - size_t c = Channels; - - while (c >= 8) { - MLAS_FLOAT16X8 Accumulator = Bias == nullptr ? MlasZeroFloat16x8() : MlasLoadFloat16x8(&Bias[ChannelOffset]); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]); - MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]); - - Accumulator = MlasMultiplyAddFloat16x8(InputVector, FilterVector, Accumulator); - ChannelKernelOffset += Channels; - } - MlasStoreFloat16x8(Output, Accumulator); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - - if (c >= 4) { - MLAS_FLOAT16X4 Accumulator = Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadFloat16x4(&Bias[ChannelOffset]); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]); - MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - - Accumulator = MlasMultiplyAddFloat16x4(InputVector, FilterVector, Accumulator); - ChannelKernelOffset += Channels; - } - MlasStoreFloat16x4(Output, Accumulator); - Output += 4; - - ChannelOffset += 4; - c -= 4; - } - - if (c > 0) { - MLAS_FLOAT16X4 Accumulator = - Bias == nullptr ? MlasZeroFloat16x4() : MlasLoadPartialFloat16x4(&Bias[ChannelOffset], c); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]); - MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - - Accumulator = MlasMultiplyAddFloat16x4(InputValue, FilterValue, Accumulator); - ChannelKernelOffset += Channels; - } - MlasStorePartialFloat16x4(Output, Accumulator, c); - Output += c; - } - if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); - } - Input += KernelSize; - OutputCount -= 1; - } -} - -#else - -MLAS_FORCEINLINE -void -MlasConvDepthwiseKernel( - const _mlas_fp16_* const* Input, - const _mlas_fp16_* Filter, - const _mlas_fp16_* Bias, - _mlas_fp16_* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize, - MLAS_HALF_GEMM_POSTPROCESSOR* PostProc -) -{ - while (OutputCount > 0) { - for (size_t ChannelOffset = 0; ChannelOffset < Channels; ChannelOffset++) { - float Accumulator = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[ChannelOffset]); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - Accumulator += MLAS_Half2Float(Input[k][ChannelOffset]) * MLAS_Half2Float(Filter[ChannelKernelOffset]); - ChannelKernelOffset += Channels; - } - *Output++ = MLAS_Float2Half(Accumulator); - } - if (PostProc) { - PostProc->Process(reinterpret_cast(Output - Channels), 0, 0, 1, Channels, Channels); - } - Input += KernelSize; - OutputCount -= 1; - } -} - -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED - -void -MLASCALL -MlasConvDepthwise( - const MLAS_FP16* const* Input, - const MLAS_FP16* Filter, - const MLAS_FP16* Bias, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize, - MLAS_HALF_GEMM_POSTPROCESSOR* PostProc -) -{ - MlasConvDepthwiseKernel( - reinterpret_cast(Input), - reinterpret_cast(Filter), - reinterpret_cast(Bias), - reinterpret_cast<_mlas_fp16_*>(Output), - Channels, - OutputCount, - KernelSize, - PostProc - ); -} diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp deleted file mode 100644 index b45bd5162a02c..0000000000000 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ /dev/null @@ -1,269 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - erf.cpp - -Abstract: - - This module implements routines to compute the hyperbolic tangent function. - - This implementation uses the same polynomial coefficients and algorithm as - found in: https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff - Our usage requires building platform specific versions of - the algorithm to target different instruction sets. The implementation below - targets the base instruction set (typically SSE2) while assembly - implementations target newer instruction sets (such as FMA3). - ---*/ - -#include "mlasi.h" - -// -// Bundles the constants for use by kernels written in assembly. -// - -MLAS_INTERNAL_DATA const struct { - float ErfUpperAbsRange; - float ErfSplitBoundary; - float ErfSMALL_P0; - float ErfSMALL_P1; - float ErfSMALL_P2; - float ErfSMALL_P3; - float ErfSMALL_P4; - float ErfSMALL_P5_Minus_One; - float ErfReserved0; - float ErfBIG_P0; - float ErfBIG_P1; - float ErfBIG_P2; - float ErfBIG_P3; - float ErfBIG_P4; - float ErfBIG_P5; - float ErfBIG_P6_Minus_One; - float ErfNegZero; - float ErfOne; - - float Exp_UpperRange; - float Exp_LowerRange; - float Exp_Log2Reciprocal; - float Exp_log2_hi; - float Exp_log2_lo; - float Exp_P0; - float Exp_P1; - float Exp_P2; - float Exp_P3; - float Exp_P4; - float Exp_P5; - float Exp_P6; - float Exp_C; - int32_t Exp_X7F; -} MlasErfConstants = { - 3.925f, - 0.921875f, - -5.99104969e-4f, - 4.99339588e-3f, - -2.67667342e-2f, - 1.12818025e-1f, - -3.76124859e-1f, - 1.28379151e-1f, - 0.0f, - 1.72948930e-5f, - -3.83208680e-4f, - 3.88393435e-3f, - -2.42545605e-2f, - 1.06777847e-1f, - 6.34846687e-1f, - 1.28717512e-1f, - -0.0f, - 1.0f, - - // Independent parameters to calculate Exp for Erff() - 88.3762626647950f, - -88.3762626647949f, - 1.44269504088896341f, - -6.93145752e-1f, - -1.42860677e-6f, - 1.38319808e-3f, - 8.37550033e-3f, - 4.16689515e-2f, - 1.66664466e-1f, - 4.99999851e-1f, - 1.00000000e+0f, - 1.00000000e+0f, - 1.25829120e+7f, - 127, -}; - -void -MLASCALL -MlasErfKernel( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine implements the generic kernel for the error function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ - while (N >= 4) { - MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 NegZero = MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero); - MLAS_FLOAT32X4 SignMask = MlasAndFloat32x4(Value, NegZero); - MLAS_FLOAT32X4 AbsValue = MlasAndNotFloat32x4(NegZero, Value); - AbsValue = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfUpperAbsRange), AbsValue); - MLAS_FLOAT32X4 SquareValue = MlasMultiplyFloat32x4(AbsValue, AbsValue); - - MLAS_FLOAT32X4 r_small = MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P0); - r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P1)); - r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P2)); - r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P3)); - r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P4)); - r_small = MlasMultiplyAddFloat32x4(r_small, SquareValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSMALL_P5_Minus_One)); - r_small = MlasMultiplyAddFloat32x4(r_small, AbsValue, AbsValue); - MLAS_FLOAT32X4 split_mask = MlasGreaterThanFloat32x4(AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfSplitBoundary)); - r_small = MlasAndNotFloat32x4(split_mask, r_small); - - AbsValue = MlasAndFloat32x4(split_mask, AbsValue); // clear smaller value into zero for bigger number calculation - MLAS_FLOAT32X4 r_big = MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P0); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P1)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P2)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P3)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P4)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P5)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, MlasBroadcastFloat32x4(MlasErfConstants.ErfBIG_P6_Minus_One)); - r_big = MlasMultiplyAddFloat32x4(r_big, AbsValue, AbsValue); - - // 1.0 - exp(-r_big), no need to do min() - r_big = MlasXorFloat32x4(r_big, MlasBroadcastFloat32x4(MlasErfConstants.ErfNegZero)); // -r_big - r_big = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_LowerRange), r_big); - MLAS_FLOAT32X4 exp_c = MlasBroadcastFloat32x4(MlasErfConstants.Exp_C); - MLAS_FLOAT32X4 r = MlasMultiplyAddFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.Exp_Log2Reciprocal), r_big, exp_c); - r = MlasSubtractFloat32x4(r, exp_c); - - MLAS_FLOAT32X4 fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_hi), r_big); - fx = MlasMultiplyAddFloat32x4(r, MlasBroadcastFloat32x4(MlasErfConstants.Exp_log2_lo), fx); - // y = exp(fx) - MLAS_FLOAT32X4 y = MlasBroadcastFloat32x4(MlasErfConstants.Exp_P0); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P1)); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P2)); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P3)); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P4)); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P5)); - y = MlasMultiplyAddFloat32x4(y, fx, MlasBroadcastFloat32x4(MlasErfConstants.Exp_P6)); - // 1.0 - exp(fx) * 2^INT(r) - y = MlasMultiplyFloat32x4(y, MlasPowerOf2Float32x4(r)); - y = MlasSubtractFloat32x4(MlasBroadcastFloat32x4(MlasErfConstants.ErfOne), y); - - // merge two splits results - y = MlasOrFloat32x4(r_small, y); - y = MlasOrFloat32x4(y, SignMask); - - MlasStoreFloat32x4(Output, y); - - Input += 4; - Output += 4; - N -= 4; - } - - while (N > 0) { - float Value = *Input++; - float AbsValue = fabsf(Value); - - float r; - if (AbsValue > MlasErfConstants.ErfSplitBoundary) { - AbsValue = std::min(MlasErfConstants.ErfUpperAbsRange, AbsValue); - float r_big = MlasErfConstants.ErfBIG_P0; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P1; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P2; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P3; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P4; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P5; - r_big = r_big * AbsValue + MlasErfConstants.ErfBIG_P6_Minus_One; - r_big = r_big * AbsValue + AbsValue; - - r_big = std::max(-r_big, MlasErfConstants.Exp_LowerRange); - r = MlasErfConstants.Exp_Log2Reciprocal * r_big + MlasErfConstants.Exp_C; - r -= MlasErfConstants.Exp_C; - float fx = r * MlasErfConstants.Exp_log2_hi + r_big; - fx = r * MlasErfConstants.Exp_log2_lo + fx; - - float y = MlasErfConstants.Exp_P0; - y = y * fx + MlasErfConstants.Exp_P1; - y = y * fx + MlasErfConstants.Exp_P2; - y = y * fx + MlasErfConstants.Exp_P3; - y = y * fx + MlasErfConstants.Exp_P4; - y = y * fx + MlasErfConstants.Exp_P5; - y = y * fx + MlasErfConstants.Exp_P6; - - r = 1.0f - ldexpf(y, (int)r); - r = (Value <= -0.0f) ? -r : r; - } - else { - float SquareValue = AbsValue * AbsValue; - r = MlasErfConstants.ErfSMALL_P0; - r = r * SquareValue + MlasErfConstants.ErfSMALL_P1; - r = r * SquareValue + MlasErfConstants.ErfSMALL_P2; - r = r * SquareValue + MlasErfConstants.ErfSMALL_P3; - r = r * SquareValue + MlasErfConstants.ErfSMALL_P4; - r = r * SquareValue + MlasErfConstants.ErfSMALL_P5_Minus_One; - r = r * Value + Value; - } - - *Output++ = r; - N -= 1; - } -} - -void -MLASCALL -MlasComputeErf( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine computes the error function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().ErfKernelRoutine(Input, Output, N); -#else - MlasErfKernel(Input, Output, N); -#endif -} diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp deleted file mode 100644 index fe5402ed144aa..0000000000000 --- a/onnxruntime/core/mlas/lib/flashattn.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include - -#include "mlasi.h" - -void -MlasFlashAttentionThreaded( - void* argptr, - std::ptrdiff_t thread_id -) -{ - const MlasFlashAttentionThreadedArgs* args = reinterpret_cast(argptr); - ptrdiff_t q_block_size = static_cast(args->q_block_size); - ptrdiff_t kv_block_size = static_cast(args->kv_block_size); - ptrdiff_t batch_size = static_cast(args->batch_size); - ptrdiff_t num_heads = static_cast(args->num_heads); - ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); - ptrdiff_t kv_sequence_length = static_cast(args->kv_sequence_length); - ptrdiff_t qk_head_size = static_cast(args->qk_head_size); - ptrdiff_t v_head_size = static_cast(args->v_head_size); - float* buffer = args->buffer; - ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); - ptrdiff_t thread_count = static_cast(args->thread_count); - const float* query = args->query; - const float* key = args->key; - const float* value = args->value; - float* output = args->output; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - auto&& mlas_platform = GetMlasPlatform(); -#endif - - ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; - - ptrdiff_t task_start = 0; - ptrdiff_t task_end = 0; - ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; - ptrdiff_t quotient = total_task_count / thread_count; - ptrdiff_t remainder = total_task_count % thread_count; - if (thread_id < remainder) { - task_start = (quotient + 1) * thread_id; - task_end = task_start + quotient + 1; - } else { - task_start = quotient * thread_id + remainder; - task_end = task_start + quotient; - } - - for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { - ptrdiff_t batch_idx = task_index; - ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; - batch_idx /= q_chunk_count; - ptrdiff_t head_idx = batch_idx % num_heads; - batch_idx /= num_heads; - - char* buffer_current_thread = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; - float* l = reinterpret_cast(buffer_current_thread); - float* m = l + q_block_size; - for (ptrdiff_t t = 0; t < q_block_size; ++t) { - m[t] = std::numeric_limits::lowest(); - } - float* intermediate = m + q_block_size; - float* temp_output = intermediate + q_block_size * kv_block_size; - float negmax = 0; - - for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { - /* - S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T - old_m = m - m = max(m, rowmax(S)) - diff = old_m - m - S = exp(S - m) - l = exp(diff) * l + rowsum(S) - O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :] - */ - ptrdiff_t h = batch_idx * num_heads + head_idx; - const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size; - const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; - const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; - - size_t row_size_q_capped = static_cast(std::min(q_block_size, q_sequence_length - q_idx)); - size_t row_size_kv_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); - - MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, - CBLAS_TRANSPOSE::CblasTrans, - row_size_q_capped, - row_size_kv_capped, - static_cast(qk_head_size), - args->scale, - inputQ, - static_cast(qk_head_size), - inputK, - static_cast(qk_head_size), - 0.0f, - intermediate, - row_size_kv_capped); - - for (ptrdiff_t irow = 0; irow < static_cast(row_size_q_capped); ++irow) { - float* p = intermediate + irow * row_size_kv_capped; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped); -#else - float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped); -#endif - float m_diff = m[irow]; - m[irow] = std::max(m[irow], rowmax); // new m - negmax = -m[irow]; - m_diff -= m[irow]; // old - new (less than 0) - -#if defined(MLAS_TARGET_AMD64) - float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); -#else - float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); -#endif - - // Note: for ir == 0, there is actually no need to calculate exp_diff - if (ir != 0) { - float exp_diff = std::exp(m_diff); - l[irow] = exp_diff * l[irow] + rowsum; - - for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { - temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol]; - } - } else { - l[irow] = rowsum; - // When ir == 0, there is no need to scale the old result because it is zero. - } - } - MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, - CBLAS_TRANSPOSE::CblasNoTrans, - row_size_q_capped, - static_cast(v_head_size), - row_size_kv_capped, - 1.0f, - intermediate, - row_size_kv_capped, - inputV, - static_cast(v_head_size), - ir == 0 ? 0.0f : 1.0f, - temp_output, - static_cast(v_head_size)); - } - - float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size; - ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx); - // TODO: leverage advanced instruction sets - for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) { - for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { - output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow]; - } - output_row += num_heads * v_head_size; - } - } -} - -void -MLASCALL -MlasFlashAttention( - MlasFlashAttentionThreadedArgs* args, - MLAS_THREADPOOL* ThreadPool -) -{ - MlasExecuteThreaded( - MlasFlashAttentionThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h deleted file mode 100644 index 30b66cdb2ea78..0000000000000 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ /dev/null @@ -1,335 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - fp16_common.h - -Abstract: - - Intrinsic and inline functions for fp16 processing. - ---*/ - -#pragma once - -#include "mlas_float16.h" -#include "mlasi.h" - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - -// TODO!! Add intel fp16 implementations - -typedef float16x8_t MLAS_FLOAT16X8; -typedef float16x4_t MLAS_FLOAT16X4; -typedef uint16x8_t MLAS_UINT16X8; -typedef uint16x4_t MLAS_UINT16X4; - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasReinterpretAsFloat16x8(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasBroadcastFloat16x8(_mlas_fp16_ Value) { return vreinterpretq_f16_p16(vdupq_n_p16(Value)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasBroadcastFloat16x4(_mlas_fp16_ Value) { return vreinterpret_f16_p16(vdup_n_p16(Value)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasBroadcastFloat16x8(const _mlas_fp16_* Value) { return vreinterpretq_f16_u16(vld1q_dup_u16(Value)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasBroadcastFloat16x4(const _mlas_fp16_* Value) { return vreinterpret_f16_u16(vld1_dup_u16(Value)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasZeroFloat16x8(void) { return vreinterpretq_f16_f32(vdupq_n_f32(0.0f)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasZeroFloat16x4(void) { return vreinterpret_f16_f32(vdup_n_f32(0.0f)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasLoadFloat16x8(const _mlas_fp16_* Buffer) { return vreinterpretq_f16_u16(vld1q_u16(Buffer)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) -{ - MLAS_FLOAT16X4 Vector = MlasZeroFloat16x4(); - if ((len & 1) != 0) { - Vector = vreinterpret_f16_u16(vld1_lane_u16(Buffer + (len - 1), vreinterpret_u16_f16(Vector), 0)); - } - if ((len & 2) != 0) { - Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 0)); - Vector = vreinterpret_f16_f32( - vld1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0) - ); - } - return Vector; -} - -MLAS_FORCEINLINE -void -MlasStoreFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) -{ - vst1q_u16(Buffer, vreinterpretq_u16_f16(Vector)); -} - -MLAS_FORCEINLINE -void -MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) -{ - vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); -} - -MLAS_FORCEINLINE -void -MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) -{ - if ((len & 2) != 0) { - vst1_lane_f32(reinterpret_cast(Buffer), vreinterpret_f32_f16(Vector), 0); - Vector = vreinterpret_f16_f32(vdup_lane_f32(vreinterpret_f32_f16(Vector), 1)); - Buffer += 2; - } - if ((len & 1) != 0) { - vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), 0); - } -} - -template -MLAS_FORCEINLINE void -MlasStoreLaneFloat16x8(_mlas_fp16_* Buffer, MLAS_FLOAT16X8 Vector) -{ - vst1q_lane_u16(Buffer, vreinterpretq_u16_f16(Vector), Lane); -} - -MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V) -{ - // vget_low should be compiled to nothing - return vget_low_f16(V); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vaddq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vadd_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasSubtractFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vsubq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasSubtractFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vsub_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasMultiplyFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vmulq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasMultiplyFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vmul_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasDivFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vdivq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasDivFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vdiv_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Vector3) -{ - return vfmaq_f16(Vector3, Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasMultiplyAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Vector3) -{ - return vfma_f16(Vector3, Vector1, Vector2); -} - - -MLAS_FORCEINLINE -void -MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3) -{ - MlasMultiplyAddFloat16x8(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); -} - -MLAS_FORCEINLINE -void -MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3) -{ - MlasMultiplyAddFloat16x8(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasDivideFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vdivq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasGreaterThanFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vreinterpretq_f16_u16(vcgtq_f16(Vector1, Vector2)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasAndFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vreinterpretq_f16_s64(vandq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2))); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasOrFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vreinterpretq_f16_s64(vorrq_s64(vreinterpretq_s64_f16(Vector1), vreinterpretq_s64_f16(Vector2))); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasAndNotFloat16x8(MLAS_FLOAT16X8 VectorNot, MLAS_FLOAT16X8 Vector) -{ - return vreinterpretq_f16_s32(vandq_s32(vmvnq_s32(vreinterpretq_s32_f16(VectorNot)), vreinterpretq_s32_f16(Vector))); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasXorFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vreinterpretq_f16_s32(veorq_s32(vreinterpretq_s32_f16(Vector1), vreinterpretq_s32_f16(Vector2))); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Selection) -{ - return MlasOrFloat16x8(MlasAndFloat16x8(Vector2, Selection), - MlasAndNotFloat16x8(Selection, Vector1)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasMaximumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vmaxq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasMaximumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vmax_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasMinimumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vminq_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasMinimumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) -{ - return vmin_f16(Vector1, Vector2); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange) -{ - Value = MlasMaximumFloat16x8(MlasBroadcastFloat16x8(LowerRange), Value); - Value = MlasMinimumFloat16x8(MlasBroadcastFloat16x8(UpperRange), Value); - return Value; -} - -MLAS_FORCEINLINE -_mlas_fp16_ -MlasReduceAddFloat16x8(MLAS_FLOAT16X8 Vector) -{ - Vector = vpaddq_f16(Vector, Vector); - Vector = vpaddq_f16(Vector, Vector); - return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); -} - -MLAS_FORCEINLINE -MLAS_UINT16X8 -MlasCmpLessEqualFloat16x8(MLAS_FLOAT16X8 left, MLAS_FLOAT16X8 right) -{ - return vcleq_f16(left, right); -} - -MLAS_FORCEINLINE -MLAS_UINT16X4 -MlasCmpLessEqualFloat16x4(MLAS_FLOAT16X4 left, MLAS_FLOAT16X4 right) -{ - return vcle_f16(left, right); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasBitwiseSelectFloat16x8(MLAS_UINT16X8 select, MLAS_FLOAT16X8 ones, MLAS_FLOAT16X8 zeros) -{ - return vbslq_f16(select, ones, zeros); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT16X4 zeros) -{ - return vbsl_f16(select, ones, zeros); -} - -#endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp deleted file mode 100644 index 29734c2277667..0000000000000 --- a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp +++ /dev/null @@ -1,164 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - fp16_neon_common.cpp - -Abstract: - - This module implements the common kernels for ARM NEON specific to float16. - ---*/ - -#include "mlasi.h" - -#include "arm_neon.h" - -// This file is enabled in cmake only if ARM64 is defined and not on Apple platforms -// The cmake condition is equivalent to MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64. -// Therefore omit the MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 macro in this file. - -MLAS_FORCEINLINE -size_t -StoreFp32Lane(float* dest, float32x4_t src, size_t count) -{ - if (count == 3) { - vst1q_lane_f32(dest + 0, src, 0); - vst1q_lane_f32(dest + 1, src, 1); - vst1q_lane_f32(dest + 2, src, 2); - return 3; - } else if (count == 2) { - vst1q_lane_f32(dest + 0, src, 0); - vst1q_lane_f32(dest + 1, src, 1); - return 2; - } else if (count == 1) { - vst1q_lane_f32(dest + 0, src, 0); - return 1; - } - - return 0; -} - -void -MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count) -{ - // 4 float16 alignment - auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 7) & ~7); - auto pre_count = std::min(static_cast(src_aligned - src), count); - size_t i = 0; - - // Handle leading unaligned src - if (pre_count > 0) { - float16x4_t fp16v4; - std::memcpy(&fp16v4, src, pre_count * sizeof(unsigned short)); - float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); - - i = StoreFp32Lane(dest, fp32v4, pre_count); - } - - // aligned src - for (; i + 7 < count; i += 8) - { - float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); - float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); - vst1q_f32(dest + i, fp32v4_0); - - float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4)); - float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1); - vst1q_f32(dest + i + 4, fp32v4_1); - } - - if (i + 3 < count) - { - float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); - float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); - vst1q_f32(dest + i, fp32v4_0); - i += 4; - } - - // Handle trailing unaligned src - auto post_count = count - i; - if (post_count > 0) - { - float16x4_t fp16v4; - std::memcpy(&fp16v4, src + i, post_count * sizeof(unsigned short)); - float32x4_t fp32v4 = vcvt_f32_f16(fp16v4); - - StoreFp32Lane(dest + i, fp32v4, post_count); - } -} - -MLAS_FORCEINLINE -size_t -StoreU16Lane(unsigned short* dest, uint16x4_t src, size_t count) -{ - if (count == 3) { - vst1_lane_u16(dest + 0, src, 0); - vst1_lane_u16(dest + 1, src, 1); - vst1_lane_u16(dest + 2, src, 2); - return 3; - } else if (count == 2) { - vst1_lane_u16(dest + 0, src, 0); - vst1_lane_u16(dest + 1, src, 1); - return 2; - } else if (count == 1) { - vst1_lane_u16(dest + 0, src, 0); - return 1; - } - - return 0; -} - -void -MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count) -{ - // 4 float32 alignment - auto* src_aligned = reinterpret_cast((reinterpret_cast(src) + 15) & ~15); - auto pre_count = std::min(static_cast(src_aligned - src), count); - size_t i = 0; - - // Handle leading unaligned src - if (pre_count > 0) - { - float32x4_t fp32v4; - std::memcpy(&fp32v4, src, pre_count * sizeof(float)); - uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); - - i = StoreU16Lane(dest, u16v4, pre_count); - } - - // aligned src - for (; i + 7 < count; i += 8) - { - float32x4_t fp32v4_0 = vld1q_f32(src + i); - float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); - vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); - - float32x4_t fp32v4_1 = vld1q_f32(src + i + 4); - float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1); - vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1)); - } - - if (i + 3 < count) - { - float32x4_t fp32v4_0 = vld1q_f32(src + i); - float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); - vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); - i += 4; - } - - // Handle trailing unaligned src - auto post_count = count - i; - if (post_count > 0) - { - float32x4_t fp32v4; - std::memcpy(&fp32v4, src + i, post_count * sizeof(float)); - uint16x4_t u16v4 = vreinterpret_u16_f16(vcvt_f16_f32(fp32v4)); - - StoreU16Lane(dest + i, u16v4, post_count); - } -} diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp deleted file mode 100644 index 49387d2fc998f..0000000000000 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ /dev/null @@ -1,335 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - half gemm.cpp - -Abstract: - - This module implements the half precision (fp16) matrix/matrix multiply - operation (QGEMM). - ---*/ - -#include "mlasi.h" -#include "mlas_float16.h" - -#include "halfgemm.h" - -#include - -bool MLASCALL -MlasFp16AccelerationSupported() -{ -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); -#else - return false; -#endif -} - - -void -MLASCALL -MlasHalfGemmBatch( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool - ) -{ - const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); - MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; - - if (ThreadPool == nullptr) { - for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - operation(N, K, Data, 0, M, 0, N); - } - return; - } - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - const size_t StrideM = dispatch->StrideM; - - size_t nc = N; - if ((size_t)MlasGetMaximumThreadCount(ThreadPool) > BatchN) { - // more than one thread per GEMM - - const size_t BlockedM = MlasDivRoundup(M, StrideM); - const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min(nc, MlasDivRoundup(nc, max_nc * MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * - MLAS_QGEMM_STRIDEN_THREAD_ALIGN); - } - } - const size_t StrideN = nc; - - const size_t ThreadCountM = MlasDivRoundup(M, StrideM); - const size_t ThreadCountN = MlasDivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; - - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; - - const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; - const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; - - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); - - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - - operation(N, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - }); -} - - -size_t -MLASCALL -MlasHalfGemmPackBSize( - size_t N, - size_t K, - bool float2half - ) -{ - const auto* dispatch = MlasHalfGemmGetDispatch(); - const auto padding = dispatch->BufOverRead; - const auto PackedK = dispatch->PackededK; - if (!float2half && dispatch->CopyPackBRoutine == nullptr) { - // No packing routine provided - return 0; - } - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; - const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = - (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); - return AlignedBytesRequired; -} - -void -MLASCALL -MlasHalfGemmPackB( - size_t N, - size_t K, - const MLAS_FP16* B, - size_t ldb, - void* PackedB - ) -{ - const auto* dispatch = MlasHalfGemmGetDispatch(); - dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); -} - -void -MLASCALL -MlasHalfGemmConvertPackB( - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ) -{ - const auto* dispatch = MlasHalfGemmGetDispatch(); - dispatch->ConvertPackBRoutine((_mlas_fp16_*)PackedB, B, ldb, N, K); -} - - -// -// Post Processor Implementations -// - -MLAS_FORCEINLINE -void -CvtHalf2Float( - float* dest, - const _mlas_fp16_* src, - size_t len -) -{ -#ifdef MLAS_TARGET_ARM64 - while (len >= 4) { - const auto* srcPtr = reinterpret_cast(src); - auto* dstPtr = reinterpret_cast(dest); - *dstPtr = vcvt_f32_f16(*srcPtr); - src += 4; - dest += 4; - len -= 4; - } - - if (0 == len) { - return; - } - - float16x4_t buf; - std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); - float32x4_t res = vcvt_f32_f16(buf); - - if ((len & 2) != 0) { - auto wide = vreinterpretq_f64_f32(res); - vst1q_lane_f64((float64_t*)dest, wide, 0); - res = vreinterpretq_f32_f64(vdupq_laneq_f64(wide, 1)); - dest += 2; - } - if ((len & 1) != 0) { - vst1q_lane_f32(dest, res, 0); - } -#else - for (size_t i = 0; i < len; i++) { - *dest++ = MLAS_Half2Float(*src++); - } -#endif // MLAS_TARGET_ARM64 -} - -void -MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( - MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const -{ - float* Output = Output_; - const auto* CRow = reinterpret_cast(C); - CRow += StartM * ldc + StartN; - Output += StartM * RowStride_ + StartN; - - while (CountM-- > 0) { - CvtHalf2Float(Output, CRow, CountN); - MlasActivation(&Activation_, Output, nullptr, 1, CountN, ldc); - CRow += ldc; - Output += RowStride_; - } -} - - -// -// Dummy C++ implementation that runs very slowly -// - -struct MLAS_HALF_GEMM_KERNEL_DEFAULT { - - static constexpr bool PackNeeded = false; - static constexpr size_t KernelMaxM = 128; // max # rows the vectorized kernel can process - static constexpr size_t PackedK = 1; - - static constexpr MLAS_HALF_GEMM_STRIDES Strides{8, 16, 32}; -}; - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmConvertPackA( - _mlas_fp16_* D, - const float* A, - size_t lda, - size_t CountM, - size_t CountK -) -{ - for (size_t m = 0; m < CountM; m++) { - for (size_t k = 0; k < CountK; k++) { - *D++ = MLAS_Float2Half(*(A + m * lda + k)); - } - } -} - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmConvertPackB( - _mlas_fp16_* D, - const float* B, - size_t ldb, - size_t CountN, - size_t CountK -) -{ - for (size_t k = 0; k < CountK; k++) { - for (size_t n = 0; n < CountN; n++) { - *D++ = MLAS_Float2Half(*(B + k * ldb + n)); - } - } -} - - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmKernel( - size_t CountM, - size_t CountN, - size_t CountK, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, - const _mlas_fp16_* A, - size_t lda, - const _mlas_fp16_* B, - size_t ldb, - const bool ZeroMode) -{ - for (size_t m = 0; m < CountM; m++) { - for (size_t n = 0; n < CountN; n++) { - const auto* a = A + (m * lda); - const auto* b = B + n; - auto* c = C + (m * ldc) + n; - - float sum = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[n]); - if (!ZeroMode) { - sum += MLAS_Half2Float(*c); - } - - for (size_t k = 0; k < CountK; k++) { - auto down = MLAS_Float2Half(MLAS_Half2Float(*a) * MLAS_Half2Float(*b) + sum); - sum = MLAS_Half2Float(down); - b += ldb; - a += 1; - } - - *c = MLAS_Float2Half(sum); - } - } -} - - -const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { - MlasHalfGemmOperation, - nullptr, - MlasHalfGemmConvertPackB, - MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, - MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, - 0 -}; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h deleted file mode 100644 index 61e2fbb0afc6a..0000000000000 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ /dev/null @@ -1,515 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - halfgemm.h - -Abstract: - - This module defines the set of template functions to implement half - precision matrix/matrix multiply operation (QGEMM). - - To implement a new kernel, template functions below need to be specialized: - MlasHalfGemmCopyPackB - MlasHalfGemmConvertPackA - MlasHalfGemmConvertPackB - MlasHalfGemmPackedBOffset - MlasHalfGemmPackedBLeadingDim - MlasHalfGemmKernel - - MlasHalfGemmOperation is the shared kernel driver. - - A kernel type should define the following constants: - bool PackNeeded; Whether fp16 B needs to be packed - size_t KernelMaxM; Max # rows the vectorized kernel can process - size_t PackedK; Packed alignment on the K dim (power of 2) - MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; ---*/ - -#pragma once - -#include -#include -#include - -#include "mlasi.h" -#include "mlas_float16.h" - - -/** - * @brief Define the default striding parameters for - * the half precision gemm operation - */ -struct MLAS_HALF_GEMM_STRIDES { - size_t M; - size_t N; - size_t K; -}; - -/** - * @brief Packing function for fp16 B matrix - * - * @tparam KernelType - * @param[out] D Address of packing buffer - * @param[in] B Address of source matrix B - * @param[in] ldb Leading dimension of B - * @param[in] CountN # of column to pack - * @param[in] CountK # of rows to pack -*/ -template -MLAS_FORCEINLINE -void -MlasHalfGemmCopyPackB( - _mlas_fp16_* D, - const _mlas_fp16_* B, - size_t ldb, - size_t CountN, - size_t CountK -) -{ - MLAS_UNREFERENCED_PARAMETER(D); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(CountK); - // No packing needed by default -} - -/** - * @brief Convert fp32 matrix A to fp16 and pack the data - * - * @tparam KernelType - * @param[out] D Address of the packing buffer - * @param[in] A Address of fp32 matrix A - * @param[in] lda leading dimension of A - * @param[in] CountM # of rows to pack - * @param[in] CountK # of columns to pack -*/ -template -void -MlasHalfGemmConvertPackA( - _mlas_fp16_* D, - const float* A, - size_t lda, - size_t CountM, - size_t CountK -); - -/** - * @brief Convert fp32 matrix B to fp16 and pack the data - * - * @tparam KernelType - * @param[out] D Address of packing buffer - * @param[in] B Address of source matrix B in fp32 - * @param[in] ldb Leading dimension of B - * @param[in] CountN # of column to pack - * @param[in] CountK # of rows to pack - */ -template -void -MlasHalfGemmConvertPackB( - _mlas_fp16_* D, - const float* B, - size_t ldb, - size_t CountN, - size_t CountK -); - -/** - * @brief Find the location of PackedB[StartK, StartN] - * - * @tparam KernelType - * @param PackedB - * @param DimN Total columns of the packing buffer - * @param DimK Total rows of the packing buffer - * @param StartN - * @param StartK - * @return Address of PackedB[StartK, StartN] -*/ -template -MLAS_FORCEINLINE -const _mlas_fp16_* -MlasHalfGemmPackedBOffset( - const _mlas_fp16_* PackedB, - size_t DimN, - size_t DimK, - size_t StartN, - size_t StartK) -{ - // By default the packed buffer is just a row major - // K row by N column buffer - MLAS_UNREFERENCED_PARAMETER(DimK); - return PackedB + StartK * DimN + StartN; -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -/*No it can NOT be constexpr!.*/ -#pragma warning(disable : 26497) -#endif - -/** - * @brief leading dimension of the packed B buffer - * Related to how B is packed - * @tparam KernelType - * @param DimN - * @param DimK - * @return leading dimension of the packed B buffer -*/ -template -MLAS_FORCEINLINE -size_t -MlasHalfGemmPackedBLeadingDim( - size_t DimN, - size_t DimK) -{ - // By default the packed buffer is just a row major - // K row by N column buffer - MLAS_UNREFERENCED_PARAMETER(DimK); - return DimN; -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - -template -void -MlasHalfGemmKernel( - const size_t CountM, - const size_t CountN, - const size_t CountK, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, - const _mlas_fp16_* A, - const size_t lda, - const _mlas_fp16_* B, - const size_t ldb, - const bool ZeroMode -); - - -template -MLAS_FORCEINLINE -void -MlasHalfGemmNoPackOperation( - const size_t N, - const size_t K, - const MLAS_HALF_GEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -{ - // - // Optimize for the special case where no packing is needed. - // Simpler tiling as we are not restricted by packing panel size - // - - const size_t lda = Data->lda; - size_t ldb = Data->ldb; // 0 if prepacked - const size_t ldc = Data->ldc; - - const auto* pa = reinterpret_cast(Data->A) - + RangeStartM * lda; - const _mlas_fp16_* pb; - if (ldb == 0) { - pb = MlasHalfGemmPackedBOffset( - reinterpret_cast(Data->B), - N, - K, - RangeStartN, - 0); - ldb = MlasHalfGemmPackedBLeadingDim(N, K); - } else { - pb = reinterpret_cast(Data->B) + RangeStartN; - } - - const _mlas_fp16_* Bias = (nullptr == Data->Bias) - ? nullptr - : reinterpret_cast(Data->Bias) + RangeStartN; - _mlas_fp16_* c = reinterpret_cast<_mlas_fp16_*>(Data->C) - + RangeStartM * ldc + RangeStartN; - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - MlasHalfGemmKernel( - RowsRemaining, - RangeCountN, - K, - c, - ldc, - Bias, - pa, - lda, - pb, - ldb, - true); - - size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); - - if (Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + RangeCountM - RowsRemaining, - RangeStartN, - RowsHandled, - RangeCountN, - Data->ldc); - } - - c += ldc * RowsHandled; - pa += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } -} - - -template -void -MlasHalfGemmOperation( - const size_t N, - const size_t K, - const MLAS_HALF_GEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -{ - const size_t lda = Data->lda; - const size_t ldb = Data->ldb; - const size_t ldc = Data->ldc; - - if (!Data->AIsfp32 && (ldb == 0 || (!KernelType::PackNeeded && !Data->BIsfp32))) { - // !Data->AIsfp32 => A is fp16, no packing on the left hand side - // ldb == 0 => B is already packed, no packing on the right hand side - // !KernelType::PackNeeded && !Data->BIsfp32 => B is fp16 and the kernel - // does not require packing - // - // So no packing needed on either A or B, use a simpler driver instead - - MlasHalfGemmNoPackOperation( - N, - K, - Data, - RangeStartM, - RangeCountM, - RangeStartN, - RangeCountN); - return; - } - - const auto* Bias = reinterpret_cast(Data->Bias); - _mlas_fp16_* C = reinterpret_cast<_mlas_fp16_*>(Data->C) - + RangeStartM * ldc + RangeStartN; - - // - // Three dimensional tiling due to limited packing panel size - // - constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; - constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * FP16_SIZE); - constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * FP16_SIZE); - MlasThreadedBufAlloc(packASize + packBSize); - - uint8_t* p = ThreadedBufHolder.get(); - auto* PanelA = reinterpret_cast<_mlas_fp16_*>(p); - p += packASize; - auto* PanelB = reinterpret_cast<_mlas_fp16_*>(p); - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - for (size_t k = 0; k < K; k += CountK) { - CountK = std::min(K - k, Strides.K); - const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, Strides.N); - - // - // Copy a panel of matrix B to a local packed buffer. - // - size_t ld_pb; - const _mlas_fp16_* pb; - if (ldb == 0) { - // Already packed - pb = MlasHalfGemmPackedBOffset( - reinterpret_cast(Data->B), - N, - K, - RangeStartN + n, - k); - ld_pb = MlasHalfGemmPackedBLeadingDim(N, K); - } else if (Data->BIsfp32) { - // fp32, need conversion and packing - MlasHalfGemmConvertPackB( - PanelB, - reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, - ldb, - CountN, - CountK); - pb = PanelB; - ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); - } else if (KernelType::PackNeeded) { - // fp16, need packing - MlasHalfGemmCopyPackB( - PanelB, - reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, - ldb, - CountN, - CountK); - pb = PanelB; - ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); - } else { - // fp16, and no packing needed - pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; - ld_pb = ldb; - } - - // - // Step through each slice of matrix A along the M dimension. - // - - auto* c = C + n; - const auto* pbias = (nullptr == Bias) ? nullptr : Bias + RangeStartN + n; - size_t CountM; - for (size_t m = 0; m < RangeCountM; m += CountM) { - CountM = std::min(RangeCountM - m, Strides.M); - - // - // Copy a panel of matrix A to a local packed buffer. - // - const _mlas_fp16_* pa; - size_t ld_pa; - if (Data->AIsfp32) { - MlasHalfGemmConvertPackA( - PanelA, - reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k, - lda, - CountM, - CountK); - pa = PanelA; - ld_pa = KernelType::PackedK * PackedCountK; - } else { - pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; - ld_pa = lda; - } - - size_t RowsRemaining = CountM; - bool ZeroMode = (k == 0); - bool PostProcess = (k + CountK == K); - - while (RowsRemaining > 0) { - MlasHalfGemmKernel( - RowsRemaining, - CountN, - CountK, - c, - ldc, - ZeroMode ? pbias : nullptr, - pa, - ld_pa, - pb, - ld_pb, - ZeroMode); - - size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); - - if (PostProcess && Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + m + CountM - RowsRemaining, - RangeStartN + n, - RowsHandled, - CountN, - Data->ldc); - } - - c += ldc * RowsHandled; - pa += ld_pa * RowsHandled; - RowsRemaining -= RowsHandled; - } - } - } - } -} - - -// -// dispatch structure. -// - -typedef -void -(MLAS_HALFGEMM_OPERATION)( - const size_t N, - const size_t K, - const MLAS_HALF_GEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ); - - -typedef -void -(MLAS_HALFGEMM_COPYPACKB_ROUTINE)( - _mlas_fp16_* D, - const _mlas_fp16_* B, - size_t ldb, - size_t CountN, - size_t CountK - ); - -typedef -void -(MLAS_HALFGEMM_CONVERTPACKB_ROUTINE)( - _mlas_fp16_* D, - const float* B, - size_t ldb, - size_t CountN, - size_t CountK - ); - -/** - * @brief Hardware dependent dispatch for half precision GEMM -*/ -struct MLAS_HALFGEMM_DISPATCH { - MLAS_HALFGEMM_OPERATION* Operation; /**< HalfGemm driver */ - MLAS_HALFGEMM_COPYPACKB_ROUTINE* CopyPackBRoutine; /**< Pack function for B */ - MLAS_HALFGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ - size_t PackededK; - size_t StrideM; - size_t BufOverRead; -}; - -extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; - -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) -extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; -#endif - -MLAS_FORCEINLINE -const MLAS_HALFGEMM_DISPATCH* -MlasHalfGemmGetDispatch() -{ -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - return &MlasHalfGemmDispatchNeon; -#else - return &MlasHalfGemmDispatchDefault; -#endif -} diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp deleted file mode 100644 index d7f5a90b00589..0000000000000 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - halfgemm_kernel_neon.cpp - -Abstract: - - This module implements half precision GEMM kernel for neon. - ---*/ - -#include "mlasi.h" -#include "halfgemm.h" - -#include "arm_neon.h" - -// -// Define the prototypes of the NEON routines written in assembly. -// -// N.B. The kernel has not been ported to build with the Windows ARM32 toolset. -// - -extern "C" { - - size_t - MLASCALL - MlasHalfGemmKernelNeon( - const size_t CountM, - const size_t CountN, - const size_t CountK, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, - const _mlas_fp16_* A, - const size_t lda, - const _mlas_fp16_* B, - const size_t ldb, - const bool ZeroMode - ); - -} - - -struct MLAS_HALF_GEMM_KERNEL_NEON { - static constexpr bool PackNeeded = false; - static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process - static constexpr size_t PackedK = 1; - - static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 512}; -}; - - -MLAS_FORCEINLINE -void -CvtFloat2Half( - _mlas_fp16_* dest, - const float* src, - size_t len -) -{ - while (len >= 4) { - const auto* srcPtr = reinterpret_cast(src); - auto* dstPtr = reinterpret_cast(dest); - *dstPtr = vcvt_f16_f32(*srcPtr); - src += 4; - dest += 4; - len -= 4; - } - - if (0 == len) { - return; - } - - float32x4_t buf; - std::memcpy(&buf, src, len * sizeof(float)); - float16x4_t res = vcvt_f16_f32(buf); - - if ((len & 2) != 0) { - auto wide = vreinterpret_f32_f16(res); - vst1_lane_f32((float32_t*)dest, wide, 0); - res = vreinterpret_f16_f32(vdup_lane_f32(wide, 1)); - dest += 2; - } - if ((len & 1) != 0) { - vst1_lane_u16(dest, vreinterpret_u16_f16(res), 0); - } -} - -/** - * @brief Convert a 2D matrix from float to fp16 -*/ -MLAS_FORCEINLINE -void -CvtFloat2Half2D( - _mlas_fp16_* dest, - const float* src, - size_t stride, - size_t CntRow, - size_t CntCol - ) -{ - if (stride == CntCol) { - const size_t len = CntRow * CntCol; - CvtFloat2Half(dest, src, len); - return; - } - while (CntRow > 0) { - CvtFloat2Half(dest, src, CntCol); - src += stride; - dest += CntCol; - CntRow--; - } -} - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmConvertPackA( - _mlas_fp16_* D, - const float* A, - size_t lda, - size_t CountM, - size_t CountK -) -{ - CvtFloat2Half2D(D, A, lda, CountM, CountK); -} - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmConvertPackB( - _mlas_fp16_* D, - const float* B, - size_t ldb, - size_t CountN, - size_t CountK -) -{ - CvtFloat2Half2D(D, B, ldb, CountK, CountN); -} - - -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmKernel( - size_t CountM, - size_t CountN, - size_t CountK, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, - const _mlas_fp16_* A, - size_t lda, - const _mlas_fp16_* B, - size_t ldb, - const bool ZeroMode) -{ - MlasHalfGemmKernelNeon( - CountM, - CountN, - CountK, - C, - ldc, - Bias, - A, - lda, - B, - ldb, - ZeroMode); -} - - -const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { - MlasHalfGemmOperation, - nullptr, - MlasHalfGemmConvertPackB, - MLAS_HALF_GEMM_KERNEL_NEON::PackedK, - MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, - 32 // kernel may read beyond buffer end by 32 bytes -}; diff --git a/onnxruntime/core/mlas/lib/i386/SgemmKernelAvx.asm b/onnxruntime/core/mlas/lib/i386/SgemmKernelAvx.asm deleted file mode 100644 index cbc88c3c8d40d..0000000000000 --- a/onnxruntime/core/mlas/lib/i386/SgemmKernelAvx.asm +++ /dev/null @@ -1,413 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelAvx.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses AVX instructions. -; -;-- - - .686 - .xmm - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc - .list - - ASSUME DS:FLAT,ES:FLAT,SS:NOTHING,FS:NOTHING,GS:NOTHING - - EXTERN _MlasMaskMoveTableAvx:NEAR - -_TEXT SEGMENT DWORD PUBLIC 'CODE' - -; -; Macro Description: -; -; This macro multiplies and accumulates for a 16xN block of the output matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; ebx - Supplies the length in bytes of a row from matrix A. -; -; ecx - Supplies the address into the matrix A data. -; -; edx - Supplies the address into the matrix B data. -; -; ymm4-ymm7 - Supplies the block accumulators. -; - -ComputeBlockAvxBy16 MACRO RowCount, VectorOffset, BroadcastOffset - -IF RowCount EQ 1 - vbroadcastss ymm3,DWORD PTR [ecx+BroadcastOffset] - vmulps ymm1,ymm3,YMMWORD PTR [edx+VectorOffset] - vaddps ymm4,ymm1,ymm4 - vmulps ymm3,ymm3,YMMWORD PTR [edx+VectorOffset+32] - vaddps ymm5,ymm3,ymm5 -ELSE - vmovaps ymm0,YMMWORD PTR [edx+VectorOffset] - vmovaps ymm1,YMMWORD PTR [edx+VectorOffset+32] - vbroadcastss ymm3,DWORD PTR [ecx+BroadcastOffset] - vmulps ymm2,ymm3,ymm0 - vaddps ymm4,ymm2,ymm4 - vmulps ymm2,ymm3,ymm1 - vaddps ymm5,ymm2,ymm5 - vbroadcastss ymm3,DWORD PTR [ecx+ebx+BroadcastOffset] - vmulps ymm2,ymm3,ymm0 - vaddps ymm6,ymm2,ymm6 - vmulps ymm2,ymm3,ymm1 - vaddps ymm7,ymm2,ymm7 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro multiplies and accumulates for a 8xN block of the output matrix. -; -; Arguments: -; -; RowCount - Supplies the number of rows to process. -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. -; -; Implicit Arguments: -; -; ebx - Supplies the length in bytes of a row from matrix A. -; -; ecx - Supplies the address into the matrix A data. -; -; edx - Supplies the address into the matrix B data. -; -; ymm4-ymm7 - Supplies the block accumulators. -; - -ComputeBlockAvxBy8 MACRO RowCount, VectorOffset, BroadcastOffset - -IF RowCount EQ 1 - vbroadcastss ymm3,DWORD PTR [ecx+BroadcastOffset] - vmulps ymm3,ymm3,YMMWORD PTR [edx+VectorOffset] - vaddps ymm5,ymm3,ymm5 -ELSE - vmovaps ymm0,YMMWORD PTR [edx+VectorOffset] - vbroadcastss ymm3,DWORD PTR [ecx+BroadcastOffset] - vmulps ymm3,ymm3,ymm0 - vaddps ymm5,ymm3,ymm5 - vbroadcastss ymm3,DWORD PTR [ecx+ebx+BroadcastOffset] - vmulps ymm3,ymm3,ymm0 - vaddps ymm7,ymm3,ymm7 -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; ebx - Supplies the number of bytes to the next row of matrix A. -; -; ecx - Supplies the address into the matrix A data. -; -; edx - Supplies the address into the matrix B data. -; -; edi - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; ymm4-ymm7 - Supplies the block accumulators. -; - -ComputeBlockAvxLoop MACRO ComputeBlock, RowCount - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL OutputBlock - - sub edi,4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - ComputeBlock RowCount, 0, 0 - ComputeBlock RowCount, 16*4, 4 - sub edx,-32*4 ; advance matrix B by 32 columns - ComputeBlock RowCount, 0, 8 - ComputeBlock RowCount, 16*4, 12 - sub edx,-32*4 ; advance matrix B by 32 columns - add ecx,4*4 ; advance matrix A by 4 columns - sub edi,4 - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add edi,4 ; correct for over-subtract above - jz OutputBlock - -ComputeBlockBy1Loop: - ComputeBlock RowCount, 0, 0 - add edx,16*4 ; advance matrix B by 16 columns - add ecx,4 ; advance matrix A by 1 column - dec edi - jne ComputeBlockBy1Loop - -OutputBlock: - - ENDM - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A - Supplies the address of matrix A. -; -; B - Supplies the address of matrix B. The matrix data has been packed using -; MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C - Supplies the address of matrix C. -; -; CountK - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see SGEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - -cPublicProc _MlasGemmFloatKernelAvx,10 - - SgemmKernelEntry - -; -; Process 2 rows of the matrices. -; - - cmp SgemmKernelFrame.CountM[esp],2 - jb ProcessCountMLessThan2 - mov BYTE PTR SgemmKernelFrame.CountM[esp],2 - mov eax,SgemmKernelFrame.ldc[esp] - mov ebx,SgemmKernelFrame.lda[esp] - shl eax,2 ; convert ldc to bytes - shl ebx,2 ; convert lda to bytes - cmp ebp,8 - jbe ProcessRemainingCountN2 - -ProcessNextColumnLoop16x2: - mov edi,SgemmKernelFrame.CountK[esp] - mov ecx,SgemmKernelFrame.MatrixA[esp] - vxorps xmm4,xmm4,xmm4 ; clear block accumulators - vxorps xmm5,xmm5,xmm5 - vxorps xmm6,xmm6,xmm6 - vxorps xmm7,xmm7,xmm7 - ComputeBlockAvxLoop ComputeBlockAvxBy16,2 - vbroadcastss ymm2,DWORD PTR SgemmKernelFrame.Alpha[esp] - vmulps ymm4,ymm4,ymm2 ; multiply by alpha - vmulps ymm5,ymm5,ymm2 - vmulps ymm6,ymm6,ymm2 - vmulps ymm7,ymm7,ymm2 - sub ebp,16 - jb OutputMasked16x2Block - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateOutput16x2 - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm5,ymm5,YMMWORD PTR [esi+32] - vaddps ymm6,ymm6,YMMWORD PTR [esi+eax] - vaddps ymm7,ymm7,YMMWORD PTR [esi+eax+32] - -SkipAccumulateOutput16x2: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+32],ymm5 - vmovups YMMWORD PTR [esi+eax],ymm6 - vmovups YMMWORD PTR [esi+eax+32],ymm7 - add esi,16*4 ; advance matrix C by 16 columns - cmp ebp,8 - ja ProcessNextColumnLoop16x2 - test ebp,ebp - jz ExitKernel - -ProcessRemainingCountN2: - mov edi,SgemmKernelFrame.CountK[esp] - mov ecx,SgemmKernelFrame.MatrixA[esp] - vxorps xmm5,xmm5,xmm5 ; clear block accumulators - vxorps xmm7,xmm7,xmm7 - ComputeBlockAvxLoop ComputeBlockAvxBy8,2 - vbroadcastss ymm2,DWORD PTR SgemmKernelFrame.Alpha[esp] - vmulps ymm5,ymm5,ymm2 ; multiply by alpha - vmulps ymm7,ymm7,ymm2 - cmp ebp,8 - jb OutputMasked8x2Block - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateOutput8x2 - vaddps ymm5,ymm5,YMMWORD PTR [esi] - vaddps ymm7,ymm7,YMMWORD PTR [esi+eax] - -SkipAccumulateOutput8x2: - vmovups YMMWORD PTR [esi],ymm5 - vmovups YMMWORD PTR [esi+eax],ymm7 - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - movzx eax,BYTE PTR SgemmKernelFrame.CountM[esp] - vzeroupper - SgemmKernelExit - stdRET _MlasGemmFloatKernelAvx - -OutputMasked16x2Block: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateMasked16x2Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm6,ymm6,YMMWORD PTR [esi+eax] - -SkipAccumulateMasked16x2Block: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+eax],ymm6 - add esi,8*4 ; advance matrix C by 8 columns - add ebp,8 ; correct for over-subtract above - -OutputMasked8x2Block: - neg ebp - vmovdqu ymm0,YMMWORD PTR [_MlasMaskMoveTableAvx+ebp*4+8*4] - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateMasked8x2Block - vmaskmovps ymm4,ymm0,YMMWORD PTR [esi] - vmaskmovps ymm6,ymm0,YMMWORD PTR [esi+eax] - vaddps ymm5,ymm5,ymm4 - vaddps ymm7,ymm7,ymm6 - -SkipAccumulateMasked8x2Block: - vmaskmovps YMMWORD PTR [esi],ymm0,ymm5 - vmaskmovps YMMWORD PTR [esi+eax],ymm0,ymm7 - jmp ExitKernel - -; -; Process 1 row of the matrices. -; - -ProcessCountMLessThan2: - mov BYTE PTR SgemmKernelFrame.CountM[esp],1 - mov ebx,SgemmKernelFrame.MatrixA[esp] - vbroadcastss ymm2,DWORD PTR SgemmKernelFrame.Alpha[esp] - cmp ebp,8 - jbe ProcessRemainingCountN1 - -ProcessNextColumnLoop16x1: - mov edi,SgemmKernelFrame.CountK[esp] - mov ecx,ebx ; reload matrix A - vxorps xmm4,xmm4,xmm4 ; clear block accumulators - vxorps xmm5,xmm5,xmm5 - ComputeBlockAvxLoop ComputeBlockAvxBy16,1 - vmulps ymm4,ymm4,ymm2 ; multiply by alpha - vmulps ymm5,ymm5,ymm2 - sub ebp,16 - jb OutputMasked16x1Block - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulate16x1Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm5,ymm5,YMMWORD PTR [esi+32] - -SkipAccumulate16x1Block: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+32],ymm5 - add esi,16*4 ; advance matrix C by 16 columns - cmp ebp,8 - ja ProcessNextColumnLoop16x1 - test ebp,ebp - jz ExitKernel - -ProcessRemainingCountN1: - mov edi,SgemmKernelFrame.CountK[esp] - mov ecx,ebx ; reload matrix A - vxorps xmm5,xmm5,xmm5 ; clear block accumulators - ComputeBlockAvxLoop ComputeBlockAvxBy8,1 - vmulps ymm5,ymm5,ymm2 ; multiply by alpha - cmp ebp,8 - jb OutputMasked8x1Block - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulate8x1Block - vaddps ymm5,ymm5,YMMWORD PTR [esi] - -SkipAccumulate8x1Block: - vmovups YMMWORD PTR [esi],ymm5 - jmp ExitKernel - -OutputMasked16x1Block: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateMasked16x1Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - -SkipAccumulateMasked16x1Block: - vmovups YMMWORD PTR [esi],ymm4 - add esi,8*4 ; advance matrix C by 8 columns - add ebp,8 ; correct for over-subtract above - -OutputMasked8x1Block: - neg ebp - vmovdqu ymm0,YMMWORD PTR [_MlasMaskMoveTableAvx+ebp*4+8*4] - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateMasked8x1Block - vmaskmovps ymm4,ymm0,YMMWORD PTR [esi] - vaddps ymm5,ymm5,ymm4 - -SkipAccumulateMasked8x1Block: - vmaskmovps YMMWORD PTR [esi],ymm0,ymm5 - jmp ExitKernel - -stdENDP _MlasGemmFloatKernelAvx - -_TEXT ENDS - - END diff --git a/onnxruntime/core/mlas/lib/i386/SgemmKernelCommon.inc b/onnxruntime/core/mlas/lib/i386/SgemmKernelCommon.inc deleted file mode 100644 index 686bd35007b91..0000000000000 --- a/onnxruntime/core/mlas/lib/i386/SgemmKernelCommon.inc +++ /dev/null @@ -1,94 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelCommon.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the single -; precision matrix/matrix multiply operation (SGEMM). -; -;-- - -; -; Stack frame layout for the SGEMM kernels. -; - -SgemmKernelFrame STRUCT - - SavedEdi DWORD ? - SavedEsi DWORD ? - SavedEbx DWORD ? - SavedEbp DWORD ? - ReturnAddress DWORD ? - MatrixA DWORD ? - MatrixB DWORD ? - MatrixC DWORD ? - CountK DWORD ? - CountM DWORD ? - CountN DWORD ? - lda DWORD ? - ldc DWORD ? - Alpha DWORD ? - ZeroMode DWORD ? - -SgemmKernelFrame ENDS - -; -; Macro Description: -; -; This macro implements the common prologue code for the SGEMM kernels. -; -; Arguments: -; -; None. -; -; Return Registers: -; -; ecx - Stores the address of the matrix A data from the stack frame. -; -; edx - Stores the address of the matrix B data from the stack frame. -; -; ebp - Stores the CountN argument from the stack frame. -; -; ebx, esi, edi - Previous values stored on the stack and the registers are -; available as temporaries. -; - -SgemmKernelEntry MACRO - - push ebp - push ebx - push esi - push edi - mov edx,SgemmKernelFrame.MatrixB[esp] - mov esi,SgemmKernelFrame.MatrixC[esp] - mov ebp,SgemmKernelFrame.CountN[esp] - -cPublicFpo ((SgemmKernelFrame.ReturnAddress)/4),10 - - ENDM - -; -; Macro Description: -; -; This macro implements the common epilogue code for the SGEMM kernels. -; -; Arguments: -; -; None. -; - -SgemmKernelExit MACRO - - pop edi - pop esi - pop ebx - pop ebp - - ENDM diff --git a/onnxruntime/core/mlas/lib/i386/SgemmKernelSse2.asm b/onnxruntime/core/mlas/lib/i386/SgemmKernelSse2.asm deleted file mode 100644 index 7251beae1e62c..0000000000000 --- a/onnxruntime/core/mlas/lib/i386/SgemmKernelSse2.asm +++ /dev/null @@ -1,388 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; SgemmKernelSse2.asm -; -; Abstract: -; -; This module implements the kernels for the single precision matrix/matrix -; multiply operation (SGEMM). -; -; This implementation uses SSE2 instructions. -; -;-- - - .686 - .xmm - - .xlist -INCLUDE mlasi.inc -INCLUDE SgemmKernelCommon.inc - .list - - ASSUME DS:FLAT,ES:FLAT,SS:NOTHING,FS:NOTHING,GS:NOTHING - -_TEXT SEGMENT DWORD PUBLIC 'CODE' - -; -; Macro Description: -; -; This macro multiplies and accumulates for a Nx1 block of the output matrix. -; -; Arguments: -; -; VectorOffset - Supplies the byte offset from matrix B to fetch elements. -; -; Shuffle - Supplies the shuffle mask to extract the element from matrix A. -; -; Implicit Arguments: -; -; ebx - Supplies the length in bytes of a row from matrix A. -; -; ecx - Supplies the address into the matrix A data. -; -; edx - Supplies the address into the matrix B data. -; -; xmm2 - Supplies up to four elements loaded from matrix A. -; -; xmm4-xmm7 - Supplies the block accumulators. -; - -ComputeBlockSseBy4 MACRO VectorOffset, Shuffle - - pshufd xmm3,xmm1,Shuffle - movaps xmm0,XMMWORD PTR [edx+VectorOffset] - mulps xmm0,xmm3 - addps xmm4,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+16] - mulps xmm0,xmm3 - addps xmm5,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+32] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+48] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - ENDM - -ComputeBlockSseBy3 MACRO VectorOffset, Shuffle - - pshufd xmm3,xmm1,Shuffle - movaps xmm0,XMMWORD PTR [edx+VectorOffset] - mulps xmm0,xmm3 - addps xmm5,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+16] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+32] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - ENDM - -ComputeBlockSseBy2 MACRO VectorOffset, Shuffle - - pshufd xmm3,xmm1,Shuffle - movaps xmm0,XMMWORD PTR [edx+VectorOffset] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+VectorOffset+16] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - ENDM - -ComputeBlockSseBy1 MACRO VectorOffset, Shuffle - - pshufd xmm3,xmm1,Shuffle - movaps xmm0,XMMWORD PTR [edx+VectorOffset] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - ENDM - -; -; Macro Description: -; -; This macro generates code to execute the block compute macro multiple -; times and advancing the matrix A and matrix B data pointers. -; -; Arguments: -; -; ComputeBlock - Supplies the macro to compute a single block. -; -; RowCount - Supplies the number of rows to process. -; -; Implicit Arguments: -; -; ebx - Supplies the number of bytes to the next row of matrix A. -; -; ecx - Supplies the address into the matrix A data. -; -; edx - Supplies the address into the matrix B data. -; -; edi - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; xmm4-xmm7 - Supplies the block accumulators. -; - -ComputeBlockSseLoop MACRO RowCount - - LOCAL ComputeBlockBy4Loop - LOCAL ProcessRemainingBlocks - LOCAL ComputeBlockBy1Loop - LOCAL OutputBlock - - sub edi,4 - jb ProcessRemainingBlocks - -ComputeBlockBy4Loop: - movups xmm1,XMMWORD PTR [ecx] - ComputeBlockSseBy&RowCount 0,000h - ComputeBlockSseBy&RowCount 16*4,055h - sub edx,-32*4 ; advance matrix B by 32 columns - ComputeBlockSseBy&RowCount 0,0AAh - ComputeBlockSseBy&RowCount 16*4,0FFh - sub edx,-32*4 ; advance matrix B by 32 columns - add ecx,4*4 ; advance matrix A by 4 columns - sub edi,4 - jae ComputeBlockBy4Loop - -ProcessRemainingBlocks: - add edi,4 ; correct for over-subtract above - jz OutputBlock - -ComputeBlockBy1Loop: - movss xmm1,DWORD PTR [ecx] - ComputeBlockSseBy&RowCount 0,000h - add edx,16*4 ; advance matrix B by 16 columns - add ecx,4 ; advance matrix A by 1 column - dec edi - jne ComputeBlockBy1Loop - -OutputBlock: - - ENDM - -;++ -; -; Routine Description: -; -; This routine is an inner kernel to compute matrix multiplication for a -; set of rows. -; -; Arguments: -; -; A - Supplies the address of matrix A. -; -; B - Supplies the address of matrix B. The matrix data has been packed using -; MlasSgemmCopyPackB or MlasSgemmTransposePackB. -; -; C - Supplies the address of matrix C. -; -; CountK - Supplies the number of columns from matrix A and the number of rows -; from matrix B to iterate over. -; -; CountM - Supplies the maximum number of rows that can be processed for -; matrix A and matrix C. The actual number of rows handled for this -; invocation depends on the kernel implementation. -; -; CountN - Supplies the number of columns from matrix B and matrix C to iterate -; over. -; -; lda - Supplies the first dimension of matrix A. -; -; ldc - Supplies the first dimension of matrix C. -; -; Alpha - Supplies the scalar alpha multiplier (see SGEMM definition). -; -; ZeroMode - Supplies true if the output matrix must be zero initialized, -; else false if the output matrix is accumulated into. -; -; Return Value: -; -; Returns the number of rows handled. -; -;-- - -cPublicProc _MlasGemmFloatKernelSse,10 - - SgemmKernelEntry - -; -; Process 1 row of the matrices. -; - - mov eax,SgemmKernelFrame.CountK[esp] - mov ebx,SgemmKernelFrame.MatrixA[esp] - cmp ebp,12 - jbe ProcessRemainingCountN - -ProcessNextColumnLoop16x1: - mov edi,eax ; reload CountK - mov ecx,ebx ; reload matrix A - xorps xmm4,xmm4 ; clear block accumulators - xorps xmm5,xmm5 - xorps xmm6,xmm6 - xorps xmm7,xmm7 - ComputeBlockSseLoop 4 - movss xmm2,DWORD PTR SgemmKernelFrame.Alpha[esp] - shufps xmm2,xmm2,0 - mulps xmm4,xmm2 ; multiply by alpha - mulps xmm5,xmm2 - mulps xmm6,xmm2 - mulps xmm7,xmm2 - sub ebp,16 - jb OutputMasked16x1Block - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateOutput16x1 - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - movups xmm2,XMMWORD PTR [esi+32] - movups xmm3,XMMWORD PTR [esi+48] - addps xmm4,xmm0 - addps xmm5,xmm1 - addps xmm6,xmm2 - addps xmm7,xmm3 - -SkipAccumulateOutput16x1: - movups XMMWORD PTR [esi],xmm4 - movups XMMWORD PTR [esi+16],xmm5 - movups XMMWORD PTR [esi+32],xmm6 - movups XMMWORD PTR [esi+48],xmm7 - add esi,16*4 ; advance matrix C by 16 columns - cmp ebp,12 - ja ProcessNextColumnLoop16x1 - test ebp,ebp - jnz ProcessRemainingCountN - -; -; Restore non-volatile registers and return. -; - -ExitKernel: - mov eax,1 ; return 1 row handled - SgemmKernelExit - stdRET _MlasGemmFloatKernelSse - -; -; Process the remaining 1 to 12 columns of the matrices. -; - -ProcessRemainingCountN: - mov edi,eax ; reload CountK - mov ecx,ebx ; reload matrix A - movss xmm4,DWORD PTR SgemmKernelFrame.Alpha[esp] - shufps xmm4,xmm4,0 - xorps xmm5,xmm5 ; clear block accumulators - xorps xmm6,xmm6 - xorps xmm7,xmm7 - cmp ebp,4 - jbe ProcessRemainingCountN4OrLess - cmp ebp,8 - jbe ProcessRemainingCountN8OrLess - -ProcessRemainingCountN12OrLess: - ComputeBlockSseLoop 3 - mulps xmm5,xmm4 ; multiply by alpha - mulps xmm6,xmm4 - mulps xmm7,xmm4 - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateLeadingN12OrLess - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - addps xmm5,xmm0 - addps xmm6,xmm1 - -SkipAccumulateLeadingN12OrLess: - movups XMMWORD PTR [esi],xmm5 - movups XMMWORD PTR [esi+16],xmm6 - add esi,8*4 ; advance matrix C by 8 columns - jmp OutputTrailingBlock - -ProcessRemainingCountN8OrLess: - ComputeBlockSseLoop 2 - mulps xmm6,xmm4 ; multiply by alpha - mulps xmm7,xmm4 - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateLeadingN8OrLess - movups xmm0,XMMWORD PTR [esi] - addps xmm6,xmm0 - -SkipAccumulateLeadingN8OrLess: - movups XMMWORD PTR [esi],xmm6 - add esi,4*4 ; advance matrix C by 4 columns - jmp OutputTrailingBlock - -ProcessRemainingCountN4OrLess: - ComputeBlockSseLoop 1 - mulps xmm7,xmm4 ; multiply by alpha - jmp OutputTrailingBlock - -OutputMasked16x1Block: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateLeading16x1Block - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - movups xmm2,XMMWORD PTR [esi+32] - addps xmm4,xmm0 - addps xmm5,xmm1 - addps xmm6,xmm2 - -SkipAccumulateLeading16x1Block: - movups XMMWORD PTR [esi],xmm4 - movups XMMWORD PTR [esi+16],xmm5 - movups XMMWORD PTR [esi+32],xmm6 - add esi,12*4 ; advance matrix C by 12 columns - -OutputTrailingBlock: - test ebp,3 - jz OutputTrailingBlock4Elements - test ebp,2 - jz OutputTrailingBlock1Element - -OutputTrailingBlock2Elements: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateTrailingBlock2Elements - movsd xmm0,MMWORD PTR [esi] - addps xmm7,xmm0 - -SkipAccumulateTrailingBlock2Elements: - movsd MMWORD PTR [esi],xmm7 - test ebp,1 - jz ExitKernel - shufps xmm7,xmm7,0AAh ; shuffle third float down - add esi,2*4 ; advance matrix C by 2 columns - -OutputTrailingBlock1Element: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateTrailingBlock1Element - movss xmm0,DWORD PTR [esi] - addss xmm7,xmm0 - -SkipAccumulateTrailingBlock1Element: - movss DWORD PTR [esi],xmm7 - jmp ExitKernel - -OutputTrailingBlock4Elements: - cmp BYTE PTR SgemmKernelFrame.ZeroMode[esp],0 - jnz SkipAccumulateTrailingBlock4Elements - movups xmm0,XMMWORD PTR [esi] - addps xmm7,xmm0 - -SkipAccumulateTrailingBlock4Elements: - movups XMMWORD PTR [esi],xmm7 - jmp ExitKernel - -stdENDP _MlasGemmFloatKernelSse - -_TEXT ENDS - - END diff --git a/onnxruntime/core/mlas/lib/i386/mlasi.inc b/onnxruntime/core/mlas/lib/i386/mlasi.inc deleted file mode 100644 index 7c8ae537ff291..0000000000000 --- a/onnxruntime/core/mlas/lib/i386/mlasi.inc +++ /dev/null @@ -1,70 +0,0 @@ -;++ -; -; Copyright (c) Microsoft Corporation. All rights reserved. -; -; Licensed under the MIT License. -; -; Module Name: -; -; mlasi.inc -; -; Abstract: -; -; This module contains common kernel macros and structures for the Microsoft -; Machine Learning algebra subprogram library. -; -;-- - - .xlist -INCLUDE callconv.inc - .list - -; -; Macro Description: -; -; This macro conditionally emits the statement if Count is greater than or -; equal to Value. -; -; Arguments: -; -; Count - Supplies the variable used in the comparison. -; -; Value - Supplies the static used in the comparison. -; -; Statement - Supplies the statement to conditionally emit. -; - -EmitIfCountGE MACRO Count, Value, Statement - -IF (Count GE Value) - Statement -ENDIF - - ENDM - -; -; Macro Description: -; -; This macro conditionally emits the statement if Count1 is greater than or -; equal to Value1 and Count2 is greater than or equal to Value2. -; -; Arguments: -; -; Count1 - Supplies the variable used in the comparison. -; -; Value1 - Supplies the static used in the comparison. -; -; Count2 - Supplies the variable used in the comparison. -; -; Value2 - Supplies the static used in the comparison. -; -; Statement - Supplies the statement to conditionally emit. -; - -EmitIfCount2GE MACRO Count1, Value1, Count2, Value2, Statement - -IF (Count1 GE Value1) AND (Count2 GE Value2) - Statement -ENDIF - - ENDM diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx/min_max_elements.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx/min_max_elements.cpp deleted file mode 100644 index 0339f65eacfcf..0000000000000 --- a/onnxruntime/core/mlas/lib/intrinsics/avx/min_max_elements.cpp +++ /dev/null @@ -1,106 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - min_max_elements.cpp - -Abstract: - - This module implements the logic to find min and max elements with AVX instructions. - ---*/ - -#include "mlasi.h" - -void -MLASCALL -MlasReduceMinimumMaximumF32KernelAvx( - const float* Input, - float* Min, - float* Max, - size_t N - ) -{ - float tmp_min = std::numeric_limits::max(); - float tmp_max = std::numeric_limits::lowest(); - - if (N >= 8) { - - __m256 MaximumVector0 = _mm256_set1_ps(tmp_max); - __m256 MinimumVector0 = _mm256_set1_ps(tmp_min); - - if (N >= 32) { - - __m256 MaximumVector1 = MaximumVector0; - __m256 MaximumVector2 = MaximumVector0; - __m256 MaximumVector3 = MaximumVector0; - - __m256 MinimumVector1 = MinimumVector0; - __m256 MinimumVector2 = MinimumVector0; - __m256 MinimumVector3 = MinimumVector0; - - while (N >= 32) { - - __m256 InputVector0 = _mm256_loadu_ps(Input); - __m256 InputVector1 = _mm256_loadu_ps(Input + 8); - __m256 InputVector2 = _mm256_loadu_ps(Input + 16); - __m256 InputVector3 = _mm256_loadu_ps(Input + 24); - - MaximumVector0 = _mm256_max_ps(MaximumVector0, InputVector0); - MaximumVector1 = _mm256_max_ps(MaximumVector1, InputVector1); - MaximumVector2 = _mm256_max_ps(MaximumVector2, InputVector2); - MaximumVector3 = _mm256_max_ps(MaximumVector3, InputVector3); - - MinimumVector0 = _mm256_min_ps(MinimumVector0, InputVector0); - MinimumVector1 = _mm256_min_ps(MinimumVector1, InputVector1); - MinimumVector2 = _mm256_min_ps(MinimumVector2, InputVector2); - MinimumVector3 = _mm256_min_ps(MinimumVector3, InputVector3); - - Input += 32; - N -= 32; - } - - MaximumVector0 = _mm256_max_ps(MaximumVector0, MaximumVector1); - MaximumVector2 = _mm256_max_ps(MaximumVector2, MaximumVector3); - MaximumVector0 = _mm256_max_ps(MaximumVector0, MaximumVector2); - - MinimumVector0 = _mm256_min_ps(MinimumVector0, MinimumVector1); - MinimumVector2 = _mm256_min_ps(MinimumVector2, MinimumVector3); - MinimumVector0 = _mm256_min_ps(MinimumVector0, MinimumVector2); - } - - while (N >= 8) { - - __m256 InputVector0 = _mm256_loadu_ps(Input); - MaximumVector0 = _mm256_max_ps(MaximumVector0, InputVector0); - MinimumVector0 = _mm256_min_ps(MinimumVector0, InputVector0); - - Input += 8; - N -= 8; - } - - __m128 low = _mm256_castps256_ps128(MaximumVector0); - __m128 high = _mm256_extractf128_ps(MaximumVector0, 1); - tmp_max = MlasReduceMaximumFloat32x4(MlasMaximumFloat32x4(low, high)); - - low = _mm256_castps256_ps128(MinimumVector0); - high = _mm256_extractf128_ps(MinimumVector0, 1); - tmp_min = MlasReduceMinimumFloat32x4(MlasMinimumFloat32x4(low, high)); - } - - while (N > 0) { - - tmp_max = std::max(tmp_max, *Input); - tmp_min = std::min(tmp_min, *Input); - - Input += 1; - N -= 1; - } - - *Min = tmp_min; - *Max = tmp_max; -} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx2/qdwconv_avx2.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx2/qdwconv_avx2.cpp deleted file mode 100644 index 4dab50a27a5b9..0000000000000 --- a/onnxruntime/core/mlas/lib/intrinsics/avx2/qdwconv_avx2.cpp +++ /dev/null @@ -1,221 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qdwconv_avx2.cpp - -Abstract: - - This module implements the quantized integer depthwise convolution kernels. - - This implementation uses AVX2 instructions. - ---*/ - -#include "mlasi.h" - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const InputType* const* Input, - InputType InputZeroPoint, - const FilterType* Filter, - FilterType FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -{ - const __m256i InputZeroPointVector = _mm256_set1_epi16(InputZeroPoint); - const __m256i FilterZeroPointVector = _mm256_set1_epi16(FilterZeroPoint); - - while (OutputCount > 0) { - - size_t ChannelOffset = 0; - size_t c = Channels; - - while (c >= 16) { - - __m256i Accumulator0 = _mm256_setzero_si256(); - __m256i Accumulator1 = _mm256_setzero_si256(); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - - __m256i InputVector; - if (std::is_signed::value) { - InputVector = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)&Input[k][ChannelOffset])); - } else { - InputVector = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*)&Input[k][ChannelOffset])); - } - - __m256i FilterVector; - - if (std::is_signed::value) { - FilterVector = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)&Filter[ChannelKernelOffset])); - } else { - FilterVector = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*)&Filter[ChannelKernelOffset])); - } - - InputVector = _mm256_sub_epi16(InputVector, InputZeroPointVector); - FilterVector = _mm256_sub_epi16(FilterVector, FilterZeroPointVector); - - // N.B. The original SSE2 implementation used PMULLW/PMULHW in - // order to emulate the SSE 4.1 PMULLD instruction, however this - // implementation ends up being faster for some CPUs than - // extending to 32-bits and using PMULLD. - __m256i MultiplyLowWords = _mm256_mullo_epi16(InputVector, FilterVector); - __m256i MultiplyHighWords = _mm256_mulhi_epi16(InputVector, FilterVector); - __m256i Multiply0 = _mm256_unpacklo_epi16(MultiplyLowWords, MultiplyHighWords); - __m256i Multiply1 = _mm256_unpackhi_epi16(MultiplyLowWords, MultiplyHighWords); - - Accumulator0 = _mm256_add_epi32(Accumulator0, Multiply0); - Accumulator1 = _mm256_add_epi32(Accumulator1, Multiply1); - ChannelKernelOffset += Channels; - } - - // N.B. The above interleaving of the intermediate results leaves - // the accumulators in a swizzled layout, because the interleaving - // is per 128-bit half of the __m256i register. Reorder the results - // now to get the expected sequential order. - __m256i Reorder0 = _mm256_permute2x128_si256(Accumulator0, Accumulator1, 0x20); - __m256i Reorder1 = _mm256_permute2x128_si256(Accumulator0, Accumulator1, 0x31); - - _mm256_storeu_si256((__m256i*)&Output[0], Reorder0); - _mm256_storeu_si256((__m256i*)&Output[8], Reorder1); - Output += 16; - - ChannelOffset += 16; - c -= 16; - } - - if (c >= 8) { - - __m128i Accumulator0 = _mm_setzero_si128(); - __m128i Accumulator1 = _mm_setzero_si128(); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector = _mm_loadl_epi64((const __m128i*)&Input[k][ChannelOffset]); - __m128i FilterVector = _mm_loadl_epi64((const __m128i*)&Filter[ChannelKernelOffset]); - - if (std::is_signed::value) { - InputVector = _mm_cvtepi8_epi16(InputVector); - } else { - InputVector = _mm_cvtepu8_epi16(InputVector); - } - - if (std::is_signed::value) { - FilterVector = _mm_cvtepi8_epi16(FilterVector); - } else { - FilterVector = _mm_cvtepu8_epi16(FilterVector); - } - - InputVector = _mm_sub_epi16(InputVector, _mm256_castsi256_si128(InputZeroPointVector)); - FilterVector = _mm_sub_epi16(FilterVector, _mm256_castsi256_si128(FilterZeroPointVector)); - - __m128i MultiplyLowWords = _mm_mullo_epi16(InputVector, FilterVector); - __m128i MultiplyHighWords = _mm_mulhi_epi16(InputVector, FilterVector); - __m128i Multiply0 = _mm_unpacklo_epi16(MultiplyLowWords, MultiplyHighWords); - __m128i Multiply1 = _mm_unpackhi_epi16(MultiplyLowWords, MultiplyHighWords); - - Accumulator0 = _mm_add_epi32(Accumulator0, Multiply0); - Accumulator1 = _mm_add_epi32(Accumulator1, Multiply1); - ChannelKernelOffset += Channels; - } - - _mm_storeu_si128((__m128i*)&Output[0], Accumulator0); - _mm_storeu_si128((__m128i*)&Output[4], Accumulator1); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - - while (c > 0) { - - int32_t Accumulator = 0; - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - - int32_t InputValue = int32_t(Input[k][ChannelOffset]) - InputZeroPoint; - int32_t FilterValue = int32_t(Filter[ChannelKernelOffset]) - FilterZeroPoint; - - Accumulator += InputValue * FilterValue; - ChannelKernelOffset += Channels; - } - - *Output++ = Accumulator; - - ChannelOffset += 1; - c -= 1; - } - - Input += KernelSize; - OutputCount -= 1; - } -} - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const uint8_t* const* Input, - uint8_t InputZeroPoint, - const int8_t* Filter, - int8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const uint8_t* const* Input, - uint8_t InputZeroPoint, - const uint8_t* Filter, - uint8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const int8_t* const* Input, - int8_t InputZeroPoint, - const int8_t* Filter, - int8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const int8_t* const* Input, - int8_t InputZeroPoint, - const uint8_t* Filter, - uint8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp deleted file mode 100644 index 8fbee04659271..0000000000000 --- a/onnxruntime/core/mlas/lib/intrinsics/avx2/qladd_avx2.cpp +++ /dev/null @@ -1,251 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qladd_avx2.cpp - -Abstract: - - This module implements routines to quantize linear add using avx2 intrinsics. - - For quantization formula as specified in the ONNX operator documentation is: - - Output = Saturate(RoundToEven(Input / Scale) + ZeroPoint) - ---*/ - -#include "../../qladd.h" - -template -MLAS_FORCEINLINE -static -__m256i -MlasShiftRight24Epi32( - __m256i v - ); - -template <> -MLAS_FORCEINLINE -__m256i -MlasShiftRight24Epi32( - __m256i v - ) -{ - return _mm256_srai_epi32(v, 24); -} - -template <> -MLAS_FORCEINLINE -__m256i -MlasShiftRight24Epi32( - __m256i v - ) -{ - return _mm256_srli_epi32(v, 24); -} - -template -MLAS_FORCEINLINE -static -__m256i -MlasPackS16_256( - __m256i a, - __m256i b - ); - -template <> -MLAS_FORCEINLINE -__m256i -MlasPackS16_256( - __m256i a, - __m256i b - ) -{ - return _mm256_packus_epi16(a, b); -} - -template <> -MLAS_FORCEINLINE -__m256i -MlasPackS16_256( - __m256i a, - __m256i b - ) -{ - return _mm256_packs_epi16(a, b); -} - -MLAS_FORCEINLINE -static -__m256i -MlasLoad32Bytes(const uint8_t* buffer, int64_t N) -{ - if (N >= 32) { - return _mm256_lddqu_si256((const __m256i*)buffer); - } else { - uint8_t dup[32]; - MlasCopyTailBytes(dup, buffer, (size_t)N); - return _mm256_lddqu_si256((const __m256i*)dup); - } -} - -template -static -void -MlasQLinearAddKernelAvx2Helper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float ScaleRatio_AC = ScaleA / ScaleC; - const float ScaleRatio_BC = ScaleB / ScaleC; - const __m256 VectorScaleRatio_AC = _mm256_set1_ps(ScaleRatio_AC); - const __m256 VectorScaleRatio_BC = _mm256_set1_ps(ScaleRatio_BC); - __m256 VectorFixedPart = _mm256_set1_ps((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); - - if (IsScalarB) { - const auto vb_f32x8 = _mm256_set1_ps((float)(int32_t)*InputB); - VectorFixedPart = _mm256_add_ps(VectorFixedPart, _mm256_mul_ps(vb_f32x8, VectorScaleRatio_BC)); - } - - int64_t n = static_cast(N); - __m256i vc = _mm256_setzero_si256(); - while (n > 0) { - __m256i va_i8x32, vb_i8x32; - va_i8x32 = MlasLoad32Bytes((const uint8_t*)InputA, n); - InputA += 32; - - if (!IsScalarB) { - vb_i8x32 = MlasLoad32Bytes((const uint8_t*)InputB, n); - InputB += 32; - } - - __m256 lolo_f32x8, lohi_f32x8, hilo_f32x8, hihi_f32x8; - if (IsScalarB) { - const auto alo_i16x16 = _mm256_unpacklo_epi8(va_i8x32, va_i8x32); - const auto ahi_i16x16 = _mm256_unpackhi_epi8(va_i8x32, va_i8x32); - lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(alo_i16x16, alo_i16x16))); - lohi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(alo_i16x16, alo_i16x16))); - hilo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(ahi_i16x16, ahi_i16x16))); - hihi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(ahi_i16x16, ahi_i16x16))); - lolo_f32x8 = _mm256_fmadd_ps(lolo_f32x8, VectorScaleRatio_AC, VectorFixedPart); - lohi_f32x8 = _mm256_fmadd_ps(lohi_f32x8, VectorScaleRatio_AC, VectorFixedPart); - hilo_f32x8 = _mm256_fmadd_ps(hilo_f32x8, VectorScaleRatio_AC, VectorFixedPart); - hihi_f32x8 = _mm256_fmadd_ps(hihi_f32x8, VectorScaleRatio_AC, VectorFixedPart); - } else { - const auto blo_i16x16 = _mm256_unpacklo_epi8(vb_i8x32, vb_i8x32); - const auto bhi_i16x16 = _mm256_unpackhi_epi8(vb_i8x32, vb_i8x32); - lolo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(blo_i16x16, blo_i16x16))); - lohi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(blo_i16x16, blo_i16x16))); - hilo_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(bhi_i16x16, bhi_i16x16))); - hihi_f32x8 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(bhi_i16x16, bhi_i16x16))); - lolo_f32x8 = _mm256_fmadd_ps(lolo_f32x8, VectorScaleRatio_BC, VectorFixedPart); - lohi_f32x8 = _mm256_fmadd_ps(lohi_f32x8, VectorScaleRatio_BC, VectorFixedPart); - hilo_f32x8 = _mm256_fmadd_ps(hilo_f32x8, VectorScaleRatio_BC, VectorFixedPart); - hihi_f32x8 = _mm256_fmadd_ps(hihi_f32x8, VectorScaleRatio_BC, VectorFixedPart); - - const auto alo_i16x16 = _mm256_unpacklo_epi8(va_i8x32, va_i8x32); - const auto alolo_8xfp32 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(alo_i16x16, alo_i16x16))); - const auto alohi_8xfp32 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(alo_i16x16, alo_i16x16))); - const auto ahi_i16x16 = _mm256_unpackhi_epi8(va_i8x32, va_i8x32); - const auto ahilo_8xfp32 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpacklo_epi16(ahi_i16x16, ahi_i16x16))); - const auto ahihi_8xfp32 = _mm256_cvtepi32_ps(MlasShiftRight24Epi32(_mm256_unpackhi_epi16(ahi_i16x16, ahi_i16x16))); - lolo_f32x8 = _mm256_fmadd_ps(alolo_8xfp32, VectorScaleRatio_AC, lolo_f32x8); - lohi_f32x8 = _mm256_fmadd_ps(alohi_8xfp32, VectorScaleRatio_AC, lohi_f32x8); - hilo_f32x8 = _mm256_fmadd_ps(ahilo_8xfp32, VectorScaleRatio_AC, hilo_f32x8); - hihi_f32x8 = _mm256_fmadd_ps(ahihi_8xfp32, VectorScaleRatio_AC, hihi_f32x8); - } - - const auto vc02 = _mm256_packs_epi32(_mm256_cvtps_epi32(lolo_f32x8), _mm256_cvtps_epi32(lohi_f32x8)); - const auto vc13 = _mm256_packs_epi32(_mm256_cvtps_epi32(hilo_f32x8), _mm256_cvtps_epi32(hihi_f32x8)); - vc = MlasPackS16_256(vc02, vc13); - - n -= 32; - if (n < 0) break; - - _mm256_storeu_si256((__m256i*)OutputC, vc); - OutputC += 32; - } - - if (n < 0) { - n += 32; - int k = static_cast(n / 4); - if (k > 0) { - const __m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(k), _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0)); - _mm256_maskstore_epi32((int*)OutputC, mask, vc); - OutputC += static_cast(k) * 4; - } - - int r = static_cast(n % 4); - if (r > 0) { - auto permuted = _mm256_permutevar8x32_epi32(vc, _mm256_set1_epi32(k)); - uint32_t PackedValueC = (uint32_t)_mm256_extract_epi32(permuted, 0); - for (int i = 0; i < r; ++i) { - *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; - PackedValueC >>= 8; - } - } - } -} - -void -MLASCALL -MlasQLinearAddS8KernelAvx2( - const int8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const int8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - int8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ - if (IsScalarB) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } -} - -void -MLASCALL -MlasQLinearAddU8KernelAvx2( - const uint8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const uint8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - uint8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ - if (IsScalarB) { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else { - MlasQLinearAddKernelAvx2Helper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } -} diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp deleted file mode 100644 index 47d2dce45cc57..0000000000000 --- a/onnxruntime/core/mlas/lib/intrinsics/avx512/quantize_avx512f.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - quantize_avx512f.cpp - -Abstract: - - This module implements routines to quantize buffers with AVX512F instructions. - - For quantization formula as specified in the ONNX operator documentation is: - - Output = Saturate(RoundToEven(Input / Scale) + ZeroPoint) - ---*/ - -#include "mlasi.h" - -#ifndef _MM_K0_REG16 -#define _MM_K0_REG16 0xffff -#endif - -// -// QuantizeLinear implementation using AVX512 intrinsics. -// - -template -void -MLASCALL -MlasQuantizeLinearAvx512F( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters with AVX512 instructions. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - constexpr int32_t MinimumValue = std::numeric_limits::min(); - constexpr int32_t MaximumValue = std::numeric_limits::max(); - - auto ScaleVector = _mm512_set1_ps(Scale); - auto MinimumValueVector = _mm512_set1_ps(float(MinimumValue - ZeroPoint)); - auto MaximumValueVector = _mm512_set1_ps(float(MaximumValue - ZeroPoint)); - auto ZeroPointVector = _mm512_set1_epi32(ZeroPoint); - - while (N >= 64) { - - auto FloatVector0 = _mm512_loadu_ps(Input); - auto FloatVector1 = _mm512_loadu_ps(Input + 16); - auto FloatVector2 = _mm512_loadu_ps(Input + 32); - auto FloatVector3 = _mm512_loadu_ps(Input + 48); - - FloatVector0 = _mm512_div_ps(FloatVector0, ScaleVector); - FloatVector1 = _mm512_div_ps(FloatVector1, ScaleVector); - FloatVector2 = _mm512_div_ps(FloatVector2, ScaleVector); - FloatVector3 = _mm512_div_ps(FloatVector3, ScaleVector); - - FloatVector0 = _mm512_max_ps(FloatVector0, MinimumValueVector); - FloatVector1 = _mm512_max_ps(FloatVector1, MinimumValueVector); - FloatVector2 = _mm512_max_ps(FloatVector2, MinimumValueVector); - FloatVector3 = _mm512_max_ps(FloatVector3, MinimumValueVector); - - FloatVector0 = _mm512_min_ps(FloatVector0, MaximumValueVector); - FloatVector1 = _mm512_min_ps(FloatVector1, MaximumValueVector); - FloatVector2 = _mm512_min_ps(FloatVector2, MaximumValueVector); - FloatVector3 = _mm512_min_ps(FloatVector3, MaximumValueVector); - - auto IntegerVector0 = _mm512_cvtps_epi32(FloatVector0); - auto IntegerVector1 = _mm512_cvtps_epi32(FloatVector1); - auto IntegerVector2 = _mm512_cvtps_epi32(FloatVector2); - auto IntegerVector3 = _mm512_cvtps_epi32(FloatVector3); - - IntegerVector0 = _mm512_add_epi32(IntegerVector0, ZeroPointVector); - IntegerVector1 = _mm512_add_epi32(IntegerVector1, ZeroPointVector); - IntegerVector2 = _mm512_add_epi32(IntegerVector2, ZeroPointVector); - IntegerVector3 = _mm512_add_epi32(IntegerVector3, ZeroPointVector); - - _mm512_mask_cvtepi32_storeu_epi8(Output, _MM_K0_REG16, IntegerVector0); - _mm512_mask_cvtepi32_storeu_epi8(Output + 16, _MM_K0_REG16, IntegerVector1); - _mm512_mask_cvtepi32_storeu_epi8(Output + 32, _MM_K0_REG16, IntegerVector2); - _mm512_mask_cvtepi32_storeu_epi8(Output + 48, _MM_K0_REG16, IntegerVector3); - - Input += 64; - Output += 64; - N -= 64; - } - - while (N >= 16) { - auto FloatVector = _mm512_loadu_ps(Input); - FloatVector = _mm512_div_ps(FloatVector, ScaleVector); - FloatVector = _mm512_max_ps(FloatVector, MinimumValueVector); - FloatVector = _mm512_min_ps(FloatVector, MaximumValueVector); - - auto IntegerVector = _mm512_cvtps_epi32(FloatVector); - IntegerVector = _mm512_add_epi32(IntegerVector, ZeroPointVector); - - _mm512_mask_cvtepi32_storeu_epi8(Output, _MM_K0_REG16, IntegerVector); - - Input += 16; - Output += 16; - N -= 16; - } - - if (N > 0) { - __mmask16 mask = uint16_t((uint32_t(1) << N) - uint32_t(1)); - auto FloatVector = _mm512_maskz_loadu_ps(mask, Input); - FloatVector = _mm512_div_ps(FloatVector, ScaleVector); - FloatVector = _mm512_max_ps(FloatVector, MinimumValueVector); - FloatVector = _mm512_min_ps(FloatVector, MaximumValueVector); - - auto IntegerVector = _mm512_cvtps_epi32(FloatVector); - IntegerVector = _mm512_add_epi32(IntegerVector, ZeroPointVector); - - _mm512_mask_cvtepi32_storeu_epi8(Output, mask, IntegerVector); - } -} - -void -MLASCALL -MlasQuantizeLinearU8KernelAvx512F( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - MlasQuantizeLinearAvx512F(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS8KernelAvx512F( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearAvx512F(Input, Output, N, Scale, ZeroPoint); -} diff --git a/onnxruntime/core/mlas/lib/logistic.cpp b/onnxruntime/core/mlas/lib/logistic.cpp deleted file mode 100644 index ecca39f974155..0000000000000 --- a/onnxruntime/core/mlas/lib/logistic.cpp +++ /dev/null @@ -1,186 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - logistic.cpp - -Abstract: - - This module implements routines to compute the logistic function. - - This implementation uses the same polynomial coefficients and algorithm as - found in Eigen. Our usage requires building platform specific versions of - the algorithm to target different instruction sets. The implementation below - targets the base instruction set (typically SSE2) while assembly - implementations target newer instruction sets (such as FMA3). - ---*/ - -#include "mlasi.h" - -// -// Bundles the floating point constants for use by kernels written in assembly. -// - -MLAS_INTERNAL_DATA const struct { - float LowerRange; - float UpperRange; - float alpha_9; - float alpha_7; - float alpha_5; - float alpha_3; - float alpha_1; - float beta_10; - float beta_8; - float beta_6; - float beta_4; - float beta_2; - float beta_0; - float one_half; -} MlasLogisticConstants = { - -18.0f, - 18.0f, - 4.37031012579801e-11f, - 1.15627324459942e-07f, - 6.08574864600143e-05f, - 8.51377133304701e-03f, - 2.48287947061529e-01f, - 6.10247389755681e-13f, - 5.76102136993427e-09f, - 6.29106785017040e-06f, - 1.70198817374094e-03f, - 1.16817656904453e-01f, - 9.93151921023180e-01f, - 0.5f, -}; - -void -MLASCALL -MlasLogisticKernel( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine implements the generic kernel for the logistic function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ - while (N >= 4) { - - MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input); - - Value = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasLogisticConstants.LowerRange), Value); - Value = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasLogisticConstants.UpperRange), Value); - - MLAS_FLOAT32X4 ValueSquared = MlasMultiplyFloat32x4(Value, Value); - - MLAS_FLOAT32X4 p; - p = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.alpha_9), - MlasBroadcastFloat32x4(MlasLogisticConstants.alpha_7)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.alpha_5)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.alpha_3)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.alpha_1)); - p = MlasMultiplyFloat32x4(p, Value); - - MLAS_FLOAT32X4 q; - q = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_10), - MlasBroadcastFloat32x4(MlasLogisticConstants.beta_8)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_6)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_4)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_2)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_0)); - - MlasStoreFloat32x4(Output, MlasAddFloat32x4(MlasDivideFloat32x4(p, q), MlasBroadcastFloat32x4(0.5f))); - - Input += 4; - Output += 4; - N -= 4; - } - - while (N > 0) { - - float Value = *Input++; - - // This odd two-step process exists to ensure an input value of NaN carries through - // without modification because "std::min" and "std::max" return unreliable results - // when NaNs are involved, and it's clear from the test's reference outputs that - // they want a NaN on output whenever the input is a NaN. - float v_tmp; - v_tmp = (Value < MlasLogisticConstants.LowerRange) ? MlasLogisticConstants.LowerRange : Value; - Value = (v_tmp > MlasLogisticConstants.UpperRange) ? MlasLogisticConstants.UpperRange : v_tmp; - - float ValueSquared = Value * Value; - - float p; - p = ValueSquared * MlasLogisticConstants.alpha_9 + MlasLogisticConstants.alpha_7; - p = p * ValueSquared + MlasLogisticConstants.alpha_5; - p = p * ValueSquared + MlasLogisticConstants.alpha_3; - p = p * ValueSquared + MlasLogisticConstants.alpha_1; - p = p * Value; - - float q; - q = ValueSquared * MlasLogisticConstants.beta_10 + MlasLogisticConstants.beta_8; - q = q * ValueSquared + MlasLogisticConstants.beta_6; - q = q * ValueSquared + MlasLogisticConstants.beta_4; - q = q * ValueSquared + MlasLogisticConstants.beta_2; - q = q * ValueSquared + MlasLogisticConstants.beta_0; - - *Output++ = (p / q) + 0.5f; - - N -= 1; - } -} - -void -MLASCALL -MlasComputeLogistic( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine computes the logistic function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().LogisticKernelRoutine(Input, Output, N); -#else - MlasLogisticKernel(Input, Output, N); -#endif -} diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h deleted file mode 100644 index 8d812baabdf9d..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelCommon.h +++ /dev/null @@ -1,27 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the double - precision matrix/matrix multiply operation (DGEMM). - ---*/ - -#define LFgemmElementShift 3 -#define LFgemmElementSize (1 << LFgemmElementShift) -#define LFgemmYmmElementCount (32/LFgemmElementSize) - -#include "FgemmKernelCommon.h" - -FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.d) -FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.d) -FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.d) -FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.d) diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S deleted file mode 100644 index 2f197d6891579..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLasx.S +++ /dev/null @@ -1,32 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelLasx.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses Lasx instructions. - ---*/ - -#include "asmmacro.h" -#include "DgemmKernelCommon.h" -#include "FgemmKernelLasxCommon.h" - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelLasxFunction MlasGemmDoubleKernelLasx - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S deleted file mode 100644 index 63395631a9bc5..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/DgemmKernelLsx.S +++ /dev/null @@ -1,217 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelLsx.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses Lsx instructions. - ---*/ - -#include "asmmacro.h" -#include "FgemmKernelLsxCommon.h" - -FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.d) -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 8xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - a1 (rsi) - Supplies the address into the matrix B data. - - vr0-vr1 - Supplies up to two elements loaded from matrix A and matrix A - plus one row. - - vr8-vr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseBy8 RowCount - - vld $vr4, $a1, 0 - vld $vr5, $a1, 16 -.if \RowCount\() == 2 - vmove $vr6, $vr4 - vmove $vr7, $vr5 -.endif - vfmadd.d $vr8, $vr4, $vr0, $vr8 - vfmadd.d $vr9, $vr5, $vr0, $vr9 -.if \RowCount\() == 2 - vfmadd.d $vr12, $vr6, $vr1, $vr12 - vfmadd.d $vr13, $vr7, $vr1, $vr13 -.endif - vld $vr4, $a1, 32 - vld $vr5, $a1, 48 -.if \RowCount\() == 2 - vmove $vr6, $vr4 - vmove $vr7, $vr5 -.endif - vfmadd.d $vr10, $vr4, $vr0, $vr10 - vfmadd.d $vr11, $vr5, $vr0, $vr11 -.if \RowCount\() == 2 - vfmadd.d $vr14, $vr6, $vr1, $vr14 - vfmadd.d $vr15, $vr7, $vr1, $vr15 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - a0 - Supplies the address of matrix A. - - a1 - Supplies the address of matrix B. - - t8 - Supplies the address of matrix A. - - a5 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - a2 - Supplies the address of matrix C. - - a3 - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - t7 - Supplies the length in bytes of a row from matrix A. - - t5 - Supplies the length in bytes of a row from matrix C. - - s3 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough -.LProcessNextColumnLoop8xN\@: - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8,$vr8,$vr8" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9,$vr9,$vr9" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10,$vr10,$vr10" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11,$vr11,$vr11" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12,$vr12,$vr12" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13,$vr13,$vr13" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14,$vr14,$vr14" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15,$vr15,$vr15" - move $t7,$a3 # reload CountK -.LCompute8xNBlockBy1Loop\@: - EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a0, 0" - EmitIfCountGE \RowCount\(), 1, "vreplgr2vr.d $vr0, $s0" - EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a0, $t0" - EmitIfCountGE \RowCount\(), 2, "vreplgr2vr.d $vr1, $s0" - ComputeBlockSseBy8 \RowCount\() - addi.d $a1, $a1, 8*8 # advance matrix B by 8 columns - addi.d $a0, $a0, 8 # advance matrix A by 1 column - addi.d $t7, $t7, -1 - bnez $t7, .LCompute8xNBlockBy1Loop\@ - -.LOutput8xNBlock\@: - movfr2gr.d $s0, $f24 - vreplgr2vr.d $vr2, $s0 - # multiply by alpha - EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr8, $vr8, $vr2" - EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr9, $vr9, $vr2" - EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr10,$vr10, $vr2" - EmitIfCountGE \RowCount\(), 1, "vfmul.d $vr11,$vr11, $vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr12,$vr12, $vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr13,$vr13, $vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr14,$vr14, $vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.d $vr15,$vr15, $vr2" - li.d $s0, 8 - blt $a5, $s0, .LOutputPartial8xNBlock\@ - sub.d $a5, $a5, $s0 - AccumulateAndStoreBlock \RowCount\(), 4 - addi.d $a2, $a2, 8*8 # advance matrix C by 8 columns - move $a0, $t1 # reload matrix A - bnez $a5, .LProcessNextColumnLoop8xN\@ - b .LExitKernel - -// -// Output a partial 8xN block to the matrix. -// - -.LOutputPartial8xNBlock\@: - li.d $s0, 2 - blt $a5, $s0, .LOutputPartial1xNBlock\@ - li.d $s0, 4 - blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ - li.d $s0, 6 - blt $a5, $s0, .LOutputPartialLessThan6xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 3 - andi $s0, $a5, 1 # check if remaining count is small - beqz $s0, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr11" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr15" - addi.d $a2, $a2, 6*8 # advance matrix C by 6 columns - b .LOutputPartial1xNBlock\@ - -.LOutputPartialLessThan6xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 2 - andi $s0, $a5,1 # check if remaining count is small - beqz $s0, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr10" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr14" - addi.d $a2, $a2, 4*8 # advance matrix C by 4 columns - b .LOutputPartial1xNBlock\@ - -.LOutputPartialLessThan4xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 1 - andi $s0, $a5,1 # check if remaining count is small - beqz $s0, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8,$vr9" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12,$vr13" - addi.d $a2, $a2, 2*8 # advance matrix C by 2 columns - -.LOutputPartial1xNBlock\@: - bnez $t5, .LSkipAccumulateOutput1xN\@ # ZeroMode? - - EmitIfCountGE \RowCount\(), 1, "fld.d $f15, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "fadd.d $f15, $f15, $f8" - EmitIfCountGE \RowCount\(), 2, "fldx.d $f16, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "fadd.d $f16, $f16, $f12" - -.LSkipAccumulateOutput1xN\@: - EmitIfCountGE \RowCount\(), 1, "fst.d $f15, $a2, 0" - EmitIfCountGE \RowCount\(), 2, "fstx.d $f16, $a2, $t6" -.ifb \Fallthrough\() - b .LExitKernel -.endif - - .endm - -// -// Generate the GEMM kernel. -// - -FgemmKernelLsxFunction MlasGemmDoubleKernelLSX - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h deleted file mode 100644 index 777a592590ec4..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelCommon.h +++ /dev/null @@ -1,100 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the floating - point matrix/matrix multiply operation (SGEMM and DGEMM). - ---*/ - -// -// Define the typed instruction template. -// - -#define FGEMM_TYPED_INSTRUCTION(Untyped, Typed) \ - .macro Untyped Operand:vararg; Typed \Operand\(); .endm; - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - - AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer - in rbx should also be advanced as part of the loop. - -Implicit Arguments: - - a0 - Supplies the address into the matrix A data. - - t7 - Supplies the address into the matrix A data plus 3 rows. - - a1 - Supplies the address into the matrix B data. - - a3 - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - vr4-vr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ComputeBlock, RowCount, AdvanceMatrixAPlusRows - - move $t8, $a3 # reload CountK - li.d $s0, 4 - blt $t8, $s0, .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*0, 64*4 - \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*1, 64*4 - addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes - \ComputeBlock\() \RowCount\(), 0, LFgemmElementSize*2, 64*4 - \ComputeBlock\() \RowCount\(), 2*32, LFgemmElementSize*3, 64*4 - addi.d $a1, $a1, 2*2*32 # advance matrix B by 128 bytes - addi.d $a0, $a0, 4*LFgemmElementSize # advance matrix A by 4 elements -.if \RowCount\() > 3 - addi.d $t7, $t7, 4*LFgemmElementSize # advance matrix A plus rows by 4 elements -.if \RowCount\() == 12 - addi.d $t3, $t3, 4*LFgemmElementSize - addi.d $t4,, $t4, 4*LFgemmElementSize -.endif -.endif - addi.d $t8, $t8, -4 - li.d $s0, 4 - bge $t8, $s0, .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - beqz $t8, .LOutputBlock\@ - -.LComputeBlockBy1Loop\@: - \ComputeBlock\() \RowCount\(), 0, 0 - addi.d $a1, $a1, 2*32 # advance matrix B by 64 bytes - addi.d $a0, $a0, LFgemmElementSize # advance matrix A by 1 element -.if \RowCount\() > 3 - addi.d $t7, $t7, LFgemmElementSize # advance matrix A plus rows by 1 element -.if \RowCount\() == 12 - addi.d $t3, $t3, LFgemmElementSize - addi.d $t4, $t4, LFgemmElementSize -.endif -.endif - addi.d $t8, $t8, -1 - bnez $t8, .LComputeBlockBy1Loop\@ - -.LOutputBlock\@: - - .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h deleted file mode 100644 index b96db848617bf..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLasxCommon.h +++ /dev/null @@ -1,546 +0,0 @@ - -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelLasxCommon.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses LASX instructions. - ---*/ - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - a0 - Supplies the address into the matrix A data. - - t7 - Supplies the address into the matrix A data plus 2 rows. - - a1 - Supplies the address into the matrix B data. - - t0 - Supplies the length in bytes of a row from matrix A. - - xr8-xr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLasxBy16 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.if \RowCount\() == 1 - xvldrepl.w $xr3, $a0, \BroadcastOffset\() - xvld $xr4, $a1, \VectorOffset\() - xvfmadd $xr8, $xr4, $xr3, $xr8 - xvld $xr5, $a1, \VectorOffset\()+32 - xvfmadd $xr9, $xr5, $xr3, $xr9 -.else - xvld $xr0, $a1, \VectorOffset\() - xvld $xr1, $a1, \VectorOffset\()+32 - EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3,$a0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr8, $xr3, $xr0, $xr8" - EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr1, $xr9" - EmitIfCountGE \RowCount\(), 2, "add.d $s0,$a0, $t0" - EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3,$s0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr10, $xr3, $xr0, $xr10" - EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr1, $xr11" - - EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3,$t7, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr12, $xr3, $xr0, $xr12" - EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr1, $xr13" - EmitIfCountGE \RowCount\(), 4, "add.d $s0,$t7, $t0" - EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3,$s0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr14, $xr3, $xr0, $xr14" - EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr1, $xr15" -.endif - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 1 YMMWORD by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - a0 - Supplies the address into the matrix A data. - - t7 - Supplies the address into the matrix A data plus 2 rows. - - a1 - Supplies the address into the matrix B data. - - t0 - Supplies the length in bytes of a row from matrix A. - - xr8-xr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLasxBy8 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.if \RowCount\() == 1 - xvldrepl.w $xr3, $a0, \BroadcastOffset\() - xvld $xr5, $a1, \VectorOffset\() - xvfmadd.s $xr9, $xr5, $xr3, $xr9 -.else - xvld $xr0, $a1, \VectorOffset\() - EmitIfCountGE \RowCount\(), 1, "xvldrepl $xr3, $a0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 1, "xvfmadd $xr9, $xr3, $xr0, $xr9" - - EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a0, $t0" - EmitIfCountGE \RowCount\(), 2, "xvldrepl $xr3, $s0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 2, "xvfmadd $xr11, $xr3, $xr0, $xr11" - EmitIfCountGE \RowCount\(), 3, "xvldrepl $xr3, $t7, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 3, "xvfmadd $xr13, $xr3, $xr0, $xr13" - EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t0" - EmitIfCountGE \RowCount\(), 4, "xvldrepl $xr3, $s0, \BroadcastOffset\()" - EmitIfCountGE \RowCount\(), 4, "xvfmadd $xr15, $xr3, $xr0, $xr15" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - a0 - Supplies the address into the matrix A data. - - a1 - Supplies the address into the matrix B data. - - a3 - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - t0 - Supplies the length in bytes of a row from matrix A. - - vr4-vr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLasxLoop ComputeBlock, RowCount - -.if \RowCount\() > 2 - # compute matrix A plus 2 rows - slli.d $s0, $t0, 1 - add.d $t7, $a0, $s0 -.endif - ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 2 -.if \RowCount\() > 2 - # compute matrix C plus 2 rows - slli.d $s0, $t6, 1 - add.d $t7, $a2, $s0 -.endif - - .endm - - .macro store_n src, num, dst - move $s2, \num\() - beqz $s2, .Lstore_exit\@ - xvstelm.w \src\(), \dst\(), 0, 0 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 4, 1 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 8, 2 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 12, 3 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 16, 4 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 20, 5 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - - xvstelm.w \src\(), \dst\(), 24, 6 - addi.d $s2, $s2, -1 - beqz $s2, .Lstore_exit\@ - -.Lstore_exit\@: - .endm -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - a0 - Supplies the address of matrix A. - - a1 - Supplies the address of matrix B. - - t1 - Supplies the address of matrix A. - - a5 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - a2 - Supplies the address of matrix C. - - a3 - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - t0 - Supplies the length in bytes of a row from matrix A. - - t6 - Supplies the length in bytes of a row from matrix C. - - t5 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - - ori $s1, $r0, LFgemmYmmElementCount - bgeu $s1, $a5, .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop2xN\@: - EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr8, $xr8, $xr8" - EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" - EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr10, $xr10, $xr10" - EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" - EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr12, $xr12, $xr12" - EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" - EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr14, $xr14, $xr14" - EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" - - ComputeBlockLasxLoop ComputeBlockLasxBy16, \RowCount\() - EmitIfCountGE \RowCount\(), 1, "xvfmul $xr8, $xr8, $xr2" - EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" - EmitIfCountGE \RowCount\(), 2, "xvfmul $xr10, $xr10, $xr2" - EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" - EmitIfCountGE \RowCount\(), 3, "xvfmul $xr12, $xr12, $xr2" - EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" - EmitIfCountGE \RowCount\(), 4, "xvfmul $xr14, $xr14, $xr2" - EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" - - sub.d $a5, $a5, $s1 - sub.d $a5, $a5, $s1 - blt $a5, $zero, .LOutputMasked2xNBlock\@ - andi $s0, $t5, 0xff # ZeroMode? - bnez $s0, .LStore2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" - EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0x20" - EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" - EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" - EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvld $xr16, $s0, 0x20" - EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" - EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" - EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" - EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0x20" - EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" - EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" - EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvld $xr16, $s0, 0x20" - EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" - -.LStore2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0x20" - EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "add.d $s0, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvst $xr11, $s0, 0x20" - EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" - EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0x20" - EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "add.d $s0, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvst $xr15, $s0, 0x20" - - addi.d $a2, $a2, 0x40 # advance matrix C by 2 XRWORDs - move $a0, $t1 # reload matrix A - bltu $s1, $a5, .LProcessNextColumnLoop2xN\@ - beqz $a5, .LExitKernel - -.LProcessRemainingCountN\@: - EmitIfCountGE \RowCount\(), 1, "xvxor.v $xr9, $xr9, $xr9" - EmitIfCountGE \RowCount\(), 2, "xvxor.v $xr11, $xr11, $xr11" - EmitIfCountGE \RowCount\(), 3, "xvxor.v $xr13, $xr13, $xr13" - EmitIfCountGE \RowCount\(), 4, "xvxor.v $xr15, $xr15, $xr15" - - - ComputeBlockLasxLoop ComputeBlockLasxBy8, \RowCount\() - EmitIfCountGE \RowCount\(), 1, "xvfmul $xr9, $xr9, $xr2" - EmitIfCountGE \RowCount\(), 2, "xvfmul $xr11, $xr11, $xr2" - EmitIfCountGE \RowCount\(), 3, "xvfmul $xr13, $xr13, $xr2" - EmitIfCountGE \RowCount\(), 4, "xvfmul $xr15, $xr15, $xr2" - bltu $a5, $s1, .LOutputMasked1xNBlock\@ - andi $s0, $t5, 0xff # ZeroMode? - bnez $s0, .LStore1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr16" - EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr16" - EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" - EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr16" - EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr16" - -.LStore1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "xvst $xr9, $a2, 0" - EmitIfCountGE \RowCount\(), 2, "xvstx $xr11, $a2, $t6" - EmitIfCountGE \RowCount\(), 3, "xvst $xr13, $t7, 0" - EmitIfCountGE \RowCount\(), 4, "xvstx $xr15, $t7, $t6" - b .LExitKernel - -.LOutputMasked2xNBlock\@: - andi $s0, $t5, 0xff # ZeroMode? - bnez $s0, .LStoreMasked2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "xvfadd $xr8, $xr8, $xr16" - EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvfadd $xr10, $xr10, $xr16" - EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" - EmitIfCountGE \RowCount\(), 3, "xvfadd $xr12, $xr12, $xr16" - EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvfadd $xr14, $xr14, $xr16" - -.LStoreMasked2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "xvst $xr8, $a2, 0" - EmitIfCountGE \RowCount\(), 2, "xvstx $xr10, $a2, $t6" - EmitIfCountGE \RowCount\(), 3, "xvst $xr12, $t7, 0" - EmitIfCountGE \RowCount\(), 4, "xvstx $xr14, $t7, $t6" - addi.d $a2, $a2, 0x20 # advance matrix C by YMMWORD -.if \RowCount\() > 2 - addi.d $t7, $t7, 0x20 # advance matrix C plus 2 rows by YMMWORD - -.endif - addi.d $a5, $a5, LFgemmYmmElementCount # correct for over-subtract above - - -.LOutputMasked1xNBlock\@: - -.if \RowCount\() > 2 - slli.d $s0, $t0, 1 - add.d $t7, $a0, $s0 -.endif - -.if \RowCount\() == 1 -.else -.endif - -.if \RowCount\() > 2 - slli.d $s0, $t6, 1 - add.d $t7, $a2, $s0 -.endif - - sub.d $a5, $zero, $a5 - la.global $a0, MlasMaskMoveTableLasx - ori $s0, $r0, LFgemmElementSize - mul.d $s0, $a5, $s0 - addi.d $s0, $s0, 8*4 - xvldx $xr0, $a0, $s0 - andi $s0, $t5, 0xff - - sub.d $a5, $zero, $a5 - - bnez $s0, .LStoreMasked1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "xvld $xr16, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "xvand.v $xr8, $xr16, $xr0" - EmitIfCountGE \RowCount\(), 2, "xvldx $xr16, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "xvand.v $xr10, $xr16, $xr0" - EmitIfCountGE \RowCount\(), 3, "xvld $xr16, $t7, 0" - EmitIfCountGE \RowCount\(), 3, "xvand.v $xr12, $xr16, $xr0" - EmitIfCountGE \RowCount\(), 4, "xvldx $xr16, $t7, $t6" - EmitIfCountGE \RowCount\(), 4, "xvand.v $xr14, $xr16, $xr0" - - EmitIfCountGE \RowCount\(), 1, "xvfadd $xr9, $xr9, $xr8" - EmitIfCountGE \RowCount\(), 2, "xvfadd $xr11, $xr11, $xr10" - EmitIfCountGE \RowCount\(), 3, "xvfadd $xr13, $xr13, $xr12" - EmitIfCountGE \RowCount\(), 4, "xvfadd $xr15, $xr15, $xr14" -.LStoreMasked1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "store_n $xr9, $a5, $a2" - - add.d $s3, $a2, $t6 - EmitIfCountGE \RowCount\(), 2, "store_n $xr11, $a5, $s3" - - EmitIfCountGE \RowCount\(), 3, "store_n $xr13, $a5, $t7" - - add.d $s3, $t7, $t6 - EmitIfCountGE \RowCount\(), 4, "store_n $xr15, $a5, $s3" - sub.d $a5, $zero, $a5 -.ifb \Fallthrough\() - b .LExitKernel -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelLasxFunction FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A a0 - Supplies the address of matrix A. - - B a1 - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C a2 - Supplies the address of matrix C. - - CountK a3 - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM a4 - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN a5 - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda a6 - Supplies the first dimension of matrix A. - - ldc a7 - Supplies the first dimension of matrix C. - - Alpha f0 - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode (sp + 0)- Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY \FunctionName\() - - addi.d $sp, $sp, -64 - st.d $ra, $sp, 56 - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - fst.s $f0, $sp, 2*8 - fst.d $f16, $sp,3*8 - st.d $s2, $sp, 4*8 - st.d $s3, $sp, 5*8 - - move $t1, $a0 - slli.d $t0, $a6, 2 # convert lda to bytes - slli.d $t6, $a7, 2 # convert ldc to bytes - ld.d $t5, $sp, 64 # get zeromode - fst.s $f0, $sp, 2*8 - xvldrepl.w $xr2, $sp, 0x10 - -// -// Process 4 rows of the matrices. -// - - ori $s0, $zero, 4 - bltu $a4, $s0, .LProcessCountMLessThan4 - li.d $a4, 4 # return 4 rows handled - ProcessCountM 4, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - bstrpick.d $a0, $a4, 31, 0 - ld.d $s0, $sp, 0 - ld.d $s1, $sp, 8 - fld.d $f16, $sp,3*8 - ld.d $s2, $sp, 4*8 - ld.d $s3, $sp, 5*8 - ld.d $ra, $sp, 7*8 - addi.d $sp, $sp, 64 - jr $ra - -// -// Process 2 rows of the matrices. -// - -.LProcessCountMLessThan4: - ori $s0, $r0, 2 - bltu $a4, $s0, .LProcessCountMLessThan2 - li.d $a4, 2 # return 2 rows handled - ProcessCountM 2 - -// -// Process 1 row of the matrices. -// - -.LProcessCountMLessThan2: - ProcessCountM 1 - - .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h deleted file mode 100644 index 0333af792ba70..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/FgemmKernelLsxCommon.h +++ /dev/null @@ -1,170 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelLsxCommon.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses Lsx instructions. - ---*/ - -#include "FgemmKernelCommon.h" -/*++ - -Macro Description: - - This stores the block accumulators to the output matrix with an optional - accumulation of the existing contents of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorCount - Supplies the number of vector columns to process. - -Implicit Arguments: - - t5 - Supplies the length in bytes of a row from matrix C. - - a2 - Supplies the address of matrix C. - - s3 - Stores the ZeroMode argument from the stack frame. - - vr8-vr15 - Supplies the block accumulators. - ---*/ - - .macro AccumulateAndStoreBlock RowCount, VectorCount - - and $s0, $t5,$t5 # ZeroMode? - bnez $s0 , .LSkipAccumulateOutput\@ - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vld $vr0, $a2, 0" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vld $vr1, $a2, 16" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vld $vr2, $a2, 32" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vld $vr3, $a2, 48" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vldx $vr4, $a2, $t6" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vldx $vr5, $a2, $s0" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vldx $vr6, $a2, $s0" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vldx $vr7, $a2, $s0" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vfadd $vr8, $vr8, $vr0" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vfadd $vr9, $vr9, $vr1" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vfadd $vr10,$vr10,$vr2" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vfadd $vr11,$vr11,$vr3" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vfadd $vr12,$vr12,$vr4" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vfadd $vr13,$vr13,$vr5" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vfadd $vr14,$vr14,$vr6" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vfadd $vr15,$vr15,$vr7" - -.LSkipAccumulateOutput\@: - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "vst $vr8, $a2, 0" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "vst $vr9, $a2, 16" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "vst $vr10, $a2, 32" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "vst $vr11, $a2, 48" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "vstx $vr12, $a2, $t6" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addi.d $s0, $t6, 16" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "vstx $vr13, $a2, $s0" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addi.d $s0, $t6, 32" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "vstx $vr14, $a2, $s0" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addi.d $s0, $t6, 48" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "vstx $vr15, $a2, $s0" - - .endm -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelLsxFunction FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (a0) - Supplies the address of matrix A. - - B (a1) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (a2) - Supplies the address of matrix C. - - CountK (a3) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (a4) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (a5) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda (a6) Supplies the first dimension of matrix A. - - ldc (a7) Supplies the first dimension of matrix C. - - Alpha (f0) - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode (sp 0) - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - -FUNCTION_ENTRY \FunctionName\() - addi.d $sp, $sp, -64 - st.d $t5, $sp, 0 - st.d $s0, $sp, 1*8 - st.d $s1, $sp, 2*8 - st.d $s2, $sp, 3*8 - st.d $s3, $sp, 4*8 - move $t1, $a0 - slli.d $t0, $a6, 2 //convert lda to bytes - slli.d $t6, $a7, 2 //convert ldc to bytes - ld.d $t5, $sp, 64 - fmov.s $f24, $f0 //f0 destroyed by lsx - - li.d $s0, 2 - blt $a4, $s0, .LProcessCountM1 - - li.d $a4, 2 - ProcessCountM 2, Fallthrough - -.LExitKernel: - ld.d $t5, $sp, 0 - ld.d $s0, $sp, 1*8 - ld.d $s1, $sp, 2*8 - ld.d $s2, $sp, 3*8 - ld.d $s3, $sp, 4*8 - addi.d $sp, $sp, 64 - move $a0, $a4 - jr $ra - -.LProcessCountM1: - ProcessCountM 1 - .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S deleted file mode 100644 index e03503521912a..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasx.S +++ /dev/null @@ -1,412 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelLasx.S - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses Lasx instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelLasxCommon.h" - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - a3 - Supplies the address of the input buffer. - - a2 - Supplies the address of the filter buffer. - - a1 - Supplies the FilterStride parameter (see function description). - - t7 - Supplies the address of the filter buffer plus 2 * FilterStride. - - a5 - Supplies the StrideWidth parameter (see function description). - - xr0-xr7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - xvld $xr12, $a2, 0 - EmitIfCountGE \OutputCount\(), 1, "xvld $xr8, $a3, 0" - EmitIfCountGE \OutputCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr12, $xr0" - EmitIfCountGE \OutputCount\(), 2, "xvldx $xr9, $a3, $a5" - EmitIfCountGE \OutputCount\(), 2, "xvfmadd.s $xr4, $xr9, $xr12, $xr4" - -.else - EmitIfCountGE \OutputCount\(), 1, "xvldrepl.w $xr13, $a3, \BroadcastOffset\()" - EmitIfCountGE \OutputCount\(), 2, "add.d $s0, $a3, $a5" - EmitIfCountGE \OutputCount\(), 2, "xvldrepl.w $xr14, $s0, \BroadcastOffset\()" -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "xvld $xr8, $a2, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 1, "xvfmadd.s $xr0, $xr8, $xr13, $xr0" - EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" - EmitIfCountGE \FilterCount\(), 2, "xvld $xr9, $s0, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 2, "xvfmadd.s $xr1, $xr9, $xr13, $xr1" - EmitIfCountGE \FilterCount\(), 3, "xvld $xr10, $t7, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 3, "xvfmadd.s $xr2, $xr10, $xr13, $xr2" - EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" - EmitIfCountGE \FilterCount\(), 4, "xvld $xr11, $s0, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 4, "xvfmadd.s $xr3, $xr11, $xr13, $xr3" -.else - EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a2, \VectorOffset\()" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmadd.s $xr0, $xr12, $xr13, $xr0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmadd.s $xr4, $xr12, $xr14, $xr4" - EmitIfCountGE \FilterCount\(), 2, "add.d $s0, $a2, $a1" - EmitIfCountGE \FilterCount\(), 2, "xvld $xr12, $s0, \VectorOffset\()" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmadd.s $xr1, $xr13, $xr12, $xr1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmadd.s $xr5, $xr14, $xr12, $xr5" - EmitIfCountGE \FilterCount\(), 3, "xvld $xr12, $t7, \VectorOffset\()" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmadd.s $xr2, $xr13, $xr12, $xr2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmadd.s $xr6, $xr14, $xr12, $xr6" - EmitIfCountGE \FilterCount\(), 4, "add.d $s0, $t7, $a1" - EmitIfCountGE \FilterCount\(), 4, "xvld $xr12, $s0, \VectorOffset\()" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmadd.s $xr3, $xr13, $xr12, $xr3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmadd.s $xr7, $xr14, $xr12, $xr7" -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - t7 - Supplies the DilationWidth parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - t5 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - -// -// Process the output blocks that include left padding. -// - - ld.d $t0, $sp, OutputCountLeftPad_arg - beqz $t0, .L\KernelType\().\FilterCount\().ProcessOutputCount - bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() - -// -// Process the output blocks that do not include any padding. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCount: - ld.d $t0, $sp, OutputCount_arg - li.d $s0, 2 - bltu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount - -.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2: - ProcessOutputCountN Lasx, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 - slli.d $s0, $a5, 1 # advance input by 2 elements - add.d $a0, $a0, $s0 - addi.d $t0, $t0, -2 - li.d $s0, 2 - bgeu $t0, $s0, .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2 - -.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: - -// -// Process the output blocks that include right padding plus any remaining output -// blocks from above. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: - ld.d $s0, $sp, OutputCountRightPad_arg - add.d $t0, $t0, $s0 - beqz $t0, .L\KernelType\().ExitKernel - bl MlasConv\KernelType\()FloatSingleLasxFilter\FilterCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description). - - t8 - Supplies the InputStride parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - t0 - Supplies the OutputCount parameter (see function description). - - t2 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount - li.d $s0, 2 - bltu $t0, $s0, .LPointwise.\FilterCount\().ProcessRemainingOutputCount - -.LPointwise.\FilterCount\().ProcessNextOutputCountBy2: - ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 2 - slli.d $s0, $a5, 1 # advance input by 2 elements - add.d $a0, $a0, $s0 - addi.d $t0, $t0, -2 - li.d $s0, 2 - bgeu $t0, $s0, .LPointwise.\FilterCount\().ProcessNextOutputCountBy2 - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount: - beqz $t0, .LPointwise.ExitKernel - ProcessPointwiseOutputCountN Lasx, 8, \FilterCount\(), 1 - - .endm - -// -// Generate the convolution kernels. -// - - SconvKernelFunction Nchw, 8, Lasx - SconvKernelFunction Nchwc, 8, Lasx, BiasFilter - SconvKernelDepthwiseFunction 8, Lasx - SconvKernelPointwiseFunction Lasx, BiasFilter - -/*++ - -Macro Description: - - This macro generates code to process an output block after the inner - convolution kernel has executed and then stores the output block to the - output buffer. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - ---*/ - - .macro PostProcessBlock FilterCount, OutputCount - - .globl MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() - .hidden MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\() -MlasConvPostProcessFloatLasxFilter\FilterCount\()Output\OutputCount\(): - - .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() - .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() -MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): - -.if \FilterCount\() > 2 - slli.d $s0, $t6, 1 # compute output plus 2 rows - add.d $t7, $a4, $s0 -.endif - -// -// Test if the existing contents of the output buffer should be accumulated -// with the output block. -// - - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvld $xr16, $a4, 0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvld $xr16, $a4, 32" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvld $xr16, $a4, 0x40" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr16" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvldx $xr16, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr16" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvld $xr16, $s0, 0x20" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr16" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvld $xr16, $s0, 0x40" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr16" - - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvld $xr16,$t7, 0" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr16" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvld $xr16,$t7, 0x20" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr16" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvld $xr16,$t7, 0x40" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr16" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvldx $xr16,$t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr16" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvld $xr16,$s0, 0x20" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr16" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvld $xr16,$s0, 0x40" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr16" - - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: - -// -// Test if the bias buffer should be accumulated with the output block. -// - - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "xvld $xr16, $a3, 0" - EmitIfCountGE \FilterCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" - EmitIfCountGE \FilterCount\(), 2, "xvld $xr16, $a3, 0x20" - EmitIfCountGE \FilterCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" - EmitIfCountGE \FilterCount\(), 3, "xvld $xr16, $a3, 0x40" - EmitIfCountGE \FilterCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" - EmitIfCountGE \FilterCount\(), 4, "xvld $xr16, $a3, 0x60" - EmitIfCountGE \FilterCount\(), 4, "xvfadd.s $xr3, $xr3, $xr16" -.else - EmitIfCountGE \FilterCount\(), 1, "xvld $xr12, $a3, 0" - EmitIfCountGE \FilterCount\(), 2, "xvld $xr13, $a3, 0x20" - EmitIfCountGE \FilterCount\(), 3, "xvld $xr14, $a3, 0x40" - EmitIfCountGE \FilterCount\(), 4, "xvld $xr15, $a3, 0x60" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfadd.s $xr4, $xr4, $xr12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfadd.s $xr8, $xr8, $xr12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfadd.s $xr1, $xr1, $xr13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfadd.s $xr5, $xr5, $xr13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfadd.s $xr9, $xr9, $xr13" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfadd.s $xr2, $xr2, $xr14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfadd.s $xr6, $xr6, $xr14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfadd.s $xr10, $xr10, $xr14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfadd.s $xr3, $xr3, $xr15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfadd.s $xr7, $xr7, $xr15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfadd.s $xr11, $xr11, $xr15" - -.endif - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: - -// -// Test for fused ReLU activation. -// - - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation - xvxor.v $xr15, $xr15, $xr15 - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvfmax.s $xr0, $xr15, $xr0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvfmax.s $xr4, $xr15, $xr4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvfmax.s $xr8, $xr15, $xr8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvfmax.s $xr1, $xr15, $xr1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvfmax.s $xr5, $xr15, $xr5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvfmax.s $xr9, $xr15, $xr9" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvfmax.s $xr2, $xr15, $xr2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvfmax.s $xr6, $xr15, $xr6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvfmax.s $xr10, $xr15, $xr10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvfmax.s $xr3, $xr15, $xr3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvfmax.s $xr7, $xr15, $xr7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvfmax.s $xr11, $xr15, $xr11" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: - -// -// Store the output block in the output buffer. -// - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvst $xr0, $a4, 0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvst $xr4, $a4, 0x20" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvst $xr8, $a4, 0x40" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvstx $xr1, $a4, $t6" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "add.d $s0, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvst $xr5, $s0, 0x20" - - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "add.d $s0, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvst $xr9, $s0, 0x40" - - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvst $xr2, $t7, 0" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvst $xr6, $t7, 0x20" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvst $xr10, $t7, 0x40" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvstx $xr3, $t7, $t6" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "add.d $s0, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvst $xr7, $s0, 0x20" - - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "add.d $s0, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvst $xr11, $s0, 0x40" - - add_immed $a4,\OutputCount\()*8*4 # advance output by N nchw8c blocks - jr $ra - - .endm - - .irp FilterCount, 1, 2, 3, 4 - .irp OutputCount, 1, 2, 3 - PostProcessBlock \FilterCount\(), \OutputCount\() - .endr - .endr - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h deleted file mode 100644 index bd2db816ed9ab..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLasxCommon.h +++ /dev/null @@ -1,868 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelLasxCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision convolution operation for the Lasx kernels. - ---*/ - - -#define SP_SIZE 32*8 - -#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 -#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 -#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 -#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 - -#define OutputStride_arg 6*8 -#define KernelHeight_arg 7*8 -#define KernelWidth_arg 8*8 -#define InputBase_arg 9*8 -#define InputWidth_arg 10*8 -#define DilatedInputWidth_arg 11*8 -#define OutputCountLeftPad_arg 12*8 -#define OutputCount_arg 13*8 -#define OutputCountRightPad_arg 14*8 -#define Bias_arg 15*8 -#define Flags_arg 16*8 -#define InputChannels_arg 17*8 -#define Filter_save_offset 18*8 - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - s8 - Supplies the DilationWidth parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - t5 - Supplies the InputStride parameter (see function description). ---*/ - .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount - - move $a3, $a0 -.ifeqs "\KernelType\()","Depthwise" - move $a2, $a1 -.else - ld.d $a2, $sp, Filter_save_offset -.endif - ld.d $t1, $sp, KernelHeight_arg - ld.d $t2, $sp, KernelWidth_arg -.if \OutputCount\() == 1 - ld.d $t3, $sp, InputBase_arg - ld.d $t4, $sp, InputWidth_arg - sub.d $t3, $zero, $t3 -.endif - ClearBlock \FilterCount\(), \OutputCount\() - beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing - -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: - move $t6, $t2 # reload kernel width remaining - -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - add.d $t7, $a3, $t3 # compute (Input - InputBase) - # (Input - InputBase) >= InputWidth? - bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding -.endif -.if \OutputCount\() > 3 - slli.d $s0, $a5, 1 - add.d $s0, $s0, $a5 - add.d $t4, $a3, $s0 # compute input plus 3 blocks -.endif -.if \FilterCount\() > 2 - slli.d $s0, $a1, 1 # compute filter plus 2 rows - add.d $t7, $a2, $s0 -.endif -.ifeqs "\KernelType\()","Nchwc" -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif -.else - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 -.endif - -.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: - # advance input by dilation width - add.d $a3, $a3, $t8 -.ifeqs "\KernelType\()","Nchwc" - # advance filter by 8i8o/16i16o block - addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 -.else - addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block -.endif - addi.d $t6, $t6, -1 - bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn - add.d $a3, $a3, $t5 # advance input to next row -.if \OutputCount\() == 1 - ld.d $s0, $sp, DilatedInputWidth_arg - # advance input base to next row - sub.d $t3, $t3, $s0 -.endif - addi.d $t1, $t1, -1 # decrement rows remaining - bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow - -// -// Handle post processing of the output block. -// - -.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: - ld.w $a2, $sp, Flags_arg -.if \FilterCount\() > 1 - ld.d $t6, $sp, OutputStride_arg -.endif - ld.d $a3, $sp, Bias_arg - bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - - BiasFilter - Supplies a non-blank value if the address of the filter buffer - should be biased to point to the middle of a OIhw8i8o block in order to - reduce the code size from relative byte offsets. - ---*/ - - .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter (a1) - Supplies the address of the filter buffer. - - Output (a2) - Supplies the address of the output buffer. - - StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (a4) - Supplies the length in bytes of the blocked dilation - width. - - FilterCount (a5) - Supplies the number of filters to process in this - iteration. - - InputStride (a6)- Supplies the length in bytes to advance the input buffer to - the next input row. - - FilterStride (a7) - Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - KernelHeight (sp + 8)- Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth (sp + 0x10)- Supplies the width of the kernel to apply. - - InputBase (sp + 0x18)- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (sp + 0x20)- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (sp + 0x28)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (sp + 0x30)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (sp + 0x38)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (sp + 0x40)- Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias (sp + 0x48)- Supplies the address of the bias buffer. - - Flags (sp + 0x50)- Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() - - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0 - st.d $s1, $sp, 8 - st.d $s2, $sp, 2*8 - st.d $ra, $sp, 5*8 - - ld.d $t0, $sp, SP_SIZE+0*8 - ld.d $t1, $sp, SP_SIZE+1*8 - ld.d $t2, $sp, SP_SIZE+2*8 - ld.d $t3, $sp, SP_SIZE+3*8 - st.d $t0, $sp, OutputStride_arg - st.d $t1, $sp, KernelHeight_arg - st.d $t2, $sp, KernelWidth_arg - st.d $t3, $sp, InputBase_arg - ld.d $t0, $sp, SP_SIZE+4*8 - ld.d $t1, $sp, SP_SIZE+5*8 - ld.d $t2, $sp, SP_SIZE+6*8 - ld.d $t3, $sp, SP_SIZE+7*8 - st.d $t0, $sp, InputWidth_arg - st.d $t1, $sp, DilatedInputWidth_arg - st.d $t2, $sp, OutputCountLeftPad_arg - st.d $t3, $sp, OutputCount_arg - ld.d $t0, $sp, SP_SIZE+8*8 - ld.d $t1, $sp, SP_SIZE+9*8 - ld.d $t2, $sp, SP_SIZE+10*8 - st.d $t0, $sp, OutputCountRightPad_arg - st.d $t1, $sp, Bias_arg - st.d $t2, $sp, Flags_arg - -.ifeqs "\BiasFilter\()","BiasFilter" - addi.d $a1, $a1, 4*8*4 -.endif - st.d $a1, $sp, Filter_save_offset - move $a1, $a7 - move $t5, $a6 - move $t8, $a4 - move $t1, $a5 - move $a4, $a2 - move $a5, $a3 - -// -// Process the specified number of filter rows. -// - - ori $s0, $zero, 3 - beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 - bltu $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 - ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 4 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount3: - ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 3 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCountLessThan3: - ori $s0, $zero, 2 - bltu $t1, $s0, .L\KernelType\().ProcessFilterCount1 - ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 2 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount1: - ProcessFilterCountN LSconvKernelFrame, \KernelType\(), 1 - -// -// Restore non-volatile registers and return. -// - -.L\KernelType\().ExitKernel: -.ifnes "\Isa\()","LSX" - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 -.endif - ld.d $s0, $sp, 0 - ld.d $s1, $sp, 8 - ld.d $s2, $sp, 2*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE - jirl $zero, $ra, 0 - -.ifnes "\Isa\()","LSX" - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - - .irp FilterCount, 1, 2, 3, 4 - -MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): - st.d $ra, $sp, 19*8 -loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): - ProcessOutputCountN \Isa\(), LSconvKernelSingleFrame, \KernelType\(), \BlockSize\(), \FilterCount\(), 1 - add.d $a0, $a0, $a5 # advance input by 1 element - addi.d $t0, $t0, -1 # decrement output count remaining - bnez $t0, loopMlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\() - ld.d $ra, $sp, 19*8 - jr $ra - - .endr - -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel for the special - case of a depthwise separable convolution. - -Arguments: - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SconvKernelDepthwiseFunction BlockSize, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Depthwise separable convolutions are a form of grouped convolution where - the number of input and output channels per group are one. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter (a1) - Supplies the address of the filter buffer. - - Output (a2) - Supplies the address of the output buffer. - - StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (a4) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (a5) - Supplies the length in bytes to advance the input buffer - to the next input row. - - KernelHeight (a6)- Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth (a7)- Supplies the width of the kernel to apply. - - InputBase (sp + 0 )- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (sp + 8 )- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (sp + 0x20)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias (sp + 0x30)- Supplies the address of the bias buffer. - - Flags (sp + 0x38)- Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() - - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0 - st.d $s1, $sp, 8 - st.d $s2, $sp, 2*8 - st.d $ra, $sp, 5*8 - - st.d $a6, $sp, KernelHeight_arg - st.d $a7, $sp, KernelWidth_arg - - ld.d $t0, $sp, SP_SIZE+0*8 - ld.d $t1, $sp, SP_SIZE+1*8 - ld.d $t2, $sp, SP_SIZE+2*8 - ld.d $t3, $sp, SP_SIZE+3*8 - st.d $t0, $sp, InputBase_arg - st.d $t1, $sp, InputWidth_arg - st.d $t2, $sp, DilatedInputWidth_arg - st.d $t3, $sp, OutputCountLeftPad_arg - ld.d $t0, $sp, SP_SIZE+4*8 - ld.d $t1, $sp, SP_SIZE+5*8 - ld.d $t2, $sp, SP_SIZE+6*8 - ld.d $t3, $sp, SP_SIZE+7*8 - st.d $t0, $sp, OutputCount_arg - st.d $t1, $sp, OutputCountRightPad_arg - st.d $t2, $sp, Bias_arg - st.d $t3, $sp, Flags_arg - - move $t8, $a4 - move $t5, $a5 - move $a4, $a2 - move $a5, $a3 - -// -// Process the specified number of filter rows. -// - - ProcessFilterCountN LSconvKernelDepthwiseFrame, Depthwise, 1 - -// -// Restore non-volatile registers and return. -// - -.LDepthwise.ExitKernel: -.ifnes "\Isa\()","LSX" - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 -.endif - ld.d $s0, $sp, 0 - ld.d $s1, $sp, 8 - ld.d $s2, $sp, 2*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE - jr $ra - -.ifnes "\Isa\()","LSX" - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - -MlasConvDepthwiseFloatSingle\Isa\()Filter1: - st.d $ra, $sp, 20*8 -MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop: - ProcessOutputCountN \Isa\(), LSconvKernelDepthwiseSingleFrame, Depthwise, \BlockSize\(), 1, 1 - add.d $a0, $a0, $a5 # advance input by 1 element - addi.d $t0, $t0, -1 # decrement output count remaining - - bnez $t0, MlasConvDepthwiseFloatSingle\Isa\()Filter1_loop - ld.d $ra, $sp, 20*8 - jr $ra - -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks - for a pointwise convolution. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description). - - t8 - Supplies the InputStride parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - t2 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount - - move $a3, $a0 - move $a2, $t2 - ld.d $t1, $sp, InputChannels_arg - ClearBlock \FilterCount\(), \OutputCount\() - -.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: -.if \OutputCount\() > 3 - slli.d $s0, $a5, 1 - add.d $s0, $s0, $a5 - add.d $t4, $s0, $a3 -.endif -.if \FilterCount\() > 2 - slli.d $s0, $a1, 1 - add.d $t7, $a2, $s0 -.endif -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif - add.d $a3, $a3, $t8 # advance input to next channel block - - addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 # advance filter by 8i8o/16i16o block - addi.d $t1, $t1, -1 # decrement input blocks remaining - - bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock - -// -// Handle post processing of the output block. -// - - ld.w $a2, $sp, Flags_arg -.if \FilterCount\() > 1 - ld.d $t6, $sp, OutputStride_arg -.endif - ld.d $a3, $sp, Bias_arg - bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel for the special - case where the kernel dimensions are 1. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - BiasFilter - Supplies a non-blank value if the address of the filter buffer - should be biased to point to the middle of a OIhw8i8o block in order to - reduce the code size from relative byte offsets. - ---*/ - - .macro SconvKernelPointwiseFunction Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Pointwise convolutions have a kernel size of one. To simplify this - implementation, no input padding is allowed, which matches typical usage in - models. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - Filter (a1) - Supplies the address of the filter buffer. - - Output (a2) - Supplies the address of the output buffer. - - StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. - - InputChannels (a4) - Supplies the number of input channels to process. - - FilterCount (a5) - Supplies the number of rows from the filter to process. - - InputStride (a6) - Supplies the length in bytes to advance the input buffer to - the next input channel of the same input row. - - FilterStride (a7) - Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride (sp + 0)- Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - OutputCount (sp + 8)- Supplies the number of output elements. - - Bias (sp + 0x10)- Supplies the address of the bias buffer. - - Flags (sp + 0x18)- Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() - - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - st.d $s2, $sp, 2*8 - st.d $ra, $sp, 5*8 - - ld.d $t0, $sp, SP_SIZE+0*8 - ld.d $t1, $sp, SP_SIZE+1*8 - ld.d $t2, $sp, SP_SIZE+2*8 - ld.d $t3, $sp, SP_SIZE+3*8 - st.d $t0, $sp, OutputStride_arg - st.d $t1, $sp, OutputCount_arg - st.d $t2, $sp, Bias_arg - st.d $t3, $sp, Flags_arg - st.d $a4, $sp, InputChannels_arg - -.ifeqs "\BiasFilter\()","BiasFilter" - addi.d $t2, $a1, 4*8*4 -.else - move $t2, $a1 -.endif - ld.d $t0, $sp, OutputCount_arg - move $a1, $a7 - move $t8, $a6 - move $t1, $a5 - move $a4, $a2 - move $a5, $a3 - -// -// Process the specified number of filter rows. -// - - ori $s0, $zero, 3 - beq $t1, $s0, .LPointwise.ProcessFilterCount3 - bltu $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 - ProcessPointwiseFilterCountN 4 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount3: - ProcessPointwiseFilterCountN 3 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCountLessThan3: - ori $s0, $zero, 2 - bltu $t1, $s0, .LPointwise.ProcessFilterCount1 - ProcessPointwiseFilterCountN 2 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount1: - ProcessPointwiseFilterCountN 1 - -// -// Restore non-volatile registers and return. -// - -.LPointwise.ExitKernel: -.ifnes "\Isa\()","LSX" - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 -.endif - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - ld.d $s2, $sp, 2*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE - jr $ra - - .endm - -/*++ - -Macro Description: - - This macro generates code to clear the block accumulators. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - xr0-xr11 - Supplies the block accumulators. - ---*/ - - .macro ClearBlock FilterCount, OutputCount - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "xvxor.v $xr4, $xr4, $xr4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "xvxor.v $xr8, $xr8, $xr8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xvxor.v $xr1, $xr1, $xr1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "xvxor.v $xr5, $xr5, $xr5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "xvxor.v $xr9, $xr9, $xr9" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xvxor.v $xr2, $xr2, $xr2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "xvxor.v $xr6, $xr6, $xr6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "xvxor.v $xr10, $xr10, $xr10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xvxor.v $xr3, $xr3, $xr3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "xvxor.v $xr7, $xr7, $xr7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "xvxor.v $xr11, $xr11, $xr11" - - .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S deleted file mode 100644 index 04b8dc14d067d..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsx.S +++ /dev/null @@ -1,339 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelLsx.S - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses Lsx instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelLsxCommon.h" - -/*++ - -Macro Description: - - This macro generates code to clear the block accumulators. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - vr0-vr7 - Supplies the block accumulators. - ---*/ - - .macro ClearBlock FilterCount, OutputCount - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr0,$vr0,$vr0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxor.v $vr1,$vr1,$vr1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr2,$vr2,$vr2" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxor.v $vr3,$vr3,$vr3" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr4,$vr4,$vr4" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxor.v $vr5,$vr5,$vr5" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr6,$vr6,$vr6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxor.v $vr7,$vr7,$vr7" - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - a3 - Supplies the address of the input buffer. - - a2 - Supplies the address of the filter buffer. - - a1 - Supplies the FilterStride parameter (see function description). - - t6 - Supplies the address of the filter buffer plus 2 * FilterStride. - - a5 - Supplies the StrideWidth parameter (see function description). - - vr0-vr7 - Supplies the block accumulators. - ---*/ - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - vld $vr8, $a2, 0 - vld $vr9, $a2, 16 - vld $vr10, $a3, 0 - vld $vr11, $a3, 16 - vfmadd.s $vr0, $vr8, $vr10, $vr0 - vfmadd.s $vr1, $vr9, $vr11, $vr1 -.else - EmitIfCountGE \OutputCount\(), 1, "ld.w $s0, $a3, \BroadcastOffset\()" - EmitIfCountGE \OutputCount\(), 1, "vreplgr2vr.w $vr12, $s0" - EmitIfCountGE \FilterCount\(), 1, "vld $vr8, $a2, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 1, "vld $vr9, $a2, \VectorOffset\()+16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr0, $vr8, $vr12, $vr0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd.s $vr1, $vr9, $vr12, $vr1" - EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()" - EmitIfCountGE \FilterCount\(), 2, "vldx $vr8, $a2, $s0" - EmitIfCountGE \FilterCount\(), 2, "addi.d $s0, $a1, +\VectorOffset\()+16" - EmitIfCountGE \FilterCount\(), 2, "vldx $vr9, $a2, $s0" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr2, $vr8, $vr12, $vr2" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd.s $vr3, $vr9, $vr12, $vr3" - EmitIfCountGE \FilterCount\(), 3, "vld $vr8, $t7, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 3, "vld $vr9, $t7, \VectorOffset\()+16" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr4, $vr8, $vr12, $vr4" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd.s $vr5, $vr9, $vr12, $vr5" - EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()" - EmitIfCountGE \FilterCount\(), 4, "vldx $vr8, $t7, $s0" - EmitIfCountGE \FilterCount\(), 4, "addi.d $s0, $a1, \VectorOffset\()+16" - EmitIfCountGE \FilterCount\(), 4, "vldx $vr9, $t7, $s0" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr6, $vr8, $vr12, $vr6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd.s $vr7, $vr9, $vr12, $vr7" -.endif - .endm -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - s8 - Supplies the DilationWidth parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - s3 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - ld.d $s0, $sp, OutputCountLeftPad_arg //OutputCountLeftPad - ld.d $s1, $sp, OutputCount_arg //OutputCount - add.d $s0, $s0, $s1 - ld.d $s1, $sp, OutputCountRightPad_arg //OutputCountRightPad - add.d $t0, $s0, $s1 -.L\KernelType\().\FilterCount\().ProcessNextOutputCount: - ProcessOutputCountN Sse, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 1 - add.d $a0, $a0, $a5 - addi.d $t0, $t0, -1 - bnez $t0, .L\KernelType\().\FilterCount\().ProcessNextOutputCount - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description). - - s8 - Supplies the InputStride parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - t7 - Supplies the OutputCount parameter (see function description). - - s5 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount -.LPointwise.\FilterCount\().ProcessNextOutputCount: - ProcessPointwiseOutputCountN Sse, 8, \FilterCount\(), 1 - add.d $a0, $a0, $a5 - addi.d $t0, $t0, -1 - bnez $t0, .LPointwise.\FilterCount\().ProcessNextOutputCount - .endm - -// -// Generate the convolution kernels. -// - - SconvKernelFunction Nchw, 8, LSX - SconvKernelFunction Nchwc, 8, LSX, BiasFilter - SconvKernelDepthwiseFunction 8, LSX - SconvKernelPointwiseFunction LSX, BiasFilter - -/*++ - -Macro Description: - - This macro generates code to process an output block after the inner - convolution kernel has executed and then stores the output block to the - output buffer. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. ---*/ - - .macro PostProcessBlock FilterCount, OutputCount - - .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() -#if !defined(__APPLE__) - .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() -#endif -MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): - -.if \FilterCount\() > 2 - li.d $s0, 2 - mul.d $s0, $s0, $t6 - add.d $t7, $a4, $s0 -.endif - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - andi $s0, $s0, 0xff - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a4, 0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a4, 16" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr10, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vldx $vr11, $a4, $s0" - - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $t7, 0" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $t7, 16" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr14, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vldx $vr15, $t7, $s0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: -// -// Test if the bias buffer should be accumulated with the output block. -// - - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - andi $s0, $s0, 0xff - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr8, $a3, 0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vld $vr9, $a3, 16" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr10, $a3, 32" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vld $vr11, $a3, 48" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr12, $a3, 64" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vld $vr13, $a3, 80" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr14, $a3, 96" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vld $vr15, $a3, 112" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr0, $vr0, $vr8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfadd.s $vr1, $vr1, $vr9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr2, $vr2, $vr10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfadd.s $vr3, $vr3, $vr11" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr4, $vr4, $vr12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfadd.s $vr5, $vr5, $vr13" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr6, $vr6, $vr14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfadd.s $vr7, $vr7, $vr15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: - -// -// Test for fused ReLU activation. -// - - andi $s0, $a2, MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - andi $s0, $s0, 0xff - beqz $s0, .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation - vxor.v $vr15,$vr15, $vr15 - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr0, $vr0, $vr15" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmax.s $vr1, $vr1, $vr15" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr2, $vr2, $vr15" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmax.s $vr3, $vr3, $vr15" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr4, $vr4, $vr15" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmax.s $vr5, $vr5, $vr15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr6, $vr6, $vr15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmax.s $vr7, $vr7, $vr15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: - -// -// Store the output block in the output buffer. -// - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr0, $a4,0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vst $vr1, $a4, 16" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr2, $a4, $t6" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addi.d $s0, $t6, 16" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vstx $vr3, $a4, $s0" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr4, $t7, 0" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vst $vr5, $t7, 16" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr6, $t7, $t6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addi.d $s0, $t6, 16" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vstx $vr7, $t7, $s0" - add_immed $a4, \OutputCount\()*8*4 # advance output by N nchw8c blocks - jr $ra - - .endm - - .irp FilterCount, 1, 2, 3, 4 - .irp OutputCount, 1 - PostProcessBlock \FilterCount\(), \OutputCount\() - .endr - .endr - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h deleted file mode 100644 index d03714f654500..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SconvKernelLsxCommon.h +++ /dev/null @@ -1,669 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelLsxCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision convolution operation for the Lsx kernels. - ---*/ - -#define SP_SIZE 32*8 - -#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 -#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 -#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 -#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 - -#define Filter_save_offset 18*8 - -#define OutputStride_arg 6*8 -#define KernelHeight_arg 7*8 -#define KernelWidth_arg 8*8 -#define InputBase_arg 9*8 -#define InputWidth_arg 10*8 -#define DilatedInputWidth_arg 11*8 -#define OutputCountLeftPad_arg 12*8 -#define OutputCount_arg 13*8 -#define OutputCountRightPad_arg 14*8 -#define Bias_arg 15*8 -#define Flags_arg 16*8 -#define InputChannels_arg 17*8 - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a1 - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - s8 - Supplies the DilationWidth parameter (see function description). - - a4 - Supplies the address of the output buffer. - - a5 - Supplies the StrideWidth parameter (see function description). - - s3 - Supplies the InputStride parameter (see function description). ---*/ - - .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount - move $a3, $a0 -.ifeqs "\KernelType\()","Depthwise" - move $a2, $a1 -.else - ld.d $a2, $sp, Filter_save_offset -.endif - ld.d $t1, $sp, KernelHeight_arg //KernelHeight - ld.d $t2, $sp, KernelWidth_arg //KernelWidth -.if \OutputCount\() == 1 - ld.d $t3, $sp, InputBase_arg //InputBase - ld.d $t4, $sp, InputWidth_arg //InputWidth - sub.d $t3, $zero, $t3 # keep negative for lea usage below -.endif - ClearBlock \FilterCount\(), \OutputCount\() - beqz $t1, .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing - -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: - move $t6, $t2 # reload kernel width remaining -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - add.d $t7, $a3, $t3 - bgeu $t7, $t4, .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding -.endif -.if \OutputCount\() > 3 - li.d $s2, 2 - mul.d $s2, $a5, $s2 - add.d $t4, $a5, $s2 - - add.d $t4, $t4, $a3 # compute input plus 3 blocks -.endif -.if \FilterCount\() > 2 - li.d $s2, 2 - mul.d $s2, $s2, $a1 - add.d $t7, $a2, $s2 //t6 is rbx used by ComputeBlock -.endif -.ifeqs "\KernelType\()","Nchwc" -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif -.else - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 -.endif -.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: - add.d $a3, $a3, $t8 # advance input by dilation width -.ifeqs "\KernelType\()","Nchwc" - addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 - # advance filter by 8i8o/16i16o block -.else - addi.d $a2, $a2, \BlockSize\()*4 # advance filter by 8o/16o block -.endif - addi.d $t6, $t6, -1 # decrement columns remaining - bnez $t6, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn - add.d $a3, $a3, $t5 -.if \OutputCount\() == 1 - ld.d $s0, $sp, DilatedInputWidth_arg #DilatedInputWidth - sub.d $t3, $t3, $s0 - # advance input base to next row -.endif - addi.d $t1, $t1, -1 # decrement rows remaining - bnez $t1, .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow - -// -// Handle post processing of the output block. -// -.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: - ld.w $a2, $sp, Flags_arg - -.if \FilterCount\() > 1 - ld.d $t6, $sp, OutputStride_arg -.endif - ld.d $a3, $sp, Bias_arg - bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() -.endm -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - - BiasFilter - Supplies a non-blank value if the address of the filter buffer - should be biased to point to the middle of a OIhw8i8o block in order to - reduce the code size from relative byte offsets. - ---*/ - - .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter (a1) - Supplies the address of the filter buffer. - - Output (a2) - Supplies the address of the output buffer. - - StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (a4) - Supplies the length in bytes of the blocked dilation - width. - - FilterCount (a5) - Supplies the number of filters to process in this - iteration. - - InputStride (a6) - Supplies the length in bytes to advance the input buffer to - the next input row. - - FilterStride (a7)- Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride (sp,8*0) - Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - KernelHeight (sp,8*1)- Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth (sp, 8*2)- Supplies the width of the kernel to apply. - - InputBase (sp, 8*3)- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (sp, 8*4)- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (sp, 8*5)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (sp, 8*6)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (sp, 8*7)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (sp, 8*8)- Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias (sp, 8*9)- Supplies the address of the bias buffer. - - Flags (sp, 8*10)- Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - st.d $s2, $sp, 2*8 - st.d $s3, $sp, 3*8 - st.d $s4, $sp, 4*8 - st.d $ra, $sp, 5*8 - ld.d $s0, $sp, SP_SIZE+0*8 - ld.d $s1, $sp, SP_SIZE+1*8 - ld.d $s2, $sp, SP_SIZE+2*8 - ld.d $s3, $sp, SP_SIZE+3*8 - st.d $s0, $sp, OutputStride_arg - st.d $s1, $sp, KernelHeight_arg - st.d $s2, $sp, KernelWidth_arg - st.d $s3, $sp, InputBase_arg - ld.d $s0, $sp, SP_SIZE+4*8 - ld.d $s1, $sp, SP_SIZE+5*8 - ld.d $s2, $sp, SP_SIZE+6*8 - ld.d $s3, $sp, SP_SIZE+7*8 - st.d $s0, $sp, InputWidth_arg - st.d $s1, $sp, DilatedInputWidth_arg - st.d $s2, $sp, OutputCountLeftPad_arg - st.d $s3, $sp, OutputCount_arg - ld.d $s0, $sp, SP_SIZE+8*8 - ld.d $s1, $sp, SP_SIZE+9*8 - ld.d $s2, $sp, SP_SIZE+10*8 - st.d $s0, $sp, OutputCountRightPad_arg - st.d $s1, $sp, Bias_arg - st.d $s2, $sp, Flags_arg - -.ifeqs "\BiasFilter\()","BiasFilter" - addi.d $a1, $a1,4*8*4 -.endif - st.d $a1, $sp, Filter_save_offset //store Filter - move $a1, $a7 - move $t5, $a6 - move $t8, $a4 # shuffle to Win64 register usage - move $t1, $a5 - move $a4, $a2 - move $a5, $a3 - - li.d $s0, 3 - beq $t1, $s0, .L\KernelType\().ProcessFilterCount3 - blt $t1, $s0, .L\KernelType\().ProcessFilterCountLessThan3 - ProcessFilterCountN SconvKernelFrame, \KernelType\(), 4 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount3: - ProcessFilterCountN SconvKernelFrame, \KernelType\(), 3 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCountLessThan3: - li.d $s0,2 - blt $t1, $s0, .L\KernelType\().ProcessFilterCount1 - ProcessFilterCountN SconvKernelFrame, \KernelType\(), 2 - b .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount1: - ProcessFilterCountN SconvKernelFrame, \KernelType\(), 1 - -// -// Restore non-volatile registers and return. -// - -.L\KernelType\().ExitKernel: - ld.d $a1, $sp, Filter_save_offset //restore Filter - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - ld.d $s2, $sp, 2*8 - ld.d $s3, $sp, 3*8 - ld.d $s4, $sp, 4*8 - ld.d $ra, $sp, 5*8 - - addi.d $sp, $sp, SP_SIZE - jr $ra -.endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel for the special - case of a depthwise separable convolution. - -Arguments: - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SconvKernelDepthwiseFunction BlockSize, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Depthwise separable convolutions are a form of grouped convolution where - the number of input and output channels per group are one. - -Arguments: - - Input a0 - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter a1 - Supplies the address of the filter buffer. - - Output a2 - Supplies the address of the output buffer. - - StrideWidth a3 - Supplies the length in bytes of the blocked stride width. - - DilationWidth a4 - Supplies the length in bytes of the blocked dilation - width. - - InputStride a5 - Supplies the length in bytes to advance the input buffer - to the next input row. - - KernelHeight a6 - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth a7- Supplies the width of the kernel to apply. - - InputBase (sp, 0*8)- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (sp, 1*8)- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (sp, 2*8)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (sp, 3*8)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (sp, 4*8)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (sp, 5*8)- Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias (sp, 6*8)- Supplies the address of the bias buffer. - - Flags (sp, 7*8)- Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - st.d $s2, $sp, 2*8 - st.d $s3, $sp, 3*8 - st.d $s4, $sp, 4*8 - st.d $ra, $sp, 5*8 - - st.d $a6, $sp, KernelHeight_arg - st.d $a7, $sp, KernelWidth_arg - - ld.d $s0, $sp, SP_SIZE+0*8 - ld.d $s1, $sp, SP_SIZE+1*8 - ld.d $s2, $sp, SP_SIZE+2*8 - ld.d $s3, $sp, SP_SIZE+3*8 - st.d $s0, $sp, InputBase_arg - st.d $s1, $sp, InputWidth_arg - st.d $s2, $sp, DilatedInputWidth_arg - st.d $s3, $sp, OutputCountLeftPad_arg - ld.d $s0, $sp, SP_SIZE+4*8 - ld.d $s1, $sp, SP_SIZE+5*8 - ld.d $s2, $sp, SP_SIZE+6*8 - ld.d $s3, $sp, SP_SIZE+7*8 - st.d $s0, $sp, OutputCount_arg - st.d $s1, $sp, OutputCountRightPad_arg - st.d $s2, $sp, Bias_arg - st.d $s3, $sp, Flags_arg -// -// Process the specified number of filter rows. -// - move $t8, $a4 // shuffle to Win64 register usage - move $t5, $a5 - move $a4, $a2 - move $a5, $a3 - ProcessFilterCountN SconvKernelDepthwiseFrame, Depthwise, 1 - -// -// Restore non-volatile registers and return. - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - ld.d $s2, $sp, 2*8 - ld.d $s3, $sp, 3*8 - ld.d $s4, $sp, 4*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE -// - jr $ra -.endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks - for a pointwise convolution. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - (a0) - Supplies the address of the input buffer. - - (a1) - Supplies the FilterStride parameter (see function description). - - (s8) - Supplies the InputStride parameter (see function description). - - (a4) - Supplies the address of the output buffer. - - (a5) - Supplies the StrideWidth parameter (see function description). - - (s5) - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount - - move $a3, $a0 - move $a2, $t2 - ld.d $t1, $sp, InputChannels_arg - ClearBlock \FilterCount\(), \OutputCount\() - -.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: -.if \OutputCount\() > 3 - li.d $s0, 2 - mul $s0, $s0, $a5 - add.d $t4, $a5, $s0 - add.d $t4, $t4, $a3 # compute input plus 3 blocks -.endif -.if \FilterCount\() > 2 - li.d $s0, 2 # compute filter plus 2 rows - mul.d $s0, $s0, $a1 - add.d $t7, $a2, $s0 -.endif - -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif - add.d $a3, $a3, $t8 # advance input to next channel block - addi.d $a2, $a2, \BlockSize\()*\BlockSize\()*4 - # advance filter by 8i8o/16i16o block - addi.d $t1, $t1, -1 //InputChannels decrement input blocks remaining - bnez $t1, .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock - -// -// Handle post processing of the output block. -// - ld.w $a2, $sp, Flags_arg #load flag -.if \FilterCount\() > 1 - ld.d $t6 ,$sp, OutputStride_arg #load .LSconvKernelPointwiseFrame_OutputStride -.endif - ld.d $a3, $sp, Bias_arg # load .LSconvKernelPointwiseFrame_Bias - bl MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() -.endm - - .macro SconvKernelPointwiseFunction Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Pointwise convolutions have a kernel size of one. To simplify this - implementation, no input padding is allowed, which matches typical usage in - models. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - Filter (a1) - Supplies the address of the filter buffer. - - Output (a2) - Supplies the address of the output buffer. - - StrideWidth (a3) - Supplies the length in bytes of the blocked stride width. - - InputChannels (a4) - Supplies the number of input channels to process. - - FilterCount (a5) - Supplies the number of rows from the filter to process. - - InputStride (a6) - Supplies the length in bytes to advance the input buffer to - the next input channel of the same input row. - - FilterStride (a7) - Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride (sp+0) - Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - OutputCount (sp+8) - Supplies the number of output elements. - - Bias (sp+16) - Supplies the address of the bias buffer. - - Flags (sp+24) - Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - st.d $s2, $sp, 2*8 - st.d $s3, $sp, 3*8 - st.d $s4, $sp, 4*8 - st.d $ra, $sp, 5*8 - - ld.d $s0, $sp, SP_SIZE+0*8 - ld.d $s1, $sp, SP_SIZE+1*8 - ld.d $s2, $sp, SP_SIZE+2*8 - ld.d $s3, $sp, SP_SIZE+3*8 - st.d $s0, $sp, OutputStride_arg - st.d $s1, $sp, OutputCount_arg - st.d $s2, $sp, Bias_arg - st.d $s3, $sp, Flags_arg - st.d $a4, $sp, InputChannels_arg - -.ifeqs "\BiasFilter\()","BiasFilter" - addi.d $t2, $a1, 4*8*4 -.else - move $t2, $a1 -.endif - - ld.d $t0, $sp, OutputCount_arg //OutputCount - move $a1, $a7 // FilterStride - move $t8, $a6 // InputStride - move $t1, $a5 // shuffle to Win64 register usage - move $a4, $a2 - move $a5, $a3 - -// -// Process the specified number of filter rows. -// - li.d $s0, 3 - beq $t1, $s0, .LPointwise.ProcessFilterCount3 - blt $t1, $s0, .LPointwise.ProcessFilterCountLessThan3 - ProcessPointwiseFilterCountN 4 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount3: - ProcessPointwiseFilterCountN 3 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCountLessThan3: - li.d $s0, 2 - blt $t1, $s0, .LPointwise.ProcessFilterCount1 - ProcessPointwiseFilterCountN 2 - b .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount1: - ProcessPointwiseFilterCountN 1 - -// -// Restore non-volatile registers and return. -// -.LPointwise.ExitKernel: - - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - ld.d $s2, $sp, 2*8 - ld.d $s3, $sp, 3*8 - ld.d $s4, $sp, 4*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE - jr $ra -.endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h deleted file mode 100644 index 93b109c90ae4f..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelCommon.h +++ /dev/null @@ -1,35 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision matrix/matrix multiply operation (SGEMM). - ---*/ - -// -// Define the single precision parameters. -// - -#define LFgemmElementShift 2 -#define LFgemmElementSize (1 << LFgemmElementShift) -#define LFgemmYmmElementCount (32/LFgemmElementSize) - -#include "FgemmKernelCommon.h" - -// -// Define the typed instructions for single precision. -// - -FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.s) -FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.s) -FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.w) -FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.s) diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S deleted file mode 100644 index d537742016d01..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLasx.S +++ /dev/null @@ -1,33 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelLasx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses LASX instructions. - ---*/ - -#include "asmmacro.h" -#include "SgemmKernelCommon.h" -#include "FgemmKernelLasxCommon.h" - - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelLasxFunction MlasGemmFloatKernelLasx - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S deleted file mode 100644 index 86b5ef8b51b00..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SgemmKernelLsx.S +++ /dev/null @@ -1,267 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelLsx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses Lsx instructions. - ---*/ - -#include "asmmacro.h" -#include "FgemmKernelLsxCommon.h" - -FGEMM_TYPED_INSTRUCTION(vfadd, vfadd.s) - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 16xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - Shuffle - Supplies the shuffle mask to extract the element from matrix A. - -Implicit Arguments: - - a1 - Supplies the address into the matrix B data. - - vr0-vr1 - Supplies up to four elements loaded from matrix A and matrix A - plus one row. - - vr8-vr15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseBy16 RowCount, VectorOffset, Shuffle - vld $vr4, $a1, \VectorOffset - vld $vr5, $a1, \VectorOffset + 16 - vreplvei.w $vr2, $vr0, \Shuffle -.if \RowCount\() == 2 - vreplvei.w $vr3, $vr1, \Shuffle - vmove $vr6, $vr4 - vmove $vr7, $vr5 -.endif - vfmadd.s $vr8, $vr4, $vr2, $vr8 - vfmadd.s $vr9, $vr5, $vr2, $vr9 -.if \RowCount\() == 2 - vfmadd.s $vr12, $vr6, $vr3, $vr12 - vfmadd.s $vr13, $vr7, $vr3, $vr13 -.endif - vld $vr4, $a1, \VectorOffset + 32 - vld $vr5, $a1, \VectorOffset + 48 -.if \RowCount\() == 2 - vmove $vr6, $vr4 - vmove $vr7, $vr5 -.endif - vfmadd.s $vr10, $vr4, $vr2, $vr10 - vfmadd.s $vr11, $vr5, $vr2, $vr11 -.if \RowCount\() == 2 - vfmadd.s $vr14, $vr6, $vr3, $vr14 - vfmadd.s $vr15, $vr7, $vr3, $vr15 -.endif - .endm - - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - a0 - Supplies the address of matrix A. - - a1 - Supplies the address of matrix B. - - t8 - Supplies the address of matrix A. - - a5 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - a2 - Supplies the address of matrix C. - - a3 - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - t7 - Supplies the length in bytes of a row from matrix A. - - t5 - Supplies the length in bytes of a row from matrix C. - - s3 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough -.LProcessNextColumnLoop16xN\@: - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr8, $vr8,$vr8" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr9, $vr9,$vr9" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr10, $vr10,$vr10" - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr11, $vr11,$vr11" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr12, $vr12,$vr12" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr13, $vr13,$vr13" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr14, $vr14,$vr14" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr15, $vr15,$vr15" - move $t8, $a3 - li.d $s0, 4 - blt $t8, $s0, .LProcessRemaining16xNBlocks\@ -.LCompute16xNBlockBy4Loop\@: - EmitIfCountGE \RowCount\(), 1, "vld $vr0, $a0, 0" - EmitIfCountGE \RowCount\(), 2, "vldx $vr1, $a0, $t0" #second line of A - ComputeBlockSseBy16 2, 0, 0x0 - ComputeBlockSseBy16 2, 16*4, 0x1 - addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns - ComputeBlockSseBy16 2, 0, 0x2 - ComputeBlockSseBy16 2, 16*4, 0x3 - addi.d $a1, $a1, 32*4 # advance matrix B by 32 columns - addi.d $a0, $a0, 4*4 # advance matrix A by 4 columns - addi.d $t8, $t8, -4 - li.d $s0, 4 #check matrix A remaining less than 4 - bge $t8, $s0, .LCompute16xNBlockBy4Loop\@ - -.LProcessRemaining16xNBlocks\@: - beqz $t8, .LOutput16xNBlock\@ - -.LCompute16xNBlockBy1Loop\@: - EmitIfCountGE \RowCount\(), 1, "ld.w $s0, $a0, 0" - EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.w $vr0, $s0, 0" - EmitIfCountGE \RowCount\(), 2, "ldx.w $s0,$a0, $t0" - EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.w $vr1,$s0, 0" - ComputeBlockSseBy16 2, 0, 0x00 - addi.d $a1, $a1, 16*4 #advance matrix B by 16 columns - addi.d $a0, $a0, 1*4 #advance matrix A by 1 column - addi.d $t8, $t8, -1 - bnez $t8, .LCompute16xNBlockBy1Loop\@ - -.LOutput16xNBlock\@: - movfr2gr.s $s0, $f24 - vreplgr2vr.w $vr2, $s0 - EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr8,$vr8,$vr2" - # multiply by alpha - EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr9,$vr9,$vr2" - EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr10,$vr10,$vr2" - EmitIfCountGE \RowCount\(), 1, "vfmul.s $vr11,$vr11,$vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr12,$vr12,$vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr13,$vr13,$vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr14,$vr14,$vr2" - EmitIfCountGE \RowCount\(), 2, "vfmul.s $vr15,$vr15,$vr2" - li.d $s0, 16 - blt $a5, $s0, .LOutputPartial16xNBlock\@ - sub.d $a5, $a5, $s0 - AccumulateAndStoreBlock \RowCount\(), 4 - addi.d $a2, $a2, 16*4 # advance matrix C by 16 columns - move $a0, $t1 # reload matrix A - bnez $a5, .LProcessNextColumnLoop16xN\@ - b .LExitKernel - -// -// Output a partial 16xN block to the matrix. -// - -.LOutputPartial16xNBlock\@: - li.d $s0, 4 - blt $a5, $s0, .LOutputPartialLessThan4xNBlock\@ - li.d $s0, 8 - blt $a5, $s0, .LOutputPartialLessThan8xNBlock\@ - li.d $s0, 12 - blt $a5, $s0, .LOutputPartialLessThan12xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 3 - andi $a5, $a5, 3 - beqz $a5, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr11" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr15" - addi.d $a2, $a2,12*4 # advance matrix C by 12 columns - b .LOutputPartialLessThan4xNBlock\@ - -.LOutputPartialLessThan12xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 2 - andi $a5, $a5, 3 - beqz $a5, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr10" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr14" - addi.d $a2, $a2,8*4 # advance matrix C by 8 columns - b .LOutputPartialLessThan4xNBlock\@ - -.LOutputPartialLessThan8xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 1 - andi $a5, $a5, 3 - beqz $a5, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vmove $vr8, $vr9" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "vmove $vr12, $vr13" - addi.d $a2, $a2, 4*4 # advance matrix C by 4 columns - -.LOutputPartialLessThan4xNBlock\@: - andi $s0, $a5, 2 - beqz $s0, .LOutputPartial1xNBlock\@ - and $s0, $t5, $t5 # ZeroMode? - bnez $s0, .LSkipAccumulateOutput2xN\@ - EmitIfCountGE \RowCount\(), 1, "vxor.v $vr0, $vr0, $vr0" - EmitIfCountGE \RowCount\(), 1, "ld.d $s0, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "vinsgr2vr.d $vr0, $s0, 0" - EmitIfCountGE \RowCount\(), 2, "vxor.v $vr1, $vr1, $vr1" - EmitIfCountGE \RowCount\(), 2, "ldx.d $s0, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "vinsgr2vr.d $vr1, $s0, 0" - EmitIfCountGE \RowCount\(), 1, "vfadd.s $vr8, $vr8, $vr0" - EmitIfCountGE \RowCount\(), 2, "vfadd.s $vr12, $vr12, $vr1" - -.LSkipAccumulateOutput2xN\@: - EmitIfCountGE \RowCount\(), 1, "vstelm.d $vr8, $a2, 0, 0" - EmitIfCountGE \RowCount\(), 2, "vpickve2gr.d $s0, $vr12, 0" - EmitIfCountGE \RowCount\(), 2, "stx.d $s0, $a2, $t6" - andi $s0, $a5, 1 - beqz $s0, .LExitKernel - EmitIfCountGE \RowCount\(), 1, "vpermi.w $vr8, $vr8, 0xee" - # shift third element down - EmitIfCountGE \RowCount\(), 2, "vpermi.w $vr12, $vr12, 0xee" - addi.d $a2, $a2, 2*4 # advance matrix C by 2 columns - -.LOutputPartial1xNBlock\@: - and $s0, $t5, $t5 # ZeroMode? - bnez $s0, .LSkipAccumulateOutput1xN\@ - - EmitIfCountGE \RowCount\(), 1, "fld.s $f16, $a2, 0" - EmitIfCountGE \RowCount\(), 1, "fadd.s $f8, $f16, $f8" - EmitIfCountGE \RowCount\(), 2, "fldx.s $f17, $a2, $t6" - EmitIfCountGE \RowCount\(), 2, "fadd.s $f12, $f12, $f17" - -.LSkipAccumulateOutput1xN\@: - EmitIfCountGE \RowCount\(), 1, "fst.s $f8, $a2, 0" - EmitIfCountGE \RowCount\(), 2, "fstx.s $f12, $a2, $t6" -.ifb \Fallthrough\() - b .LExitKernel -.endif - .endm - -// -// Generate the GEMM kernel. -// - -FgemmKernelLsxFunction MlasGemmFloatKernelLSX - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S deleted file mode 100644 index cd1747745d2a4..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4LSX.S +++ /dev/null @@ -1,89 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmTransposePackB16x4LSX.s - -Abstract: - - This module implements routines for packing buffers for the single precision - matrix/matrix multiply operation (SGEMM). - - This implementation uses Lsx instructions. - ---*/ - -#include "asmmacro.h" - - .text - -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 - rows in the destination packed buffer. - -Arguments: - - D (a0) - Supplies the address of the destination packed buffer. - - B (a1) - Supplies the address of the source matrix. - - ldb (a2) - Supplies the number of elements per row of the source matrix. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmTransposePackB16x4LSX - addi.d $sp, $sp, -64 - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - slli.d $a2, $a2, 2 # convert ldb to bytes - ori $a3, $zero, 4 # transpose four 4x4 blocks - vxor.v $vr7, $vr7, $vr7 -.LTransposeBlockLoop: - slli.d $s0, $a2, 1 - add.d $s1, $a1, $s0 - vld $vr0, $a1, 0 - vldx $vr1, $a1, $a2 - vld $vr2, $s1, 0 - vldx $vr3, $s1, $a2 - - vor.v $vr4, $vr0, $vr7 - vilvl.w $vr4, $vr1, $vr4 - vilvh.w $vr0, $vr1, $vr0 - vor.v $vr5, $vr2, $vr7 - vilvl.w $vr5, $vr3, $vr5 - vilvh.w $vr2, $vr3, $vr2 - vor.v $vr1, $vr4, $vr7 - vilvl.d $vr1, $vr5, $vr1 - vilvh.d $vr4, $vr5, $vr4 - vor.v $vr3, $vr0, $vr7 - vilvl.d $vr3, $vr2, $vr3 - vilvh.d $vr0, $vr2, $vr0 - vst $vr1, $a0, 0 - vst $vr4, $a0, 0x40 - vst $vr3, $a0, 0x80 - vst $vr0, $a0, 0xc0 - addi.d $a0, $a0, 0x10 - slli.d $s0, $a2, 1 - add.d $a1, $s0, $s1 - addi.d $a3, $a3, -1 - bnez $a3, .LTransposeBlockLoop - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - addi.d $sp, $sp, 64 - jr $ra - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S b/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S deleted file mode 100644 index e617419989c4d..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SgemmTransposePackB16x4Lasx.S +++ /dev/null @@ -1,126 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmTransposePackB16x4Lasx.s - -Abstract: - - This module implements routines for packing buffers for the single precision - matrix/matrix multiply operation (SGEMM). - - This implementation uses Lasx instructions. - ---*/ - -#include "asmmacro.h" - - .text - -/*++ - -Macro Description: - - 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 - rows in the destination packed buffer. - -Arguments: - - StoreOffset - Supplies the relative byte offset into the destination packed - buffer. - -Implicit Arguments: - - a0 - Supplies the address of the destination packed buffer. - - a1 - Supplies the address of the source matrix. - - a2 - Supplies the number of elements per row of the source matrix. - ---*/ - - .macro TransposePackB8x4BlockLasx StoreOffset - -// -// Load 4 columns from 8 rows of the source matrix into the lower and upper -// halves of 4 XR registers. -// - - add.d $t0, $a2, $a2 - add.d $t6, $a1, $t0 - vld $vr0, $a1, 0 - vldx $vr1, $a1, $a2 - add.d $t0, $a2, $a2 - add.d $a1, $t6, $t0 - vld $vr2, $t6, 0 - vldx $vr3, $t6, $a2 - add.d $t0, $a2, $a2 - add.d $t6, $a1, $t0 - - vld $vr4, $a1, 0 - xvpermi.q $xr0, $xr4, 0x2 - vldx $vr5, $a1, $a2 - xvpermi.q $xr1, $xr5, 0x2 - vld $vr4, $t6, 0 - xvpermi.q $xr2, $xr4, 0x2 - vldx $vr5, $t6, $a2 - xvpermi.q $xr3, $xr5, 0x2 - -// -// Transpose the lower and upper halves of the 4 XR registers as two 4x4 -// matrices and store the output to the destination packed buffer. -// - - xvilvl.w $xr4, $xr1, $xr0 - xvilvh.w $xr5, $xr1, $xr0 - xvilvl.w $xr0, $xr3, $xr2 - xvilvh.w $xr1, $xr3, $xr2 - xvilvl.d $xr2, $xr0, $xr4 - xvilvh.d $xr3, $xr0, $xr4 - xvst $xr2, $a0, \StoreOffset\() - xvst $xr3, $a0, 0x40+\StoreOffset\() - xvilvl.d $xr0, $xr1, $xr5 - xvilvh.d $xr4, $xr1, $xr5 - xvst $xr0, $a0, 0x80+\StoreOffset\() - xvst $xr4, $a0, 0xc0+\StoreOffset\() - - .endm - -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 - rows in the destination packed buffer. - -Arguments: - - D (a0) - Supplies the address of the destination packed buffer. - - B (a1) - Supplies the address of the source matrix. - - ldb (a2) - Supplies the number of elements per row of the source matrix. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmTransposePackB16x4Lasx - - slli.d $a2, $a2, 2 # convert ldb to bytes - TransposePackB8x4BlockLasx 0*4 - add.d $t0, $a2, $a2 - add.d $a1, $t0, $t6 - TransposePackB8x4BlockLasx 8*4 - jr $ra - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S deleted file mode 100644 index aaaa3cbf9138d..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SoftmaxKernelLasx.S +++ /dev/null @@ -1,357 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SoftmaxKernelLasx.s - -Abstract: - - This module implements the kernels for the single precision softmax - operation. - - This implementation uses Lasx instructions. - ---*/ - -#include "asmmacro.h" - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to find the maximum value of - the supplied buffer. - -Arguments: - - Input (a0) - Supplies the input buffer. - - N (a1) - Supplies the number of elements to process. - -Return Value: - - Returns the maximum value of the supplied buffer. - ---*/ - - FUNCTION_ENTRY MlasReduceMaximumF32KernelLasx - addi.d $sp, $sp, -32 - - la.global $t0, MlasMinimumF32Value - ld.w $t0, $t0, 0 - xvreplgr2vr.w $xr0, $t0 - beqz $a1, .LReduceMaximum.ExitKernel - ori $t0, $zero, 8 - bltu $a1, $t0, .LReduceMaximum.ProcessRemainingCountBy1 - ori $t1, $zero, 32 - bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy8 - xvreplgr2vr.w $xr16, $zero - xvor.v $xr1, $xr0, $xr16 - xvor.v $xr2, $xr0, $xr16 - xvor.v $xr3, $xr0, $xr16 - -.LReduceMaximum.ProcessRemainingCountBy32: - xvld $xr16, $a0, 0 - xvfmax.s $xr0, $xr0, $xr16 - xvld $xr16, $a0, 8*4 - xvfmax.s $xr1, $xr1, $xr16 - addi.d $a1, $a1, -0x20 - xvld $xr16, $a0, 16*4 - xvfmax.s $xr2, $xr2, $xr16 - xvld $xr16, $a0, 24*4 - xvfmax.s $xr3, $xr3, $xr16 - addi.d $a0, $a0, 32*4 # advance input by 32 elements - ori $t1, $zero, 32 - bgeu $a1, $t1, .LReduceMaximum.ProcessRemainingCountBy32 - xvfmax.s $xr0, $xr0, $xr1 - xvfmax.s $xr2, $xr2, $xr3 - xvfmax.s $xr0, $xr0, $xr2 - -.LReduceMaximum.ProcessRemainingCountBy8: - ori $t1, $zero, 8 - bltu $a1, $t1, .LReduceMaximum.ProcessRemainingCountLessThan8 - xvld $xr16, $a0, 0 - xvfmax.s $xr0, $xr0, $xr16 - addi.d $a1, $a1, -8 - addi.d $a0, $a0, 8*4 - b .LReduceMaximum.ProcessRemainingCountBy8 - -.LReduceMaximum.ProcessRemainingCountLessThan8: - xvst $xr0, $sp, 0 - vld $vr1, $sp, 0x10 - vld $vr0, $sp, 0 - vfmax.s $vr0, $vr0, $vr1 - vshuf4i.w $vr1, $vr0, 0xee - vfmax.s $vr0, $vr0, $vr1 - vshuf4i.w $vr1, $vr0, 0x55 - vfmax.s $vr0, $vr0, $vr1 - beqz $a1, .LReduceMaximum.ExitKernel - -.LReduceMaximum.ProcessRemainingCountBy1: - vld $vr16, $a0, 0 - vfmax.s $vr0, $vr0, $vr16 - addi.d $a0, $a0, 4 # advance input by 1 element - addi.d $a1, $a1, -1 - bnez $a1, .LReduceMaximum.ProcessRemainingCountBy1 - -.LReduceMaximum.ExitKernel: - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 - addi.d $sp, $sp, 32 - jr $ra - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to produce the final output for - the softmax operation. - -Arguments: - - Output (a0) - Supplies the output buffer. - - N (a1) - Supplies the number of elements to process. - - Parameters (a2) - Supplies an array containing the scale value. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeSoftmaxOutputF32KernelLasx - - ld.w $t0, $a2, 0 - xvreplgr2vr.w $xr4, $t0 - ori $t1, $zero, 0x20 - bltu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeSoftmaxOutput.ProcessRemainingCountBy32: - xvld $xr16, $a0, 0 - xvfmul.s $xr0, $xr4, $xr16 - xvld $xr16, $a0, 8*4 - xvfmul.s $xr1, $xr4, $xr16 - addi.d $a1, $a1, -0x20 - xvld $xr16, $a0, 16*4 - xvfmul.s $xr2, $xr4, $xr16 - xvld $xr16, $a0, 24*4 - xvfmul.s $xr3, $xr4, $xr16 - xvst $xr0, $a0, 0 - xvst $xr1, $a0, 8*4 - xvst $xr2, $a0, 16*4 - xvst $xr3, $a0, 24*4 - addi.d $a0, $a0, 0x80 # advance output by 32 elements - bgeu $a1, $t1, .LComputeSoftmaxOutput.ProcessRemainingCountBy32 - -.LComputeSoftmaxOutput.ProcessRemainingCountBy8: - ori $t2, $zero, 8 - bltu $a1, $t2, .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8 - xvld $xr16, $a0, 0 - xvfmul.s $xr0, $xr4, $xr16 - addi.d $a1, $a1, -8 - xvst $xr0, $a0, 0 - addi.d $a0, $a0, 8*4 # advance output by 8 elements - b .LComputeSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: - beqz $a1, .LComputeSoftmaxOutput.ExitKernel - -.LComputeSoftmaxOutput.ProcessRemainingCountBy1: - fld.s $f16, $a0, 0 - fmul.s $f0, $f4, $f16 - fst.s $f0, $a0, 0 - addi.d $a0, $a0, 4 # advance output by 1 element - addi.d $a1, $a1, -1 - bnez $a1, .LComputeSoftmaxOutput.ProcessRemainingCountBy1 - -.LComputeSoftmaxOutput.ExitKernel: - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 - jr $ra - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to produce the final output for - the log softmax operation. - -Arguments: - - Input (a0) - Supplies the output buffer. - - Output (a1) - Supplies the output buffer. - - N (a2) - Supplies the number of elements to process. - - Parameters (a3) - Supplies an array containing the negative maximum and - logarithm values. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeLogSoftmaxOutputF32KernelLasx - - ld.w $t0, $a3, 0 - ld.w $t1, $a3, 4 - ori $t2, $zero, 0x20 - xvreplgr2vr.w $xr4, $t0 # broadcast negative minimum value - xvreplgr2vr.w $xr5, $t1 # broadcast log(SumExp) - bltu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy32: - xvld $xr16, $a0, 0 - xvfadd.s $xr0, $xr4, $xr16 - xvld $xr16, $a0, 0x20 - xvfadd.s $xr1, $xr4, $xr16 - addi.d $a2, $a2, -0x20 - xvld $xr16, $a0, 0x40 - xvfadd.s $xr2, $xr4, $xr16 - xvld $xr16, $a0, 0x60 - xvfadd.s $xr3, $xr4, $xr16 - addi.d $a0, $a0, 0x80 # advance input by 32 elements - xvfsub.s $xr0, $xr0, $xr5 # do as two steps for numeric stability - xvfsub.s $xr1, $xr1, $xr5 # do as two steps for numeric stability - xvfsub.s $xr2, $xr2, $xr5 # do as two steps for numeric stability - xvfsub.s $xr3, $xr3, $xr5 # do as two steps for numeric stability - xvst $xr0, $a1, 0 - xvst $xr1, $a1, 0x20 - xvst $xr2, $a1, 0x40 - xvst $xr3, $a1, 0x60 - addi.d $a1, $a1, 0x80 # advance output by 32 elements - bgeu $a2, $t2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy32 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy8: - ori $t3, $zero, 8 - bltu $a2, $t3, .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8 - xvld $xr16, $a0, 0 - xvfadd.s $xr0, $xr4, $xr16 - addi.d $a0, $a0, 0x20 - xvfsub.s $xr0, $xr0, $xr5 - addi.d $a2, $a2, -8 - xvst $xr0, $a1, 0 - addi.d $a1, $a1, 0x20 # advance output by 8 elements - b .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: - beqz $a2, .LComputeLogSoftmaxOutput.ExitKernel - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: - fld.s $f16, $a0, 0 - fadd.s $f0, $f4, $f16 - - addi.d $a0, $a0, 4 - fsub.s $f0, $f0, $f5 - fst.s $f0, $a1, 0 - - addi.d $a1, $a1, 4 - addi.d $a2, $a2, -1 - bnez $a2, .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 - -.LComputeLogSoftmaxOutput.ExitKernel: - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 - jr $ra - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S deleted file mode 100644 index 96bda3bb12c6f..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLSX.S +++ /dev/null @@ -1,460 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelLSX.s - -Abstract: - - This module implements the kernels for the single precision pooling - operation. - - This implementation uses LSX instructions. - ---*/ - -#define SP_SIZE 32*8 -#define InputBase_arg SP_SIZE+0*8 -#define InputWidth_arg SP_SIZE+1*8 -#define DilatedInputWidth_arg SP_SIZE+2*8 -#define OutputCountLeftPad_arg SP_SIZE+3*8 -#define OutputCount_arg SP_SIZE+4*8 -#define OutputCountRightPad_arg SP_SIZE+5*8 - - .macro FUNCTION_ENTRY FunctionName - - .p2align 4 - .globl \FunctionName\() - .type \FunctionName\(),@function -\FunctionName\(): - - .endm - - - .text - -/*++ - -Macro Description: - - This macro generates code to initialize registers used across the kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - ---*/ - - .macro InitializeKernel PoolingType - -.ifeqs "\PoolingType\()","Maximum" - li.w $s0, 0xFF7FFFFF - vreplgr2vr.w $vr5, $s0 -.endif - -.ifeqs "\PoolingType\()","AverageIncludePad" - vreplgr2vr.w $vr5, $a5 - vffint.s.w $vr5, $vr5 -.endif - - .endm -/*++ - -Macro Description: - - This macro generates the common prologue code for the pooling kernels. - -Arguments: - - PoolingType - Supplies the pooling type string. - ---*/ - - .macro SpoolKernelEntry PoolingType - - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0*8 - st.d $s1, $sp, 1*8 - st.d $s2, $sp, 2*8 - st.d $s3, $sp, 3*8 - st.d $s4, $sp, 4*8 - st.d $ra, $sp, 5*8 - fst.d $f24,$sp, 6*8 - - InitializeKernel \PoolingType\() - # move InputStride to s8 - or $t8, $a4, $r0 - # move StrideWidth to a4 - or $a4, $a2, $r0 - # move DilationWidth to a5 - or $a5, $a3, $r0 - # move Output to a2 - or $a2, $a1, $r0 - - .endm - -/*++ - -Macro Description: - - This macro generates the common epilogue code for the pooling kernels. - -Arguments: - - None. - ---*/ - - .macro SpoolKernelExit - - ld.d $s0, $sp, 0*8 - ld.d $s1, $sp, 1*8 - ld.d $s2, $sp, 2*8 - ld.d $s3, $sp, 3*8 - ld.d $s4, $sp, 4*8 - ld.d $ra, $sp, 5*8 - fld.d $f24,$sp, 6*8 - - addi.d $sp, $sp, SP_SIZE - jr $ra - - .endm - - -/*++ - -Macro Description: - - This macro generates code to clear the pooling intermediates. - - For PoolingType==Maximum, the pooling intermediates are set to the minimum - float value. Otherwise, the pooling intermediates are cleared to zero. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - vr0-vr1 - Supplies the pooling intermediates. - - vr2 - Supplies a vector containing the minimum float value broadcasted, - if PoolingType==Maximum. - ---*/ - - .macro ClearBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - vor.v $vr0, $vr5, $vr5 - vor.v $vr1, $vr5, $vr5 -.else - vxor.v $vr0, $vr0, $vr0 - vxor.v $vr1, $vr1, $vr1 -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" - xor $a1, $a1, $a1 # reset valid block counter -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to sample the input buffer and update the pooling - intermediates as appropriate. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a3 - Supplies the address of the input buffer. - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - a4 - Supplies the StrideWidth parameter (see function description). - - vr0-vr1 - Supplies the pooling intermediates. - ---*/ - - .macro ComputeBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - vld $vr24, $a3, 0 - vfmax.s $vr0, $vr0, $vr24 - vld $vr24, $a3, 16 - vfmax.s $vr1, $vr1, $vr24 -.else - vld $vr24, $a3, 0 - vfadd.s $vr0, $vr0, $vr24 - vld $vr24, $a3, 16 - vfadd.s $vr1, $vr1, $vr24 -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" - # increment valid block counter - addi.d $a1, $a1, 1 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to process and store the pooling intermediates. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a2 - Supplies the address of the output buffer. - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - vr0-vr1 - Supplies the pooling intermediates. - - vr5 - Supplies the kernel size computed by InitializeKernel, if - PoolingType=AverageExcludePad, else the actual kernel size, if - PoolingType=AverageIncludePad. - ---*/ - - .macro PostProcessBlock PoolingType, OutputCount - -// -// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -// blocks. -// - -.ifeqs "\PoolingType\()","AverageExcludePad" - # convert valid block counter - vreplgr2vr.w $vr4, $a1 - vffint.s.w $vr4, $vr4 - vfdiv.s $vr0, $vr0, $vr4 - vfdiv.s $vr1, $vr1, $vr4 -.endif - -// -// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -// - -.ifeqs "\PoolingType\()","AverageIncludePad" - vfdiv.s $vr0, $vr0, $vr5 - vfdiv.s $vr1, $vr1, $vr5 -.endif - -// -// Store the output block in the output buffer. -// - - vst $vr0, $a2, 0 - vst $vr1, $a2, 16 - # advance output by 1 nchw8c block - addi.d $a2, $a2, 8*4 - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute pooling for a vector of input blocks - to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a2 - Supplies the address of the output buffer. - - a4 - Supplies the StrideWidth parameter (see function description). - - a5 - Supplies the DilationWidth parameter (see function description). - - s8 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount - - move $a3, $a0 - move $t1, $a6 - move $t2, $a7 -.if \OutputCount\() == 1 - ld.d $t3, $sp, InputBase_arg - ld.d $t4, $sp, InputWidth_arg - sub.d $t3, $r0, $t3 # keep negative for lea usage below -.endif - ClearBlock \PoolingType\(), \OutputCount\() - beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing - -.L\PoolingType\().\OutputCount\().ProcessNextRow: - or $t6, $t2, $t2 - -.L\PoolingType\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - # (Input - InputBase) >= InputWidth? - add.d $t7, $a3, $t3 - bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding -.endif - ComputeBlock \PoolingType\(), \OutputCount\() - -.L\PoolingType\().\OutputCount\().SkipOverPadding: - add.d $a3, $a3, $a5 # advance input by dilation width - # decrement columns remaining - addi.d $t6, $t6, -1 - bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn - add.d $a3, $a3, $t8 # advance input to next row -.if \OutputCount\() == 1 - ld.d $s0, $sp, DilatedInputWidth_arg - # advance input base to next row - sub.d $t3, $t3, $s0 -.endif - addi.d $t1, $t1, -1 - bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow - -.L\PoolingType\().\OutputCount\().HandlePostProcessing: - PostProcessBlock \PoolingType\(), \OutputCount\() - - .endm -/*++ - -Macro Description: - - This macro generates code for the inner pooling kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SpoolKernelFunction PoolingType, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute pooling for the elements of an - output row for a set of filter rows. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Output (a1) - Supplies the address of the output buffer. - - StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (a3) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (a4) - Supplies the length in bytes to advance the input buffer to - the next input row. - - ActualKernelSize (a5) - Supplies the size of the kernel based on the original - kernel dimensions, used for PoolingType=AverageIncludePad. - - KernelHeight (a6) - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth (a7) - Supplies the width of the kernel to apply. - - InputBase (0)- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (1*8)- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (2*8)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (3*8)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (4*8)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (5*8)- Supplies the number of output elements that include - one or more padding elements from the right edge. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() - SpoolKernelEntry \PoolingType\() - - ld.d $s0, $sp, OutputCountLeftPad_arg - ld.d $s1, $sp, OutputCount_arg - add.d $t0, $s0, $s1 - ld.d $s0, $sp, OutputCountRightPad_arg - add.d $t0, $t0, $s0 - beqz $t0, .L\PoolingType\().ExitKernel - -.L\PoolingType\().ProcessNextOutputCount: - ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 1 - add.d $a0, $a0, $a4 - addi.d $t0, $t0, -1 - bnez $t0, .L\PoolingType\().ProcessNextOutputCount - -.L\PoolingType\().ExitKernel: - SpoolKernelExit - - .endm - -// -// Generate the pooling kernels. -// - - SpoolKernelFunction Maximum, LSX - SpoolKernelFunction AverageExcludePad, LSX - SpoolKernelFunction AverageIncludePad, LSX - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S deleted file mode 100644 index 6e5f0136cd4ab..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasx.S +++ /dev/null @@ -1,238 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelLasx.s - -Abstract: - - This module implements the kernels for the single precision pooling - operation. - - This implementation uses Lasx instructions. - ---*/ - -#include "asmmacro.h" -#include "SpoolKernelLasxCommon.h" - - .text - -/*++ - -Macro Description: - - This macro generates code to initialize registers used across the kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - -Implicit Arguments: - - a5 - Supplies the ActualKernelSize parameter (see function description). - ---*/ - - .macro InitializeKernel PoolingType - -.ifeqs "\PoolingType\()","Maximum" - li.w $s0, 0xFF7FFFFF - xvreplgr2vr.w $xr5, $s0 -.else - xvxor.v $xr5, $xr5, $xr5 -.ifeqs "\PoolingType\()","AverageExcludePad" - move $t6, $a6 - mul.d $t6, $t6, $a7 - xvreplgr2vr.w $xr5, $t6 -.else - xvreplgr2vr.w $xr5, $a5 -.endif - xvffint.s.w $xr5, $xr5 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to clear the pooling intermediates. - - For PoolingType==Maximum, the pooling intermediates are set to the minimum - float value. Otherwise, the pooling intermediates are cleared to zero. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - xr0-xr2 - Supplies the pooling intermediates. - - xr5 - Supplies a vector containing the minimum float value broadcasted, - if PoolingType==Maximum. - ---*/ - - .macro ClearBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "xvor.v $xr0, $xr5, $xr5" - EmitIfCountGE \OutputCount\(), 2, "xvor.v $xr1, $xr5, $xr5" - EmitIfCountGE \OutputCount\(), 3, "xvor.v $xr2, $xr5, $xr5" -.else - EmitIfCountGE \OutputCount\(), 1, "xvxor.v $xr0, $xr0, $xr0" - EmitIfCountGE \OutputCount\(), 2, "xvxor.v $xr1, $xr1, $xr1" - EmitIfCountGE \OutputCount\(), 3, "xvxor.v $xr2, $xr2, $xr2" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - xor $a1, $a1, $a1 # reset valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to sample the input buffer and update the pooling - intermediates as appropriate. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a3 - Supplies the address of the input buffer. - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - a4 - Supplies the StrideWidth parameter (see function description). - - xr0-xr2 - Supplies the pooling intermediates. - ---*/ - - .macro ComputeBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" - EmitIfCountGE \OutputCount\(), 1, "xvfmax.s $xr0, $xr0, $xr16" - EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" - EmitIfCountGE \OutputCount\(), 2, "xvfmax.s $xr1, $xr1, $xr16" - EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" - EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" - EmitIfCountGE \OutputCount\(), 3, "xvfmax.s $xr2, $xr2, $xr16" -.else - EmitIfCountGE \OutputCount\(), 1, "xvld $xr16, $a3, 0" - EmitIfCountGE \OutputCount\(), 1, "xvfadd.s $xr0, $xr0, $xr16" - EmitIfCountGE \OutputCount\(), 2, "xvldx $xr16, $a3, $a4" - EmitIfCountGE \OutputCount\(), 2, "xvfadd.s $xr1, $xr1, $xr16" - EmitIfCountGE \OutputCount\(), 3, "slli.d $s0, $a4, 1" - EmitIfCountGE \OutputCount\(), 3, "xvldx $xr16, $a3, $s0" - EmitIfCountGE \OutputCount\(), 3, "xvfadd.s $xr2, $xr2, $xr16" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - addi.d $a1, $a1, 1 # increment valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to process and store the pooling intermediates. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a2 - Supplies the address of the output buffer. - - a1 - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - xr0-xr2 - Supplies the pooling intermediates. - - xr5 - Supplies the kernel size computed by InitializeKernel, if - PoolingType=AverageExcludePad, else the actual kernel size, if - PoolingType=AverageIncludePad. - ---*/ - - .macro PostProcessBlock PoolingType, OutputCount - -// -// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -// blocks. OutputCount=1 generates code to count the number of blocks accessed by -// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. -// - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - xvxor.v $xr4, $xr4, $xr4 - xvreplgr2vr.w $xr4, $a1 - xvffint.s.w $xr4, $xr4 - xvfdiv.s $xr0, $xr0, $xr4 -.else - EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" - EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" - EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" -.endif -.endif - -// -// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -// - -.ifeqs "\PoolingType\()","AverageIncludePad" - EmitIfCountGE \OutputCount\(), 1, "xvfdiv.s $xr0, $xr0, $xr5" - EmitIfCountGE \OutputCount\(), 2, "xvfdiv.s $xr1, $xr1, $xr5" - EmitIfCountGE \OutputCount\(), 3, "xvfdiv.s $xr2, $xr2, $xr5" -.endif - -// -// Store the output block in the output buffer. -// - - EmitIfCountGE \OutputCount\(), 1, "xvst $xr0, $a2, 0" - EmitIfCountGE \OutputCount\(), 2, "xvst $xr1, $a2, 0x20" - EmitIfCountGE \OutputCount\(), 3, "xvst $xr2, $a2, 0x40" - add_immed $a2,\OutputCount\()*8*4 # advance output by N nchw8c blocks - - .endm - -// -// Generate the pooling kernels. -// - - SpoolKernelFunction Maximum, Lasx - SpoolKernelFunction AverageExcludePad, Lasx - SpoolKernelFunction AverageIncludePad, Lasx - - .end diff --git a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h b/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h deleted file mode 100644 index 066c75d34f3f9..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/SpoolKernelLasxCommon.h +++ /dev/null @@ -1,311 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelasxCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision pooling operation for the Lasx kernels. - ---*/ - -// -// Stack frame layout for the pooling kernels. -// - -#define SP_SIZE 8*8 -#define InputBase_arg SP_SIZE+0*8 -#define InputWidth_arg SP_SIZE+1*8 -#define DilatedInputWidth_arg SP_SIZE+2*8 -#define OutputCountLeftPad_arg SP_SIZE+3*8 -#define OutputCount_arg SP_SIZE+4*8 -#define OutputCountRightPad_arg SP_SIZE+5*8 -/*++ - -Macro Description: - - This macro generates the common prologue code for the pooling kernels. - -Arguments: - - PoolingType - Supplies the pooling type string. - ---*/ - - .macro SpoolKernelEntry PoolingType - - addi.d $sp, $sp, -SP_SIZE - st.d $s0, $sp, 0 - st.d $s1, $sp, 1*8 - fst.d $f16, $sp, 2*8 - st.d $ra, $sp, 5*8 - - InitializeKernel \PoolingType\() - move $t8, $a4 - move $a4, $a2 - move $a5, $a3 - move $a2, $a1 - - .endm - -/*++ - -Macro Description: - - This macro generates the common epilogue code for the pooling kernels. - -Arguments: - - None. - ---*/ - - .macro SpoolKernelExit - - ld.d $s0, $sp, 0 - ld.d $s1, $sp, 1*8 - fld.d $f16, $sp, 2*8 - ld.d $ra, $sp, 5*8 - addi.d $sp, $sp, SP_SIZE - jr $ra - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute pooling for a vector of input blocks - to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - a0 - Supplies the address of the input buffer. - - a2 - Supplies the address of the output buffer. - - a4 - Supplies the StrideWidth parameter (see function description). - - a5 - Supplies the DilationWidth parameter (see function description). - - t8 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount - - move $a3, $a0 - move $t1, $a6 - move $t2, $a7 -.if \OutputCount\() == 1 - ld.d $t3, $sp, InputBase_arg - ld.d $t4, $sp, InputWidth_arg - sub.d $t3, $zero, $t3 -.endif - ClearBlock \PoolingType\(), \OutputCount\() - beqz $t1, .L\PoolingType\().\OutputCount\().HandlePostProcessing - -.L\PoolingType\().\OutputCount\().ProcessNextRow: - move $t6, $t2 - -.L\PoolingType\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - add.d $t7, $a3, $t3 # compute (Input - InputBase) - # (Input - InputBase) >= InputWidth? - bgeu $t7, $t4, .L\PoolingType\().\OutputCount\().SkipOverPadding -.endif - ComputeBlock \PoolingType\(), \OutputCount\() - -.L\PoolingType\().\OutputCount\().SkipOverPadding: - add.d $a3, $a3, $a5 # advance input by dilation width - addi.d $t6, $t6, -1 # decrement columns remaining - bnez $t6, .L\PoolingType\().\OutputCount\().ProcessNextColumn - add.d $a3, $a3, $t8 # advance input to next row -.if \OutputCount\() == 1 - ld.d $s0, $sp, DilatedInputWidth_arg - sub.d $t3, $t3, $s0 - # advance input base to next row -.endif - addi.d $t1, $t1, -1 - bnez $t1, .L\PoolingType\().\OutputCount\().ProcessNextRow - -.L\PoolingType\().\OutputCount\().HandlePostProcessing: - PostProcessBlock \PoolingType\(), \OutputCount\() - - .endm -/*++ - -Macro Description: - - This macro generates code for the inner pooling kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SpoolKernelFunction PoolingType, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute pooling for the elements of an - output row for a set of filter rows. - -Arguments: - - Input (a0) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Output (a1) - Supplies the address of the output buffer. - - StrideWidth (a2) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (a3) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (a4) - Supplies the length in bytes to advance the input buffer to - the next input row. - - ActualKernelSize (a5) - Supplies the size of the kernel based on the original - kernel dimensions, used for PoolingType=AverageIncludePad. - - KernelHeight (a6) - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth (a7)- Supplies the width of the kernel to apply. - - InputBase (sp + 0)- Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth (sp + 0x8)- Supplies the length in bytes of the blocked input width. - - DilatedInputWidth (sp + 0x10)- Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad (sp + 0x18)- Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount (sp + 0x20)- Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad (sp + 0x28)- Supplies the number of output elements that include - one or more padding elements from the right edge. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() - - SpoolKernelEntry \PoolingType\() - -.L\PoolingType\().ProcessOutputCountLeftPad: - ld.d $t0, $sp, OutputCountLeftPad_arg - - beqz $t0, .L\PoolingType\().ProcessOutputCount - bl MlasPool\PoolingType\()FloatSingle\Isa\() - -.L\PoolingType\().ProcessOutputCount: - ld.d $t0, $sp, OutputCount_arg - li.d $s0, 3 - bltu $t0, $s0, .L\PoolingType\().ProcessRemainingOutputCount - -.L\PoolingType\().ProcessNextOutputCountBy3: - ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 3 - slli.d $s0, $a4, 1 - add.d $t6, $s0, $a4 - add.d $a0, $a0, $t6 # advance input by 3 elements - addi.d $t0, $t0, -3 - li.d $s0, 3 - bgeu $t0, $s0, .L\PoolingType\().ProcessNextOutputCountBy3 - -.L\PoolingType\().ProcessRemainingOutputCount: - -.L\PoolingType\().ProcessOutputCountRightPad: - ld.d $s0, $sp, OutputCountRightPad_arg - add.d $t0, $t0, $s0 - beqz $t0, .L\PoolingType\().ExitKernel - bl MlasPool\PoolingType\()FloatSingle\Isa\() - -.L\PoolingType\().ExitKernel: - xvinsgr2vr.d $xr0, $zero, 2 - xvinsgr2vr.d $xr0, $zero, 3 - xvinsgr2vr.d $xr1, $zero, 2 - xvinsgr2vr.d $xr1, $zero, 3 - xvinsgr2vr.d $xr2, $zero, 2 - xvinsgr2vr.d $xr2, $zero, 3 - xvinsgr2vr.d $xr3, $zero, 2 - xvinsgr2vr.d $xr3, $zero, 3 - xvinsgr2vr.d $xr4, $zero, 2 - xvinsgr2vr.d $xr4, $zero, 3 - xvinsgr2vr.d $xr5, $zero, 2 - xvinsgr2vr.d $xr5, $zero, 3 - xvinsgr2vr.d $xr6, $zero, 2 - xvinsgr2vr.d $xr6, $zero, 3 - xvinsgr2vr.d $xr7, $zero, 2 - xvinsgr2vr.d $xr7, $zero, 3 - xvinsgr2vr.d $xr8, $zero, 2 - xvinsgr2vr.d $xr8, $zero, 3 - xvinsgr2vr.d $xr9, $zero, 2 - xvinsgr2vr.d $xr9, $zero, 3 - xvinsgr2vr.d $xr10, $zero, 2 - xvinsgr2vr.d $xr10, $zero, 3 - xvinsgr2vr.d $xr11, $zero, 2 - xvinsgr2vr.d $xr11, $zero, 3 - xvinsgr2vr.d $xr12, $zero, 2 - xvinsgr2vr.d $xr12, $zero, 3 - xvinsgr2vr.d $xr13, $zero, 2 - xvinsgr2vr.d $xr13, $zero, 3 - xvinsgr2vr.d $xr14, $zero, 2 - xvinsgr2vr.d $xr14, $zero, 3 - xvinsgr2vr.d $xr15, $zero, 2 - xvinsgr2vr.d $xr15, $zero, 3 - SpoolKernelExit - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - -MlasPool\PoolingType\()FloatSingle\Isa\(): - st.d $ra, $sp, 6*8 -loopMlasPool\PoolingType\()FloatSingle\Isa\(): - ProcessOutputCountN .LSpoolKernelSingleFrame, \PoolingType\(), 1 - add.d $a0, $a0, $a4 # advance input by 1 element - addi.d $t0, $t0, -1 # decrement output count remaining - bnez $t0, loopMlasPool\PoolingType\()FloatSingle\Isa\() - ld.d $ra, $sp, 6*8 - jr $ra - - .endm diff --git a/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h b/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h deleted file mode 100644 index 837aca77dd883..0000000000000 --- a/onnxruntime/core/mlas/lib/loongarch64/asmmacro.h +++ /dev/null @@ -1,144 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - asmmacro.h - -Abstract: - - This module implements common macros for the assembly modules. - ---*/ - -#define C_UNDERSCORE(symbol) symbol - -.macro vmove dst src - vand.v \dst, \src, \src -.endm - -/*++ - -Macro Description: - - This macro emits the assembler directives to annotate a new function. - -Arguments: - - FunctionName - Supplies the name of the function. - ---*/ - - .macro FUNCTION_ENTRY FunctionName - .align 2 - .globl \FunctionName\() - .type \FunctionName\(),@function -\FunctionName\(): - - .endm - -/*++ - -Macro Description: - - This macro generates an optimization for "add reg,128" which can instead - be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit - value. - -Arguments: - - Register - Supplies the register to be added to. - - Immediate - Supplies the immediate to add to the register. - ---*/ - - .macro add_immed Register, Immediate - -.if (\Immediate\() != 128) - addi.d \Register\(),\Register\(),\Immediate\() -.else - addi.d \Register\(),\Register\(),\Immediate\() # smaller encoding -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count is greater than or - equal to Value. - -Arguments: - - Count - Supplies the variable used in the comparison. - - Value - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCountGE Count1, Value1, Statement - -.if (\Count1\() >= \Value1\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count1 is greater than or - equal to Value1 and Count2 is greater than or equal to Value2. - -Arguments: - - Count1 - Supplies the variable used in the comparison. - - Value1 - Supplies the static used in the comparison. - - Count2 - Supplies the variable used in the comparison. - - Value2 - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement - -.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro emits the statement for each register listed in the register - list. The statement can use RegItem to access the current register. - -Arguments: - - RegList - Supplies the list of registers. - - Statement - Supplies the statement to emit. - ---*/ - - .macro EmitForEachRegister RegList, Statement - - .irp RegItem, \RegList\() - \Statement\() - .endr - - .endm diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h deleted file mode 100644 index 13ea8d96c20e4..0000000000000 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ /dev/null @@ -1,2609 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlasi.h - -Abstract: - - This module contains the private data structures and procedure prototypes - for the Microsoft Machine Learning algebra subprogram library. - ---*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef MLAS_NO_EXCEPTION -#if defined(__ANDROID__) -#include -#else -#include -#endif -#endif // MLAS_NO_EXCEPTION - -#include "mlas.h" - -#if defined(_WIN32) -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#include -#else -#if defined(__arm__) || defined(__aarch64__) -#include -#endif -#if defined(__x86_64__) || defined(__i386__) -#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards -#include -#endif -#if defined(__GNUC__) && __GNUC__ >= 12 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. -#include -#pragma GCC diagnostic pop -#else -#include -#endif -#endif -#if defined(__VSX__) -#include -// Undefine unwanted aliases from altivec.h. -#undef vector -#undef pixel -#undef bool -#endif -#if defined(__loongarch64) -#include -#endif -#if defined(MLAS_TARGET_WASM_SIMD) -#include -#endif -#endif - -// -// Macro to place variables at a specified alignment. -// - -#ifdef _WIN32 -#define MLAS_DECLSPEC_ALIGN(variable, alignment) DECLSPEC_ALIGN(alignment) variable -#else -#define MLAS_DECLSPEC_ALIGN(variable, alignment) variable __attribute__ ((aligned(alignment))) -#endif - -// -// Macro to force inline expansion of a function. -// - -#if defined(_MSC_VER) -#define MLAS_FORCEINLINE __forceinline -#else -#define MLAS_FORCEINLINE __attribute__ ((always_inline)) inline -#endif - -// -// Macro to tag globals as internal data shared with kernels written in -// assembly. These globals are marked with having hidden visibility to avoid -// needing to access the data through the global object table. -// - -#if defined(_MSC_VER) -#define MLAS_INTERNAL_DATA extern "C" -#else -#define MLAS_INTERNAL_DATA extern "C" __attribute ((visibility("hidden"))) -#endif - -// -// Macro to suppress unreferenced parameter warnings. -// - -#define MLAS_UNREFERENCED_PARAMETER(parameter) ((void)(parameter)) - -#ifdef MLAS_NO_EXCEPTION - -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) -{ -#if defined(__ANDROID__) - __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); -#else - // TODO, consider changing the output of the error message from std::cerr to logging when the - // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr - // output might not be easily accesible on some systems such as mobile - // TODO, see if we need to change the output of the error message from std::cerr to NSLog for - // iOS - std::cerr << msg << std::endl; -#endif -} - -#define MLAS_THROW_EX(ex, what) \ - do { \ - std::string msg = #ex; \ - msg.append(what); \ - MlasPrintFinalMessage(msg); \ - abort(); \ - } while (false) - -#else - -#define MLAS_THROW_EX(ex, ...) throw ex(__VA_ARGS__) - -#endif // MLAS_NO_EXCEPTION - -// -// Select the threading model. -// -// N.B. BUILD_MLAS_NO_ONNXRUNTIME is used to build MLAS test code outside -// of the ONNX Runtime source tree. OpenMP may or may not be enabled in this -// configuration. -// - -#if !defined(BUILD_MLAS_NO_ONNXRUNTIME) -#include "core/platform/threadpool.h" - -#include "core/common/cpuid_info.h" -using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; - -#include "core/framework/float16.h" - -#else // BUILD_MLAS_NO_ONNXRUNTIME - -class MLASCPUIDInfo -{ - public: - static const MLASCPUIDInfo& GetCPUIDInfo() - { - static MLASCPUIDInfo cpuid_info; - return cpuid_info; - } - - // ARM - bool HasArmNeonDot() const { return has_arm_neon_dot_; } - - bool HasFp16VectorAcceleration() const { return has_fp16_; } - - uint32_t GetCurrentCoreIdx() const { return 0xFFFFFFFF; } - - int32_t GetCurrentUarch() const { return -1; } - - int32_t GetCoreUarch(uint32_t coreId) const { return -1; } - - bool IsCoreArmv8NarrowLd(uint32_t coreId) const { return false; } - - bool IsCurrentCoreArmv8NarrowLd() const { return false; } - - bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } - - bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } - - bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } - - private: - MLASCPUIDInfo(); - - bool has_arm_neon_dot_{false}; - bool has_fp16_{false}; - bool has_arm_neon_i8mm_{false}; - bool has_arm_sve_i8mm_{false}; - bool has_arm_neon_bf16_{false}; -}; -using MLAS_CPUIDINFO = MLASCPUIDInfo; - -#if defined(MLAS_TARGET_ARM64) -/** - * @brief IDs for cpu microarchitectures. - * - * Copied from python cpuinfo package. Can't use the definition - * from cpuinfo directly as it causes lots of compilation issues - * in many platforms that we support. - */ -enum MlasUArch { - cpuinfo_uarch_unknown = 0, - - /** ARM Cortex-A32. */ - cpuinfo_uarch_cortex_a32 = 0x00300332, - /** ARM Cortex-A35. */ - cpuinfo_uarch_cortex_a35 = 0x00300335, - /** ARM Cortex-A53. */ - cpuinfo_uarch_cortex_a53 = 0x00300353, - /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */ - cpuinfo_uarch_cortex_a55r0 = 0x00300354, - /** ARM Cortex-A55. */ - cpuinfo_uarch_cortex_a55 = 0x00300355, - /** ARM Cortex-A57. */ - cpuinfo_uarch_cortex_a57 = 0x00300357, - /** ARM Cortex-A65. */ - cpuinfo_uarch_cortex_a65 = 0x00300365, - /** ARM Cortex-A72. */ - cpuinfo_uarch_cortex_a72 = 0x00300372, - /** ARM Cortex-A73. */ - cpuinfo_uarch_cortex_a73 = 0x00300373, - /** ARM Cortex-A75. */ - cpuinfo_uarch_cortex_a75 = 0x00300375, - /** ARM Cortex-A76. */ - cpuinfo_uarch_cortex_a76 = 0x00300376, - /** ARM Cortex-A77. */ - cpuinfo_uarch_cortex_a77 = 0x00300377, - /** ARM Cortex-A78. */ - cpuinfo_uarch_cortex_a78 = 0x00300378, -}; - -#endif // MLAS_TARGET_ARM64 - -// -// Define MLAS_FP16 -// -#include "mlas_float16.h" - -namespace onnxruntime -{ -struct MLFloat16 { - uint16_t val{0}; - - MLFloat16() = default; - explicit constexpr MLFloat16(uint16_t x) : val(x) {} - explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} - - float ToFloat() const { return MLAS_Half2Float(val); } - - operator float() const { return ToFloat(); } - - MLFloat16& operator=(float ff) - { - val = MLAS_Float2Half(ff); - return *this; - } -}; - -inline bool -operator==(const MLFloat16& left, const MLFloat16& right) -{ - return left.val == right.val; -} - -inline bool -operator!=(const MLFloat16& left, const MLFloat16& right) -{ - return left.val != right.val; -} - -} - -#endif // BUILD_MLAS_NO_ONNXRUNTIME - -static_assert(sizeof(MLAS_FP16) == FP16_SIZE); - - -// -// Define the maximum number of threads supported by this implementation. -// - -#define MLAS_MAXIMUM_THREAD_COUNT 16 - -// -// Define the default strides to step through slices of the input matrices. -// - -#define MLAS_SGEMM_STRIDEN 128 -#define MLAS_SGEMM_STRIDEK 128 -#define MLAS_SGEMM_PACKED_STRIDEN 128 -#define MLAS_SGEMM_PACKED_STRIDEK 256 -#define MLAS_DGEMM_STRIDEN 64 -#define MLAS_DGEMM_STRIDEK 128 - -// -// Define the alignment for segmenting a GEMM operation across multiple -// threads. -// -// All of the SGEMM kernels can efficiently handle 16 elements. AVX512F can -// efficiently handle 32 elements, but making this value dynamic is not worth -// the effort at this time. -// - -#define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 -#define MLAS_DGEMM_STRIDEN_THREAD_ALIGN 8 -#define MLAS_QGEMM_STRIDEN_THREAD_ALIGN 16 - -// -// Define the prototypes of the platform optimized routines. -// - -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_TARGET_LARCH64) - -typedef -size_t -(MLASCALL MLAS_GEMM_FLOAT_KERNEL)( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha, - bool ZeroMode - ); - -typedef -size_t -(MLASCALL MLAS_GEMM_DOUBLE_KERNEL)( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - double alpha, - bool ZeroMode - ); - -#else - -#if defined(__aarch64__) && defined(__linux__) -typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( - const float* A, - const bfloat16_t* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - const float* Bias -); -#endif - -typedef -size_t -(MLASCALL MLAS_GEMM_FLOAT_KERNEL)( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ); - -typedef -size_t -(MLASCALL MLAS_GEMM_DOUBLE_KERNEL)( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - double alpha - ); - -#endif - -typedef -void -(MLASCALL MLAS_GEMV_FLOAT_KERNEL)( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t ldb, - bool ZeroMode - ); - -typedef -void -(MLASCALL MLAS_SGEMM_KERNEL_M1_ROUTINE)( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t ldb, - float beta - ); - -typedef -void -(MLASCALL MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE)( - float* D, - const float* B, - size_t ldb - ); - -typedef -size_t -(MLASCALL MLAS_GEMM_U8S8_KERNEL)( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); - -typedef -size_t -(MLASCALL MLAS_GEMV_U8S8_KERNEL)( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t CountK, - size_t CountN, - size_t ldb - ); - -typedef -size_t -(MLASCALL MLAS_GEMM_U8U8_KERNEL)( - const int16_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); - -typedef -void -(MLASCALL MLAS_CONV_FLOAT_KERNEL)( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned KernelFlags - ); - -typedef -void -(MLASCALL MLAS_CONV_DEPTHWISE_FLOAT_KERNEL)( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned KernelFlags - ); - -typedef -void -(MLASCALL MLAS_CONV_POINTWISE_FLOAT_KERNEL)( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t InputChannels, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t OutputCount, - const float* Bias, - unsigned KernelFlags - ); - -typedef -void -(MLASCALL MLAS_POOL_FLOAT_KERNEL)( - const float* Input, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t ActualKernelSize, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad - ); - -typedef -void -(MLASCALL MLAS_COMPUTE_UNARY_FLOAT_KERNEL)( - const float* Input, - float* Output, - size_t N - ); - -typedef -float -(MLASCALL MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL)( - const float* Input, - float* Output, - size_t N, - const float* NegativeMaximum - ); - -typedef -void -(MLASCALL MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL)( - float* Output, - size_t N, - const float* Parameters - ); - -typedef -void -(MLASCALL MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL)( - const float* Input, - float* Output, - size_t N, - const float* Parameters - ); - -typedef -float -(MLASCALL MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL)( - const float* Input, - size_t N - ); - -typedef -void -(MLASCALL MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL)( - const float* Input, - float* Min, - float* Max, - size_t N - ); - -typedef -void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( - const unsigned short* Source, - float* Destination, - size_t Count -); - -typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( - const float* Source, - unsigned short* Destination, - size_t Count -); - -typedef -void -(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( - const int8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const int8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - int8_t* OutputC, - size_t N, - bool IsScalarB - ); - -typedef -void -(MLASCALL MLAS_QLINEAR_BINARY_OP_U8_KERNEL)( - const uint8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const uint8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - uint8_t* OutputC, - size_t N, - bool IsScalarB - ); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_U8_KERNEL)( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_S8_KERNEL)( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_U16_KERNEL)( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_S16_KERNEL)( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_U4_KERNEL)( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint); - -typedef -void -(MLASCALL MLAS_QUANTIZE_LINEAR_S4_KERNEL)( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint); - -template -struct MLAS_QUANT_KERNEL -{ - typedef - void - (MLASCALL DepthwiseKernel)( - const InputType* const* Input, - InputType InputZeroPoint, - const FilterType* Filter, - FilterType FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); -}; - -extern "C" { - -#if defined(MLAS_TARGET_AMD64_IX86) - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse; - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx; -#if defined(MLAS_TARGET_AMD64) - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3; - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelSse; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelAvx; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelFma3; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelAvx512F; -#endif -#elif defined(MLAS_TARGET_POWER) - MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel; - MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelPOWER10; - MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel; - MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; - MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; - MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; -#elif defined(MLAS_TARGET_LARCH64) - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; - MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLSX; - MLAS_GEMM_DOUBLE_KERNEL MlasGemmDoubleKernelLasx; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLSX; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLSX; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLSX; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLSX; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelLasx; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelLasx; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelLasx; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelLasx; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLSX; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLSX; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLSX; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelLasx; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelLasx; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelLasx; - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4LSX; - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Lasx; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelLasx; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelLasx; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelLasx; -#else - MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; - MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; -#if defined(__aarch64__) && defined(__linux__) - MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; - MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; -#endif - MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; - MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; -#endif - -#if defined(MLAS_TARGET_AMD64) - MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1Avx; - MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1TransposeBAvx; -#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM) - MLAS_GEMV_FLOAT_KERNEL MlasGemvFloatKernel; -#endif - -#if defined(MLAS_TARGET_AMD64) - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Sse; - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE MlasSgemmTransposePackB16x4Avx; -#endif - -#if defined(MLAS_TARGET_AMD64) - MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx2; - MLAS_GEMV_U8S8_KERNEL MlasGemvU8S8KernelAvx2; - MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512Core; - MLAS_GEMV_U8S8_KERNEL MlasGemvU8S8KernelAvx512Core; - MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvx512Vnni; - MLAS_GEMV_U8S8_KERNEL MlasGemvU8S8KernelAvx512Vnni; - MLAS_GEMM_U8S8_KERNEL MlasGemmU8S8KernelAvxVnni; - MLAS_GEMV_U8S8_KERNEL MlasGemvU8S8KernelAvxVnni; - MLAS_GEMM_U8S8_KERNEL MlasGemmU8U8KernelAvx2Vnni; - MLAS_GEMM_U8S8_KERNEL MlasGemmS8S8KernelAvx2Vnni; - MLAS_GEMM_U8S8_KERNEL MlasGemmS8U8KernelAvx2Vnni; - MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx2; - MLAS_GEMM_U8U8_KERNEL MlasGemmU8U8KernelAvx512Core; -#endif - -#if defined(MLAS_TARGET_AMD64) - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelSse; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelSse; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelSse; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelSse; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelAvx; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelAvx; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelAvx; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelAvx; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelFma3; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelFma3; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelFma3; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelFma3; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelAvx512F; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelAvx512F; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelAvx512F; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelAvx512F; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelSse; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelAvx; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelAvx512F; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelSse; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelAvx; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelAvx512F; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelSse; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelAvx; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelAvx512F; -#else - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernel; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernel; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernel; - MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernel; - MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernel; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernel; - MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernel; -#endif - - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32Kernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernel; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL MlasComputeSumExpF32Kernel; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32Kernel; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32Kernel; - MLAS_QLINEAR_BINARY_OP_S8_KERNEL MlasQLinearAddS8Kernel; - MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8Kernel; - MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8Kernel; - MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8Kernel; - MLAS_QUANTIZE_LINEAR_S16_KERNEL MlasQuantizeLinearS16Kernel; - MLAS_QUANTIZE_LINEAR_U16_KERNEL MlasQuantizeLinearU16Kernel; - MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; - MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; -#if defined(MLAS_TARGET_AMD64) - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeLogisticF32KernelFma3; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeTanhF32KernelFma3; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL MlasComputeSumExpF32KernelFma3; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL MlasComputeSumExpF32KernelAvx512F; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelAvx; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelAvx; - MLAS_QLINEAR_BINARY_OP_S8_KERNEL MlasQLinearAddS8KernelAvx2; - MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; - MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; - MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; -#endif - - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; - MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; -#if defined(MLAS_TARGET_AMD64) - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx512F; - MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; -#endif - -#if defined(MLAS_TARGET_AMD64) - MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; - MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; - MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2; - MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; -#endif - -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelNeon; - MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelNeon; -#endif -} - -// -// Define the default preferred byte alignment for buffers. -// -// MLAS_TARGET_AMD64_IX86: The typical architecture uses AVX instructions -// accessing 256-bit vectors. MLAS_TARGET_AMD64 returns a larger value if the -// platform supports 512-bit vectors to ensure that vectors are not split. -// -// MLAS_TARGET_ARM64: The kernels use "load pair" instructions to access 128-bit -// vectors, so this value keeps both vectors in the same cache line. -// -// MLAS_TARGET_ARM: Using 16 for a single 128-bit vector may be sufficient for -// this architecture, but the ONNX Runtime has historically used this larger -// value. -// - -#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 64 - -// -// Define the target number of per-thread multiplies before using another -// thread to perform additional work. -// - -#define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) -#define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) -#define MLAS_QGEMM_THREAD_COMPLEXITY 65536 - -#if defined(__aarch64__) && defined(__linux__) -#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) -#endif - -// -// Single-threaded single precision matrix/matrix multiply operation. -// - -void -MlasSgemmOperation( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - float beta, - float* C, - size_t ldc - ); - -// -// Quantized integer matrix/matrix dispatch structure. -// - -struct MLAS_GEMM_QUANT_DISPATCH; - -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2Vnni; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchAvx2Vnni; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchAvx2Vnni; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; -extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; - -// -// Symmetric quantized qgemm dispatch structure -// -struct MLAS_SYMM_QGEMM_DISPATCH; -extern const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon; -extern const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot; - -// -// Symmetric quantized integer convolution dispatch structure. -// - -struct MLAS_CONV_SYM_DISPATCH; - -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchNeon; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; - -// -// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. -// - -struct MLAS_Q8Q4GEMM_DISPATCH; - -extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; - -// -// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. -// - -struct MLAS_FPQ4GEMM_DISPATCH; - -extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; - -// -// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. -// - -struct MLAS_SQNBIT_GEMM_DISPATCH; - -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; - -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; - -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; - -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; - -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; - -// -// Quantized depthwise convolution kernels. -// - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const InputType* const* Input, - InputType InputZeroPoint, - const FilterType* Filter, - FilterType FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernelAvx2( - const InputType* const* Input, - InputType InputZeroPoint, - const FilterType* Filter, - FilterType FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -// -// Define the kernel flags for conv sym -// - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 0x00000001 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 0x00000002 - -// -// Define the post-processing parameters for conv sym: bias and re-quant params -// - -struct MLAS_CONV_SYM_POST_PROCESS_PARAMS { - const int32_t* Bias; - const float* Scale; - float MinimumValue; - float MaximumValue; - int32_t OutputZeroPoint; -}; - -// -// Environment information class. -// - -enum MlasCoreType { mlas_core_unknown = 0, mlas_core_little = 2, mlas_core_big = 3 }; - - -struct MLAS_PLATFORM { - - MLAS_PLATFORM(void); - -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) - MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; -#endif -#if defined(MLAS_TARGET_LARCH64) - const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; - const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; - MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; - MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; - MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; - MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; - MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; - MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; - uint32_t NchwcBlockSize; -#endif -#if defined(MLAS_TARGET_AMD64_IX86) - const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; - const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; - const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch{&MlasGemmQuantDispatchDefault}; - const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch{&MlasGemmQuantDispatchDefault}; -#elif defined(MLAS_TARGET_ARM64) - const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; - const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; - const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; -#endif - const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; - - const MLAS_CONV_SYM_DISPATCH* ConvSymU8S8Dispatch{nullptr}; - const MLAS_CONV_SYM_DISPATCH* ConvSymS8S8Dispatch{nullptr}; - - MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseU8S8Kernel; - MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseU8U8Kernel; - MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8S8Kernel; - MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8U8Kernel; - -#if defined(MLAS_TARGET_POWER) - MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; - const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; - MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; - MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; - MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; - MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; - MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; - MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; -#endif -#if defined(MLAS_TARGET_AMD64) - MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; - MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* TransposePackB16x4Routine; - MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; - MLAS_GEMM_U8S8_KERNEL* GemmU8S8Kernel; - MLAS_GEMM_U8S8_KERNEL* GemmS8S8Kernel; - MLAS_GEMM_U8S8_KERNEL* GemmS8U8Kernel; - MLAS_GEMV_U8S8_KERNEL* GemvU8S8Kernel; - MLAS_GEMM_U8U8_KERNEL* GemmU8U8Kernel; - MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; - MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; - MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; - MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; - MLAS_QLINEAR_BINARY_OP_S8_KERNEL* QLinearAddS8Kernel; - MLAS_QLINEAR_BINARY_OP_U8_KERNEL* QLinearAddU8Kernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ComputeExpF32Kernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* TanhKernelRoutine; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; - MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; - MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; - MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; - MLAS_QUANTIZE_LINEAR_S16_KERNEL* QuantizeLinearS16Kernel; - MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; - MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; - MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; - uint32_t NchwcBlockSize; - uint32_t PreferredBufferAlignment; - int32_t MaximumThreadCount; -#elif defined(MLAS_TARGET_ARM64) - static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; -#else - static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; -#endif - - const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; - const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; - - MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; - MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; -}; - -inline -MLAS_PLATFORM& GetMlasPlatform(){ - static MLAS_PLATFORM MlasPlatform; - return MlasPlatform; -} - -// -// Threading support. -// - -typedef -void -(MLAS_THREADED_ROUTINE)( - void* Context, - ptrdiff_t Index - ); - -void -MlasExecuteThreaded( - MLAS_THREADED_ROUTINE* ThreadedRoutine, - void* Context, - ptrdiff_t Iterations, - MLAS_THREADPOOL* ThreadPool - ); - -constexpr -size_t -MlasDivRoundup(size_t up, size_t down) -{ - return (up + down - 1) / down; -} - -/** - * @brief Distribute multiple iterations of work over a thread pool if supported - * - * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP - * @param Iterations [IN] Total number of iterations - * @param Work [IN] Logic for computing a range of iterations [begin, end) - */ -void -MlasTrySimpleParallel( - MLAS_THREADPOOL* ThreadPool, - const std::ptrdiff_t Iterations, - const std::function& Work - ); - - -/** - * @brief Distribute many iterations of work over a thread pool if supported. - * This function is for small workloads in non-performance critical situation. - * - * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP - * @param Iterations [IN] Total number of iterations - * @param Work [IN] Logic for computing a range of iterations [begin, end) - */ -void -MlasTryBatchParallel( - MLAS_THREADPOOL * ThreadPool, - const std::ptrdiff_t Iterations, - const std::function& Work - ); - - -inline -ptrdiff_t -MlasGetMaximumThreadCount( - MLAS_THREADPOOL* ThreadPool - ) -{ -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - return 1; -#else - return onnxruntime::concurrency::ThreadPool::DegreeOfParallelism(ThreadPool); -#endif -} - -inline -void -MlasPartitionWork( - ptrdiff_t ThreadId, - ptrdiff_t ThreadCount, - size_t TotalWork, - size_t* WorkIndex, - size_t* WorkRemaining - ) -{ - const size_t WorkPerThread = TotalWork / ThreadCount; - const size_t WorkPerThreadExtra = TotalWork % ThreadCount; - - if (size_t(ThreadId) < WorkPerThreadExtra) { - *WorkIndex = (WorkPerThread + 1) * ThreadId; - *WorkRemaining = WorkPerThread + 1; - } else { - *WorkIndex = WorkPerThread * ThreadId + WorkPerThreadExtra; - *WorkRemaining = WorkPerThread; - } -} - -// -// Define the minimum floating point value (and its bit value equivalent) that -// has no fractional bits. This number can be used for fast rounding of floating -// point numbers to integers. -// - -#define MLAS_ROUNDING_BIAS_MAGIC 12582912.f -#define MLAS_ROUNDING_BIAS_MAGIC_BITS 0x4B400000 - -// -// Helpers to cast a floating point type to and from an integer bit format. -// -#if defined(_MSC_VER) && !defined(__clang__) - #pragma warning(push) - // VC++ suggests we can attempt to make 'MlasBitsOfFp32' constexpr, but it is not valid. - #pragma warning(disable:26497) -#endif - -MLAS_FORCEINLINE -uint32_t -MlasBitsOfFp32( - float FloatValue - ) -{ - union { - uint32_t IntegerValue; - float FloatValue; - } u; - u.FloatValue = FloatValue; - return u.IntegerValue; -} - -MLAS_FORCEINLINE -float -MlasFp32FromBits( - uint32_t IntegerValue - ) -{ - union { - uint32_t IntegerValue; - float FloatValue; - } u; - u.IntegerValue = IntegerValue; - return u.FloatValue; -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - -#if defined(MLAS_TARGET_WASM_SCALAR) - -void -MLASCALL -MlasConvDepthwiseFloat_CHW( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - float* Output, - const float* Zeros - ); - -#endif - - -// -// Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable -// cross-compiler support. -// -// Also define additional standard NEON intrinsics using the MSVC aliases. -// - -#if defined(_M_ARM64) -#ifndef vmaxvq_f32 -#define vmaxvq_f32(src) neon_fmaxv(src) -#endif -#ifndef vminvq_f32 -#define vminvq_f32(src) neon_fminv(src) -#endif -#endif - -// -// Cross-platform wrappers for 32-bit vector intrinsics. -// - -#if defined(MLAS_TARGET_ARM) -#define MLAS_NEON_INTRINSICS -#define MLAS_NEON32_INTRINSICS -#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) -#define MLAS_NEON_INTRINSICS -#define MLAS_NEON64_INTRINSICS -#elif defined(MLAS_TARGET_POWER) -#define MLAS_VSX_INTRINSICS -#elif defined(MLAS_TARGET_AMD64_IX86) -#define MLAS_SSE2_INTRINSICS -#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) -#define MLAS_SSE41_INTRINSICS -#endif -#if defined(__AVX__) -#define MLAS_AVX_INTRINSICS -#endif -#if defined(__AVX2__) -#define MLAS_AVX2_INTRINSICS -#endif -#if defined(__FMA__) || (defined(_MSC_VER) && defined(__AVX2__)) -#define MLAS_FMA3_INTRINSICS -#endif -#elif defined(MLAS_TARGET_WASM_SIMD) -#define MLAS_WASM_SIMD_INTRINSICS -#elif defined(MLAS_TARGET_LARCH64) -#define MLAS_LSX_INTRINSICS -#endif - -#if defined(MLAS_NEON_INTRINSICS) -typedef float32x4_t MLAS_FLOAT32X4; -typedef int32x4_t MLAS_INT32X4; -#elif defined(MLAS_SSE2_INTRINSICS) -typedef __m128 MLAS_FLOAT32X4; -typedef __m128i MLAS_INT32X4; -#elif defined(MLAS_VSX_INTRINSICS) -typedef __vector float MLAS_FLOAT32X4; -typedef __vector int MLAS_INT32X4; -typedef __vector unsigned MLAS_UINT32X4; -#elif defined(MLAS_WASM_SIMD_INTRINSICS) -typedef v128_t MLAS_FLOAT32X4; -typedef v128_t MLAS_INT32X4; -#elif defined(MLAS_LSX_INTRINSICS) -typedef __m128 MLAS_FLOAT32X4; -typedef __m128i MLAS_INT32X4; -#else -typedef float MLAS_FLOAT32X4 __attribute__ ((vector_size(16))); -typedef int32_t MLAS_INT32X4 __attribute__ ((vector_size(16))); -#endif - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasReinterpretAsInt32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vreinterpretq_s32_f32(Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_castps_si128(Vector); -#elif defined(MLAS_LSX_INTRINSICS) - return (MLAS_INT32X4)Vector; -#else - return MLAS_INT32X4(Vector); -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vcvtq_s32_f32(Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_cvttps_epi32(Vector); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_cts(Vector, 0); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vftint_w_s(Vector); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return (MLAS_INT32X4)__builtin_convertvector((__f32x4)Vector, __i32x4); -#else - return MLAS_INT32X4{int32_t(Vector[0]), int32_t(Vector[1]), int32_t(Vector[2]), int32_t(Vector[3])}; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasCastToFloat32x4(MLAS_INT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vcvtq_f32_s32(Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_cvtepi32_ps(Vector); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_ctf(Vector, 0); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_convert_i32x4(Vector); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vffint_s_w(Vector); -#else - return MLAS_FLOAT32X4{float(Vector[0]), float(Vector[1]), float(Vector[2]), float(Vector[3])}; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasBroadcastInt32x4(int32_t Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vdupq_n_s32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_set1_epi32(Value); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_splats(Value); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vreplgr2vr_w(Value); -#else - return MLAS_INT32X4{Value, Value, Value, Value}; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasLoadInt32x4(const int32_t* Buffer) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vld1q_s32(Buffer); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_loadu_si128((const __m128i*)Buffer); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_vsx_ld(0, Buffer); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_load(Buffer); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vld((const MLAS_INT32X4*)Buffer, 0); -#else - return *((MLAS_INT32X4*)Buffer); -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - vst1q_s32(Buffer, Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - _mm_storeu_si128((__m128i*)Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - vec_vsx_st(Vector, 0, Buffer); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - wasm_v128_store(Buffer, Vector); -#elif defined(MLAS_LSX_INTRINSICS) - __lsx_vst(Vector, (MLAS_INT32X4 *)Buffer, 0); -#else - *((MLAS_INT32X4*)Buffer) = Vector; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasAddInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vaddq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_add_epi32(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_add(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_add(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vadd_w(Vector1, Vector2); -#else - return Vector1 + Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasSubtractInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vsubq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_sub_epi32(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_sub(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vsub_w(Vector1, Vector2); -#else - return Vector1 - Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasAndInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vandq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_and_si128(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_and(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vand_v(Vector1, Vector2); -#else - return Vector1 & Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasOrInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vorrq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_or_si128(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_or(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vor_v(Vector1, Vector2); -#else - return Vector1 | Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasAndNotInt32x4(MLAS_INT32X4 VectorNot, MLAS_INT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vandq_s32(vmvnq_s32(VectorNot), Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_andnot_si128(VectorNot, Vector); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_andnot(Vector, VectorNot); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vandn_v(VectorNot, Vector); -#else - return (~VectorNot) & Vector; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return veorq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_xor_si128(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_xor(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_xor(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vxor_v(Vector1, Vector2); -#else - return Vector1 ^ Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasBlendInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2, MLAS_INT32X4 Selection) -{ - return MlasOrInt32x4(MlasAndInt32x4(Vector2, Selection), MlasAndNotInt32x4(Selection, Vector1)); -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftLeftInt32x4(MLAS_INT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vshlq_n_s32(Vector, ShiftCount); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_slli_epi32(Vector, ShiftCount); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_shl(Vector, ShiftCount); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vslli_w(Vector, ShiftCount); -#else - return Vector << ShiftCount; -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vmaxq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE41_INTRINSICS) - return _mm_max_epi32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector1, Vector2)); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_vmaxsw(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_max(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vmax_w(Vector1, Vector2); -#else - return MlasBlendInt32x4(Vector2, Vector1, Vector1 > Vector2); -#endif -} - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vminq_s32(Vector1, Vector2); -#elif defined(MLAS_SSE41_INTRINSICS) - return _mm_min_epi32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector2, Vector1)); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_vminsw(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_min(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vmin_w(Vector1, Vector2); -#else - return MlasBlendInt32x4(Vector2, Vector1, Vector2 > Vector1); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasReinterpretAsFloat32x4(MLAS_INT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vreinterpretq_f32_s32(Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_castsi128_ps(Vector); -#elif defined(MLAS_LSX_INTRINSICS) - return MLAS_FLOAT32X4(Vector); -#else - return MLAS_FLOAT32X4(Vector); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasBroadcastFloat32x4(float Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vdupq_n_f32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_set1_ps(Value); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) - // Suppress wrong GCC warnings - MLAS_UNREFERENCED_PARAMETER(Value); - return vec_splats(Value); -#elif defined(MLAS_LSX_INTRINSICS) - return MLAS_FLOAT32X4{Value, Value, Value, Value}; -#else - return MLAS_FLOAT32X4{Value, Value, Value, Value}; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasBroadcastFloat32x4(const float* Value) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vld1q_dup_f32(Value); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_load_ps1(Value); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_load32_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_splats(*Value); -#elif defined(MLAS_LSX_INTRINSICS) - return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; -#else - return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasZeroFloat32x4(void) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vdupq_n_f32(0.0f); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_setzero_ps(); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_const(0.0f, 0.0f, 0.0f, 0.0f); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasBroadcastFloat32x4(0.0f); -#else - return MlasBroadcastFloat32x4(0.0f); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasLoadFloat32x4(const float* Buffer) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vld1q_f32(Buffer); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_loadu_ps(Buffer); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_vsx_ld(0, Buffer); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_load(Buffer); -#elif defined(MLAS_LSX_INTRINSICS) - // return MlasReinterpretAsFloat32x4(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); - return (MLAS_FLOAT32X4)__lsx_vld((const MLAS_INT32X4 *)Buffer, 0); -#else - return *((MLAS_FLOAT32X4*)Buffer); -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - vst1q_f32(Buffer, Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - _mm_storeu_ps(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - vec_vsx_st(Vector, 0, Buffer); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - wasm_v128_store(Buffer, Vector); -#elif defined(MLAS_LSX_INTRINSICS) - __lsx_vst(MlasReinterpretAsInt32x4(Vector), Buffer, 0); -#else - *((MLAS_FLOAT32X4*)Buffer) = Vector; -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - vst1q_f32(Buffer, Vector); -#elif defined(MLAS_SSE2_INTRINSICS) - _mm_store_ps(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - // Workaround for bad GCC warning that these parameters are set but not used. - MLAS_UNREFERENCED_PARAMETER(Buffer); - MLAS_UNREFERENCED_PARAMETER(Vector); - vec_st(Vector, 0, Buffer); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - wasm_v128_store(Buffer, Vector); -#elif defined(MLAS_LSX_INTRINSICS) - MlasStoreFloat32x4(Buffer, Vector); -#else - MlasStoreFloat32x4(Buffer, Vector); -#endif -} - -template -MLAS_FORCEINLINE -void -MlasStoreLaneFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - vst1q_lane_f32(Buffer, Vector, Lane); -#elif defined(MLAS_SSE2_INTRINSICS) - // N.B. When building with AVX instructions, compilers optimize the following - // to a single vextractps instruction. - _mm_store_ss(Buffer, _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - *Buffer = ((__f32x4)(Vector))[Lane]; -#elif defined(MLAS_LSX_INTRINSICS) - *Buffer = Vector[Lane]; -#else - *Buffer = Vector[Lane]; -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreLowHalfFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - vst1_f32(Buffer, vget_low_f32(Vector)); -#elif defined(MLAS_SSE2_INTRINSICS) - _mm_storel_pi((__m64*)Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - *((long long*)Buffer) = ((__vector long long)Vector)[0]; -#elif defined(MLAS_LSX_INTRINSICS) - MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); - MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); -#else - MlasStoreLaneFloat32x4<0>(&Buffer[0], Vector); - MlasStoreLaneFloat32x4<1>(&Buffer[1], Vector); -#endif -} - -template -MLAS_FORCEINLINE -float -MlasExtractLaneFloat32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vgetq_lane_f32(Vector, Lane); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_cvtss_f32(_mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Lane, Lane, Lane, Lane))); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_extract_lane(Vector, Lane); -#elif defined(MLAS_LSX_INTRINSICS) - return Vector[Lane]; -#else - return Vector[Lane]; -#endif -} - -#if defined(MLAS_SSE2_INTRINSICS) - -template<> -MLAS_FORCEINLINE -void -MlasStoreLaneFloat32x4<0>(float* Buffer, MLAS_FLOAT32X4 Vector) -{ - _mm_store_ss(Buffer, Vector); -} - -template<> -MLAS_FORCEINLINE -float -MlasExtractLaneFloat32x4<0>(MLAS_FLOAT32X4 Vector) -{ - return _mm_cvtss_f32(Vector); -} - -template -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector) -{ - return _mm_shuffle_ps(Vector, Vector, _MM_SHUFFLE(Index3, Index2, Index1, Index0)); -} - -#endif - -#if !defined(MLAS_SSE2_INTRINSICS) && !defined(_MSC_VER) - -template -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_i32x4_shuffle(Vector1, Vector2, Index0, Index1, Index2, Index3); -#elif defined(__clang__) - return __builtin_shufflevector(Vector1, Vector2, Index0, Index1, Index2, Index3); -#elif defined(MLAS_LSX_INTRINSICS) - typedef int32_t GEN_INT32X4 __attribute__ ((vector_size(16))); - return __builtin_shuffle(Vector1, Vector2, GEN_INT32X4{Index0, Index1, Index2, Index3}); -#else - return __builtin_shuffle(Vector1, Vector2, MLAS_INT32X4{Index0, Index1, Index2, Index3}); -#endif -} - -template -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasShuffleFloat32x4(MLAS_FLOAT32X4 Vector) -{ - return MlasShuffleFloat32x4(Vector, Vector); -} - -#endif - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON64_INTRINSICS) - return vzip1q_f32(Vector1, Vector2); -#elif defined(MLAS_NEON32_INTRINSICS) - float32x4x2_t zipped = vzipq_f32(Vector1, Vector2); - return zipped.val[0]; -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_unpacklo_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_mergeh(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); -#else - return MlasShuffleFloat32x4<0, 4, 1, 5>(Vector1, Vector2); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON64_INTRINSICS) - return vzip2q_f32(Vector1, Vector2); -#elif defined(MLAS_NEON32_INTRINSICS) - float32x4x2_t zipped = vzipq_f32(Vector1, Vector2); - return zipped.val[1]; -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_unpackhi_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_mergel(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); -#else - return MlasShuffleFloat32x4<2, 6, 3, 7>(Vector1, Vector2); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vaddq_f32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_add_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_add(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_add(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfadd_s(Vector1, Vector2); -#else - return Vector1 + Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasSubtractFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vsubq_f32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_sub_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_sub(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_sub(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfsub_s(Vector1, Vector2); -#else - return Vector1 - Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMultiplyFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vmulq_f32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_mul_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_mul(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - // Suppress wrong GCC warnings - MLAS_UNREFERENCED_PARAMETER(Vector1); - MLAS_UNREFERENCED_PARAMETER(Vector2); - return vec_mul(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfmul_s(Vector1, Vector2); -#else - return Vector1 * Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vmlaq_f32(Vector3, Vector1, Vector2); -#elif defined(MLAS_FMA3_INTRINSICS) - return _mm_fmadd_ps(Vector1, Vector2, Vector3); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_add_ps(_mm_mul_ps(Vector1, Vector2), Vector3); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_madd(Vector1, Vector2, Vector3); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfmadd_s(Vector1, Vector2, Vector3); -#else - return Vector1 * Vector2 + Vector3; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, float Scalar2, MLAS_FLOAT32X4 Vector3) -{ - return MlasMultiplyAddFloat32x4(Vector1, MlasBroadcastFloat32x4(Scalar2), Vector3); -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, float Scalar3) -{ - return MlasMultiplyAddFloat32x4(Vector1, Vector2, MlasBroadcastFloat32x4(Scalar3)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON64_INTRINSICS) - return vdivq_f32(Vector1, Vector2); -#elif defined(MLAS_NEON32_INTRINSICS) - Vector1 = vsetq_lane_f32(vgetq_lane_f32(Vector1, 0) / vgetq_lane_f32(Vector2, 0), Vector1, 0); - Vector1 = vsetq_lane_f32(vgetq_lane_f32(Vector1, 1) / vgetq_lane_f32(Vector2, 1), Vector1, 1); - Vector1 = vsetq_lane_f32(vgetq_lane_f32(Vector1, 2) / vgetq_lane_f32(Vector2, 2), Vector1, 2); - Vector1 = vsetq_lane_f32(vgetq_lane_f32(Vector1, 3) / vgetq_lane_f32(Vector2, 3), Vector1, 3); - return Vector1; -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_div_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_div(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfdiv_s(Vector1, Vector2); -#else - return Vector1 / Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vreinterpretq_f32_u32(vcgtq_f32(Vector1, Vector2)); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_cmpgt_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_gt(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); -#elif defined(MLAS_LSX_INTRINSICS) - return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); -#else - return Vector1 > Vector2; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasAndFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_and_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_and(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#else - return MlasReinterpretAsFloat32x4(MlasAndInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasOrFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_or_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_or(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#else - return MlasReinterpretAsFloat32x4(MlasOrInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasAndNotFloat32x4(MLAS_FLOAT32X4 VectorNot, MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_andnot_ps(VectorNot, Vector); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_andnot(Vector, VectorNot); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); -#else - return MlasReinterpretAsFloat32x4(MlasAndNotInt32x4(MlasReinterpretAsInt32x4(VectorNot), MlasReinterpretAsInt32x4(Vector))); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasXorFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_xor_ps(Vector1, Vector2); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_v128_xor(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#else - return MlasReinterpretAsFloat32x4(MlasXorInt32x4(MlasReinterpretAsInt32x4(Vector1), MlasReinterpretAsInt32x4(Vector2))); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasBlendFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Selection) -{ - return MlasOrFloat32x4(MlasAndFloat32x4(Vector2, Selection), MlasAndNotFloat32x4(Selection, Vector1)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vmaxq_f32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_max_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - // Don't use vec_max to avoid undefined behavior if NAN - return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_max(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfmax_s(Vector1, Vector2); -#else - return MlasBlendFloat32x4(Vector2, Vector1, Vector1 > Vector2); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) -{ -#if defined(MLAS_NEON_INTRINSICS) - return vminq_f32(Vector1, Vector2); -#elif defined(MLAS_SSE2_INTRINSICS) - return _mm_min_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - // Don't use vec_min to avoid undefined behavior if NAN - return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); -#elif defined(MLAS_WASM_SIMD_INTRINSICS) - return wasm_f32x4_min(Vector1, Vector2); -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfmin_s(Vector1, Vector2); -#else - return MlasBlendFloat32x4(Vector2, Vector1, Vector2 > Vector1); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasClampFloat32x4(MLAS_FLOAT32X4 Value, float LowerRange, float UpperRange) -{ -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. MINPS and MAXPS propagates the value from the second vector if the - // value is a NaN. -#endif - Value = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(LowerRange), Value); - Value = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(UpperRange), Value); - return Value; -} - -MLAS_FORCEINLINE -float -MlasReduceAddFloat32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON64_INTRINSICS) - Vector = vpaddq_f32(Vector, Vector); - Vector = vpaddq_f32(Vector, Vector); - return vgetq_lane_f32(Vector, 0); -#elif defined(MLAS_NEON32_INTRINSICS) - float32x2_t VectorLow = vget_low_f32(Vector); - float32x2_t VectorHigh = vget_high_f32(Vector); - VectorLow = vpadd_f32(VectorLow, VectorHigh); - VectorLow = vpadd_f32(VectorLow, VectorHigh); - return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) - Vector = MlasAddFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); - Vector = MlasAddFloat32x4(Vector, vec_splat(Vector, 1)); - return Vector[0]; -#else - Vector = MlasAddFloat32x4(Vector, MlasShuffleFloat32x4<2, 3, 2, 3>(Vector)); - Vector = MlasAddFloat32x4(Vector, MlasShuffleFloat32x4<1, 1, 1, 1>(Vector)); - return MlasExtractLaneFloat32x4<0>(Vector); -#endif -} - -MLAS_FORCEINLINE -float -MlasReduceMaximumFloat32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON64_INTRINSICS) - return vmaxvq_f32(Vector); -#elif defined(MLAS_NEON32_INTRINSICS) - float32x2_t VectorLow = vget_low_f32(Vector); - float32x2_t VectorHigh = vget_high_f32(Vector); - VectorLow = vpmax_f32(VectorLow, VectorHigh); - VectorLow = vpmax_f32(VectorLow, VectorHigh); - return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) - Vector = MlasMaximumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); - Vector = MlasMaximumFloat32x4(Vector, vec_splat(Vector, 1)); - return Vector[0]; -#else - Vector = MlasMaximumFloat32x4(Vector, MlasShuffleFloat32x4<2, 3, 2, 3>(Vector)); - Vector = MlasMaximumFloat32x4(Vector, MlasShuffleFloat32x4<1, 1, 1, 1>(Vector)); - return MlasExtractLaneFloat32x4<0>(Vector); -#endif -} - -MLAS_FORCEINLINE -float -MlasReduceMinimumFloat32x4(MLAS_FLOAT32X4 Vector) -{ -#if defined(MLAS_NEON64_INTRINSICS) - return vminvq_f32(Vector); -#elif defined(MLAS_NEON32_INTRINSICS) - float32x2_t VectorLow = vget_low_f32(Vector); - float32x2_t VectorHigh = vget_high_f32(Vector); - VectorLow = vpmin_f32(VectorLow, VectorHigh); - VectorLow = vpmin_f32(VectorLow, VectorHigh); - return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) - Vector = MlasMinimumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); - Vector = MlasMinimumFloat32x4(Vector, vec_splat(Vector, 1)); - return Vector[0]; -#else - Vector = MlasMinimumFloat32x4(Vector, MlasShuffleFloat32x4<2, 3, 2, 3>(Vector)); - Vector = MlasMinimumFloat32x4(Vector, MlasShuffleFloat32x4<1, 1, 1, 1>(Vector)); - return MlasExtractLaneFloat32x4<0>(Vector); -#endif -} - -// calc 2^int(N) -MLAS_FORCEINLINE -MLAS_FLOAT32X4 -MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) -{ - MLAS_INT32X4 emm0 = MlasAddInt32x4(MlasCastToInt32x4(Vector), MlasBroadcastInt32x4(127)); - return MlasReinterpretAsFloat32x4(MlasShiftLeftInt32x4<23>(emm0)); -} - -// -// Cross-platform wrappers for 64-bit vector intrinsics. -// - -#if defined(MLAS_SSE2_INTRINSICS) -typedef __m128d MLAS_FLOAT64X2; -#elif defined(MLAS_VSX_INTRINSICS) -typedef __vector double MLAS_FLOAT64X2; -#elif defined(MLAS_LSX_INTRINSICS) -typedef __m128d MLAS_FLOAT64X2; -#else -#define MLAS_FLOAT64X2_UNSUPPORTED -#endif - -#ifndef MLAS_FLOAT64X2_UNSUPPORTED - -#if defined(MLAS_VSX_INTRINSICS) -template -MLAS_FORCEINLINE -double -MlasExtractLaneFloat64x2(MLAS_FLOAT64X2 Vector) -{ - return Vector[Lane]; -} -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FLOAT64X2 Vector3) -{ - return vec_madd(Vector1, Vector2, Vector3); -} - -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasBroadcastFloat64x2(const double *Value) -{ - return MLAS_FLOAT64X2{*Value, *Value}; -} -#elif defined(MLAS_LSX_INTRINSICS) -template -MLAS_FORCEINLINE -double -MlasExtractLaneFloat64x2(MLAS_FLOAT64X2 Vector) -{ - return Vector[Lane]; -} -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasMultiplyAddFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2, MLAS_FLOAT64X2 Vector3) -{ - return __lsx_vfmadd_d(Vector1, Vector2, Vector3); -} - -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasBroadcastFloat64x2(const double *Value) -{ - return MLAS_FLOAT64X2{*Value, *Value}; -} -#endif -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasBroadcastFloat64x2(double Value) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_set1_pd(Value); -#elif defined(MLAS_VSX_INTRINSICS) - return MLAS_FLOAT64X2{Value, Value}; -#elif defined(MLAS_LSX_INTRINSICS) - return MLAS_FLOAT64X2{Value, Value}; -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasZeroFloat64x2(void) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_setzero_pd(); -#elif defined(MLAS_VSX_INTRINSICS) - return MlasBroadcastFloat64x2(0.0f); -#elif defined(MLAS_LSX_INTRINSICS) - return MlasBroadcastFloat64x2(0.0f); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasLoadFloat64x2(const double* Buffer) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_loadu_pd(Buffer); -#elif defined(MLAS_VSX_INTRINSICS) - return vec_vsx_ld(0, Buffer); -#elif defined(MLAS_LSX_INTRINSICS) - return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) -{ -#if defined(MLAS_SSE2_INTRINSICS) - _mm_storeu_pd(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - vec_vsx_st(Vector, 0, Buffer); -#elif defined(MLAS_LSX_INTRINSICS) - (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); -#endif -} - -MLAS_FORCEINLINE -void -MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) -{ -#if defined(MLAS_SSE2_INTRINSICS) - _mm_store_pd(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) - *((MLAS_FLOAT64X2*)Buffer) = Vector; -#elif defined(MLAS_LSX_INTRINSICS) - (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); -#endif -} - -MLAS_FORCEINLINE -MLAS_FLOAT64X2 -MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) -{ -#if defined(MLAS_SSE2_INTRINSICS) - return _mm_mul_pd(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) - return Vector1 * Vector2; -#elif defined(MLAS_LSX_INTRINSICS) - return __lsx_vfmul_d(Vector1, Vector2); -#endif -} - -#endif // !MLAS_FLOAT64X2_UNSUPPORTED - -// -// Reads a platform specific time stamp counter. -// - -MLAS_FORCEINLINE -uint64_t -MlasReadTimeStampCounter(void) -{ -#ifdef _WIN32 -#if defined(MLAS_TARGET_AMD64_IX86) - return ReadTimeStampCounter(); -#else - LARGE_INTEGER PerformanceCounter; - - QueryPerformanceCounter(&PerformanceCounter); - - return (ULONG64)PerformanceCounter.QuadPart; -#endif -#else -#if defined(MLAS_TARGET_AMD64) - uint32_t eax, edx; - - __asm__ __volatile__ - ( - "rdtsc" - : "=a" (eax), "=d" (edx) - ); - - return ((uint64_t)edx << 32) | eax; -#elif defined(MLAS_TARGET_LARCH64) - uint64_t time_cnt, id; - - __asm__ __volatile__ - ( - "rdtime.d %0, %1\n\t" - : "=r" (time_cnt), "=r" (id) - :: - ); - - return time_cnt; -#else - return 0; -#endif -#endif -} - -// -// Aligned buffer for GEMM packing, etc. -// - - -constexpr size_t ThreadedBufAlignment = 64; -extern thread_local size_t ThreadedBufSize; -#ifdef _MSC_VER -extern thread_local std::unique_ptr ThreadedBufHolder; -#else -extern thread_local std::unique_ptr ThreadedBufHolder; -#endif - -MLAS_FORCEINLINE -constexpr size_t -UpAlignSize(size_t size) -{ - size = (size + ThreadedBufAlignment - 1) / ThreadedBufAlignment; - return size * ThreadedBufAlignment; -} - - -MLAS_FORCEINLINE -void -MlasThreadedBufAlloc(size_t size) -{ - if (size > ThreadedBufSize) { -#ifdef _MSC_VER - ThreadedBufHolder.reset( - reinterpret_cast(_aligned_malloc(size, ThreadedBufAlignment))); -#elif (__STDC_VERSION__ >= 201112L) && !defined(__APPLE__) - ThreadedBufHolder.reset( - reinterpret_cast(aligned_alloc(ThreadedBufAlignment, size))); -#else - // aligned_alloc unavailable macos 10.14 or earlier - void* ptr; - int err = posix_memalign(&ptr, ThreadedBufAlignment, size); - if (err != 0) { - ptr = nullptr; - } - ThreadedBufHolder.reset(reinterpret_cast(ptr)); -#endif - - ThreadedBufSize = size; - } -} - -// -// Utilities for INT4 quantization. -// - -template -struct Int4Traits; - -template<> -struct Int4Traits { - using UnpackedType = int8_t; - static constexpr int8_t Min = -8; - static constexpr int8_t Max = 7; -}; - -template<> -struct Int4Traits { - using UnpackedType = uint8_t; - static constexpr int8_t Min = 0; - static constexpr int8_t Max = 15; -}; - -template -MLAS_FORCEINLINE -void -MlasSetInt4Element(uint8_t* Output, size_t ElemIndex, UnpackedType Value) -{ - static_assert(std::is_same_v || std::is_same_v); - - const size_t OutputIndex = ElemIndex >> 1; // which byte - const size_t NibbleIndex = ElemIndex & 0x1; // which 4-bit elem in the byte - const uint8_t Shift = static_cast(NibbleIndex << 2); // Either 0 or 4 - const uint8_t Mask = static_cast(0xF0 >> Shift); - uint8_t* Dst = &Output[OutputIndex]; - - *Dst &= Mask; // Clear 4-bit lane - *Dst |= static_cast((Value & 0xF) << Shift); // Set 4-bit lane -} - -template -MLAS_FORCEINLINE -void -MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueHigh) -{ - static_assert(std::is_same_v || std::is_same_v); - *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); -} diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp deleted file mode 100644 index 23d29fd02fa5a..0000000000000 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ /dev/null @@ -1,732 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - platform.cpp - -Abstract: - - This module implements logic to select the best configuration for the - this platform. - ---*/ - -#include "mlasi.h" - -#include -#include - -#if defined(MLAS_TARGET_POWER) -#if defined(__linux__) -#include -#elif defined(_AIX) -#define POWER_10 0x40000 -#define POWER_10_ANDUP (POWER_10) -#include -#define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) -#endif -#endif - -#if defined(MLAS_TARGET_ARM64) -#if defined(_WIN32) - -// N.B. Support building with downlevel versions of the Windows SDK. -#ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE -#define PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE 43 -#endif - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() -{ - has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); - - // raw hack! Need CPUIDInfo implementation for more precise detection - has_fp16_ = has_arm_neon_dot_; -} -#endif - -#elif defined(__linux__) - -#include -#include -// N.B. Support building with older versions of asm/hwcap.h that do not define -// this capability bit. -#ifndef HWCAP_ASIMDDP -#define HWCAP_ASIMDDP (1 << 20) -#endif - -#ifndef HWCAP2_I8MM -#define HWCAP2_I8MM (1 << 13) -#endif - -#ifndef HWCAP2_SVEI8MM -#define HWCAP2_SVEI8MM (1 << 9) -#endif - -#ifndef HWCAP2_BF16 -#define HWCAP2_BF16 (1 << 14) -#endif - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() -{ - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); - - // raw hack! Need CPUIDInfo implementation for more precise detection - has_fp16_ = has_arm_neon_dot_; - - has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); - has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); - - has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); -} -#endif - -#else - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() {} -#endif - -#endif // Windows vs Linux vs Unknown -#else // not MLAS_TARGET_ARM64 - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() {} -#endif - -#endif // MLAS_TARGET_ARM64 - -#ifdef MLAS_TARGET_AMD64_IX86 - -// -// Stores a vector to build a conditional load/store mask for vmaskmovps. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; - -// -// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableAvx[16], 32) = { - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, -}; - -// -// Stores a table of AVX512 opmask register values. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const int16_t MlasOpmask16BitTableAvx512[16], 32) = { - 0x0000, 0x0001, 0x0003, 0x0007, 0x000F, 0x001F, 0x003F, 0x007F, - 0x00FF, 0x01FF, 0x03FF, 0x07FF, 0x0FFF, 0x1FFF, 0x3FFF, 0x7FFF, -}; - -// -// Reads the processor extended control register to determine platform -// capabilities. -// - -#if !defined(_XCR_XFEATURE_ENABLED_MASK) -#define _XCR_XFEATURE_ENABLED_MASK 0 -#endif - -#if !defined(XFEATURE_MASK_XTILE) -#define XFEATURE_XTILECFG 17 -#define XFEATURE_XTILEDATA 18 -#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) -#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) -#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) -#endif - -inline -uint64_t -MlasReadExtendedControlRegister( - unsigned int ext_ctrl_reg -) -{ -#if defined(_WIN32) - return _xgetbv(ext_ctrl_reg); -#else - uint32_t eax, edx; - - __asm__ - ( - "xgetbv" - : "=a" (eax), "=d" (edx) - : "c" (ext_ctrl_reg) - ); - - return ((uint64_t)edx << 32) | eax; -#endif -} - -#if defined(__linux__) -#include -#endif - -bool -MlasInitAMX() -{ -#if defined(__linux__) - -#define ARCH_GET_XCOMP_PERM 0x1022 -#define ARCH_REQ_XCOMP_PERM 0x1023 - - unsigned long bitmask = 0; - long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); - if (rc) { - return false; - } - rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); - if (rc) { - return false; - } - if (bitmask & XFEATURE_MASK_XTILE) { - return true; - } - return false; -#else - return true; -#endif -} - -#endif // MLAS_TARGET_AMD64_IX86 - -#ifdef MLAS_TARGET_LARCH64 - -#if defined(__linux__) -#include -#include -#endif -// -// Stores a vector to build a conditional load/store mask for vmaskmovps. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveLasx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 }; - -// -// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveTableLasx[16], 32) = { - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, -}; - -#endif -MLAS_PLATFORM::MLAS_PLATFORM( - void - ) -/*++ - -Routine Description: - - This routine initializes the platform support for this library. - -Arguments: - - None. - -Return Value: - - None. - ---*/ -{ - - this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernel; - this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel; - this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; - this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; - this->CastF16ToF32Kernel = nullptr; - this->CastF32ToF16Kernel = nullptr; - -#if defined(MLAS_TARGET_AMD64_IX86) - - // - // Default to the baseline SSE2 support. - // - - this->GemmFloatKernel = MlasGemmFloatKernelSse; - this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchSse; - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse; - -#if defined(MLAS_TARGET_AMD64) - - this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse; - this->GemmDoubleKernel = MlasGemmDoubleKernelSse; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelSse; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelSse; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelSse; - this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelSse; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelSse; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelSse; - this->ComputeExpF32Kernel = MlasComputeExpF32Kernel; - this->LogisticKernelRoutine = MlasLogisticKernel; - this->TanhKernelRoutine = MlasTanhKernel; - this->ErfKernelRoutine = MlasErfKernel; - this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; - this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; - this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; - this->ReduceMinimumMaximumF32Kernel = MlasReduceMinimumMaximumF32Kernel; - this->QLinearAddS8Kernel = MlasQLinearAddS8Kernel; - this->QLinearAddU8Kernel = MlasQLinearAddU8Kernel; - this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; - this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; - this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; - this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; - this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; - this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; -#ifndef __APPLE__ - this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; -#endif // __APPLE__ - - this->NchwcBlockSize = 8; - this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; - - this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; - -#endif - - unsigned Cpuid1[4]; -#if defined(_WIN32) - __cpuid((int*)Cpuid1, 1); -#else - __cpuid(1, Cpuid1[0], Cpuid1[1], Cpuid1[2], Cpuid1[3]); -#endif - -#if defined(_MSC_VER) - - // - // Check if the processor supports SSE 4.1 instructions. - // - - if ((Cpuid1[2] & 0x80000) != 0) { - this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41; - } - -#endif - - // - // Check if the processor supports the AVX and OSXSAVE features. - // - - if ((Cpuid1[2] & 0x18000000) == 0x18000000) { - - // - // Check if the operating system supports saving SSE and AVX states. - // - - uint64_t xcr0 = MlasReadExtendedControlRegister(_XCR_XFEATURE_ENABLED_MASK); - - if ((xcr0 & 0x6) == 0x6) { - - this->GemmFloatKernel = MlasGemmFloatKernelAvx; - -#if defined(MLAS_TARGET_AMD64) - - this->KernelM1Routine = MlasSgemmKernelM1Avx; - this->KernelM1TransposeBRoutine = MlasSgemmKernelM1TransposeBAvx; - this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Avx; - this->GemmDoubleKernel = MlasGemmDoubleKernelAvx; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelAvx; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx; - this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx; - this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelAvx; - this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelAvx; - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx; - this->ReduceMinimumMaximumF32Kernel = MlasReduceMinimumMaximumF32KernelAvx; - this->GemmU8U8Kernel = nullptr; - - // - // Check if the processor supports AVX2/FMA3 features. - // - - unsigned Cpuid7[4]; -#if defined(_WIN32) - __cpuidex((int*)Cpuid7, 7, 0); -#else - __cpuid_count(7, 0, Cpuid7[0], Cpuid7[1], Cpuid7[2], Cpuid7[3]); -#endif - - if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { - - this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; - this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; - this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; - this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2; - this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx2; - - this->GemmFloatKernel = MlasGemmFloatKernelFma3; - this->GemmDoubleKernel = MlasGemmDoubleKernelFma3; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelFma3; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelFma3; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelFma3; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelFma3; - this->ComputeExpF32Kernel = MlasComputeExpF32KernelFma3; - this->LogisticKernelRoutine = MlasComputeLogisticF32KernelFma3; - this->TanhKernelRoutine = MlasComputeTanhF32KernelFma3; - this->ErfKernelRoutine = MlasErfKernelFma3; - this->QLinearAddS8Kernel = MlasQLinearAddS8KernelAvx2; - this->QLinearAddU8Kernel = MlasQLinearAddU8KernelAvx2; - this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2; - this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2; - this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; - this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; - this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; - this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; - this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; - - - // - // Check if the processor supports Hybrid core architecture. - // - - if ((Cpuid7[3] & 0x8000) != 0) { - this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; - } - - // - // Check if the processor supports AVXVNNI features. - // - - unsigned Cpuid7_1[4]; -#if defined(_WIN32) - __cpuidex((int*)Cpuid7_1, 7, 1); -#else - __cpuid_count(7, 1, Cpuid7_1[0], Cpuid7_1[1], Cpuid7_1[2], Cpuid7_1[3]); -#endif - - if ((Cpuid7_1[0] & 0x10) != 0) { - - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; - this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; - this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; - } - -#if !defined(ORT_MINIMAL_BUILD) - - // - // Check if the processor supports AVX512F features and the - // operating system supports saving AVX512F state. - // - - if (((Cpuid7[1] & 0x10000) != 0) && ((xcr0 & 0xE0) == 0xE0)) { - - this->GemmFloatKernel = MlasGemmFloatKernelAvx512F; - this->GemmDoubleKernel = MlasGemmDoubleKernelAvx512F; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelAvx512F; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelAvx512F; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelAvx512F; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelAvx512F; - this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelAvx512F; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelAvx512F; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelAvx512F; - this->ComputeExpF32Kernel = MlasComputeExpF32KernelAvx512F; - this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelAvx512F; - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx512F; - this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelAvx512F; - this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelAvx512F; - this->NchwcBlockSize = 16; - this->PreferredBufferAlignment = 64; - - // - // Check if the processor supports AVX512 core features - // (AVX512BW/AVX512DQ/AVX512VL). - // - - if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { - - this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; - this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; - this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; - this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; - - // - // Check if the processor supports AVX512VNNI. - // - - if ((Cpuid7[2] & 0x800) != 0) { - - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; - this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; - this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; - this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; - } - } - } - - // - // Check if the processor supports AVX-VNNI-INT8 - // - if ((Cpuid7_1[3] & 0x10) != 0) { - this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2Vnni; - this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchAvx2Vnni; - this->GemmS8S8Kernel = MlasGemmS8S8KernelAvx2Vnni; - this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchAvx2Vnni; - this->GemmS8U8Kernel = MlasGemmS8U8KernelAvx2Vnni; - } - -#ifndef __APPLE__ -#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) - // - // Check if the processor supports AVX NE CONVERT. - // - if ((Cpuid7_1[3] & (0b1 << 5)) != 0) { - this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx; - } -#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) - - - // - // Check if the processor supports AMX-TILE and AMX-INT8 - // features. - // - if ((Cpuid7[3] & 0b1 << 24) != 0 && - (Cpuid7[3] & 0b1 << 25) != 0 && - (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { - if (MlasInitAMX()) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; - this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; - } - } -#endif // __APPLE__ - -#endif // ORT_MINIMAL_BUILD - - } - -#endif // MLAS_TARGET_AMD64 - - } - } - -#endif // MLAS_TARGET_AMD64_IX86 - -#if defined(MLAS_TARGET_ARM64) - - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchNeon; - this->GemmU8S8Dispatch = &MlasGemmX8S8DispatchNeon; - this->GemmS8S8Dispatch = &MlasGemmX8S8DispatchNeon; - this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; - this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; - this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; - - // - // Check if the processor supports ASIMD dot product instructions. - // - - bool HasDotProductInstructions; - -#if defined(_WIN32) - HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else - // Use the cpuinfo value which is read from sysctl and has some additional special cases. - // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 - // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips - // as well as failing on other ARM chips as it is an EL1 level register that requires extra - // privileges to read. - // - // uint64_t isar0_el1; - // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; - HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); -#endif - - if (HasDotProductInstructions) { - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot; - this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUdot; - this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSdot; - this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; - this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; - this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; - - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; - } - -#if defined(__linux__) - // - // Check if the processor supports ASIMD I8MM instructions. - // - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; - this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; - this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; - } -#endif - -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; - this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelNeon; -#endif - -#endif // MLAS_TARGET_ARM64 -#if defined(MLAS_TARGET_POWER) - this->GemmFloatKernel = MlasSgemmKernel; - this->GemmDoubleKernel = MlasDgemmKernel; - this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; - this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; - this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; - this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; - this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; - this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; - -#if defined(__linux__) - unsigned long hwcap2 = getauxval(AT_HWCAP2); - - bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; -#elif defined(_AIX) - bool HasP9Instructions = __power_9_andup(); -#endif // __linux__ - if (HasP9Instructions) { - this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelVSX; - this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelVSX; - } - -#if defined(POWER10) -#if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ - (defined(__clang__) && (__clang_major__ >= 12)) -#if defined(__linux__) - bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); -#elif defined(_AIX) - bool HasP10Instructions = (__power_10_andup() && __power_mma_version() == MMA_V31); -#endif // __linux__ - if (HasP10Instructions) { - this->GemmFloatKernel = MlasSgemmKernelPOWER10; - this->GemmDoubleKernel = MlasDgemmKernelPOWER10; - this->GemmU8X8Dispatch = &MlasGemm8X8DispatchPOWER10; - } -#endif -#endif - -#endif // MLAS_TARGET_POWER - -#if defined(MLAS_TARGET_LARCH64) - - // - // Default to the baseline LSX support. - // - - int hwcap = getauxval(AT_HWCAP); - bool cap_lasx = hwcap & HWCAP_LOONGARCH_LASX; - bool cap_lsx = hwcap & HWCAP_LOONGARCH_LSX; - - if( cap_lasx ){ - this->GemmFloatKernel = MlasGemmFloatKernelLasx; - this->GemmDoubleKernel = MlasGemmDoubleKernelLasx; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLasx; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLasx; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLasx; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLasx; - this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLasx; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLasx; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLasx; - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelLasx; - this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelLasx; - this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; - this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; - - this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; - }else if( cap_lsx ){ - this->GemmFloatKernel = MlasGemmFloatKernelLSX; - this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; - this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; - this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4LSX; - this->GemmDoubleKernel = MlasGemmDoubleKernelLSX; - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelLSX; - this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelLSX; - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelLSX; - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelLSX; - - this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelLSX; - this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelLSX; - this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelLSX; - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; - this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; - this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; - }else{ - this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; - this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; - this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; - } - - this->NchwcBlockSize = 8; - // this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; - - // this->MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; - -#endif // MLAS_TARGET_LARCH64 - -} - -size_t -MLASCALL -MlasGetPreferredBufferAlignment( - void - ) -/*++ - -Routine Description: - - This routine returns the preferred byte alignment for buffers that are used - with this library. Buffers that are not byte aligned to this value will - function, but will not achieve best performance. - -Arguments: - - None. - -Return Value: - - Returns the preferred byte alignment for buffers. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - return GetMlasPlatform().PreferredBufferAlignment; -#else - return MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; -#endif -} - -#ifdef MLAS_TARGET_AMD64_IX86 - -bool -MLASCALL -MlasPlatformU8S8Overflow( - void - ) -{ - const auto& p = GetMlasPlatform(); - return p.GemmU8U8Dispatch != p.GemmU8S8Dispatch; -} - -#endif -thread_local size_t ThreadedBufSize = 0; -#ifdef _MSC_VER -thread_local std::unique_ptr ThreadedBufHolder(nullptr, &_aligned_free); -#else -thread_local std::unique_ptr ThreadedBufHolder(nullptr, &free); -#endif diff --git a/onnxruntime/core/mlas/lib/pooling.cpp b/onnxruntime/core/mlas/lib/pooling.cpp deleted file mode 100644 index 50dcf19224510..0000000000000 --- a/onnxruntime/core/mlas/lib/pooling.cpp +++ /dev/null @@ -1,1703 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - pooling.cpp - -Abstract: - - This module implements the pooling operation. - ---*/ - -#include "mlasi.h" - -// -// Define the parameters to execute segments of a pooling operation on worker -// threads. -// - -struct MLAS_POOL_WORK_BLOCK -{ - MLAS_POOLING_KIND PoolingKind; - size_t InputShape[3]; - size_t InputSize; - size_t OutputShape[3]; - int64_t KernelShape[3]; - int64_t Padding[6]; - int64_t StrideShape[3]; -}; - -// -// Define the prototype of the pooling kernel routine. -// - -typedef -void -(MLAS_POOL_KERNEL_ROUTINE)( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ); - -// -// Define the number of elements to allocate on the stack for the reduction -// buffer in the vectorized kernels. -// - -#define MLAS_POOL_REDUCTION_BUFFER_STACK 2048 - -// -// Define the number of reduction buffer elements reserved for over-reading -// an entire vector to avoid special handling at the right edge of the -// buffer. -// - -#define MLAS_POOL_REDUCTION_BUFFER_PADDING ((sizeof(MLAS_FLOAT32X4) / sizeof(float)) - 1) - -// -// Abstraction for maximum pooling. -// - -struct MLAS_MAXIMUM_POOLING -{ - static constexpr float InitialValue() - { - return std::numeric_limits::lowest(); - } - - static MLAS_FLOAT32X4 InitialVector() - { - return MlasBroadcastFloat32x4(InitialValue()); - } - - static constexpr float Reduce(float Reduction, float Value) - { - return std::max(Reduction, Value); - } - - static MLAS_FLOAT32X4 Reduce(MLAS_FLOAT32X4 Reduction, MLAS_FLOAT32X4 Value) - { - return MlasMaximumFloat32x4(Reduction, Value); - } - - static float Reduce(MLAS_FLOAT32X4 Reduction) - { - return MlasReduceMaximumFloat32x4(Reduction); - } - - static constexpr float AveragePool(float Reduction, float Size) - { - MLAS_UNREFERENCED_PARAMETER(Size); - - return Reduction; - } - - struct DividerVectorContext - { - void PrepareExcludePad(size_t PaddingLeftWidth, size_t InputWidth, size_t KernelWidth) - { - MLAS_UNREFERENCED_PARAMETER(PaddingLeftWidth); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - } - - void PrepareIncludePad(size_t KernelSize) - { - MLAS_UNREFERENCED_PARAMETER(KernelSize); - } - - void StartNextOutputRow(size_t InputRowsCount) - { - MLAS_UNREFERENCED_PARAMETER(InputRowsCount); - } - - MLAS_FLOAT32X4 DivideExcludePad(MLAS_FLOAT32X4 Reduction) - { - return Reduction; - } - - MLAS_FLOAT32X4 DivideIncludePad(MLAS_FLOAT32X4 Reduction) - { - return Reduction; - } - }; -}; - -// -// Abstraction for average pooling. -// - -MLAS_DECLSPEC_ALIGN(static const float MlasInitialReductionInputIndex[], sizeof(MLAS_FLOAT32X4)) = { 0.0f, 1.0f, 2.0f, 3.0f }; - -struct MLAS_AVERAGE_POOLING -{ - static float InitialValue() - { - return 0.0f; - } - - static MLAS_FLOAT32X4 InitialVector() - { - return MlasZeroFloat32x4(); - } - - static constexpr float Reduce(float Reduction, float Value) - { - return Reduction + Value; - } - - static MLAS_FLOAT32X4 Reduce(MLAS_FLOAT32X4 Reduction, MLAS_FLOAT32X4 Value) - { - return MlasAddFloat32x4(Reduction, Value); - } - - static float Reduce(MLAS_FLOAT32X4 Reduction) - { - return MlasReduceAddFloat32x4(Reduction); - } - - static constexpr float AveragePool(float Reduction, float Size) - { - return Reduction / Size; - } - - struct DividerVectorContext - { - MLAS_FLOAT32X4 KernelSizeBroadcast; - MLAS_FLOAT32X4 KernelWidthBroadcast; - MLAS_FLOAT32X4 PaddingLowerBound; - MLAS_FLOAT32X4 PaddingUpperBound; - MLAS_FLOAT32X4 ReductionInputIndex; - MLAS_FLOAT32X4 InputRowsBroadcast; - - void PrepareExcludePad(size_t PaddingLeftWidth, size_t InputWidth, size_t KernelWidth) - { - KernelWidthBroadcast = MlasBroadcastFloat32x4(float(unsigned(KernelWidth))); - PaddingLowerBound = MlasBroadcastFloat32x4(float(unsigned(PaddingLeftWidth))); - PaddingUpperBound = MlasBroadcastFloat32x4(float(unsigned(PaddingLeftWidth + InputWidth))); - } - - void PrepareIncludePad(size_t KernelSize) - { - KernelSizeBroadcast = MlasBroadcastFloat32x4(float(unsigned(KernelSize))); - } - - void StartNextOutputRow(size_t InputRowsCount) - { - ReductionInputIndex = MlasLoadFloat32x4(MlasInitialReductionInputIndex); - InputRowsBroadcast = MlasBroadcastFloat32x4(float(unsigned(InputRowsCount))); - } - - MLAS_FLOAT32X4 DivideExcludePad(MLAS_FLOAT32X4 Reduction) - { - MLAS_FLOAT32X4 Divisor; - - // - // Compute the ending input index for each column and bound the index - // range by the padding indices, then compute the number of input - // column contributions from the delta. - // - - MLAS_FLOAT32X4 ReductionInputEndingIndex = - MlasAddFloat32x4(ReductionInputIndex, KernelWidthBroadcast); - - MLAS_FLOAT32X4 LowerInputIndex = - MlasMaximumFloat32x4(ReductionInputIndex, PaddingLowerBound); - MLAS_FLOAT32X4 UpperInputIndex = - MlasMinimumFloat32x4(ReductionInputEndingIndex, PaddingUpperBound); - - MLAS_FLOAT32X4 InputIndexDelta = - MlasSubtractFloat32x4(UpperInputIndex, LowerInputIndex); - - // - // Advance the input index vector for the next iteration. - // - - ReductionInputIndex = - MlasAddFloat32x4(ReductionInputIndex, MlasBroadcastFloat32x4(4.0f)); - - // - // Compute the per-column number of input elements used for the sum. - // - // At the end of the input row, the index range computed above may be - // zero for unused trailing vector elements, so avoid any divide by zero - // penalty by enforcing a minimum of 1.0f. - // - - Divisor = MlasMultiplyFloat32x4(InputIndexDelta, InputRowsBroadcast); - Divisor = MlasMaximumFloat32x4(Divisor, MlasBroadcastFloat32x4(1.0f)); - - return MlasDivideFloat32x4(Reduction, Divisor); - } - - MLAS_FLOAT32X4 DivideIncludePad(MLAS_FLOAT32X4 Reduction) - { - return MlasDivideFloat32x4(Reduction, KernelSizeBroadcast); - } - }; -}; - -template -void -MlasPool1DKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements the 1D pooling operation using generic constructs. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - constexpr size_t WidthShapeIndex = 0; - - const MLAS_POOLING_KIND PoolingKind = WorkBlock->PoolingKind; - - const size_t InputWidth = WorkBlock->InputShape[WidthShapeIndex]; - const size_t OutputWidth = WorkBlock->OutputShape[WidthShapeIndex]; - - const int64_t KernelWidth = WorkBlock->KernelShape[WidthShapeIndex]; - const int64_t PaddingLeftWidth = WorkBlock->Padding[WidthShapeIndex]; - const int64_t StrideWidth = WorkBlock->StrideShape[WidthShapeIndex]; - - for (size_t c = 0; c < ChannelCount; c++) { - - for (size_t pw = 0; pw < OutputWidth; pw++) { - - const int64_t iwStart64 = pw * StrideWidth - PaddingLeftWidth; - const int64_t iwEnd64 = iwStart64 + KernelWidth; - - const size_t iwStart = size_t(std::max(iwStart64, int64_t(0))); - const size_t iwEnd = size_t(std::min(iwEnd64, int64_t(InputWidth))); - - float m = PoolingType::InitialValue(); - - for (size_t iw = size_t(iwStart); iw < size_t(iwEnd); iw++) { - m = PoolingType::Reduce(m, Input[iw]); - } - - if (PoolingKind == MlasAveragePoolingExcludePad) { - m = PoolingType::AveragePool(m, float(iwEnd - iwStart)); - } else { - m = PoolingType::AveragePool(m, float(KernelWidth)); - } - - *Output++ = m; - } - - Input += InputWidth; - } -} - -template -void -MlasPool2DKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements the 2D pooling operation using generic constructs. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - constexpr size_t HeightShapeIndex = 0; - constexpr size_t WidthShapeIndex = 1; - - const MLAS_POOLING_KIND PoolingKind = WorkBlock->PoolingKind; - - const size_t InputHeight = WorkBlock->InputShape[HeightShapeIndex]; - const size_t InputWidth = WorkBlock->InputShape[WidthShapeIndex]; - const size_t InputSize = WorkBlock->InputSize; - const size_t OutputHeight = WorkBlock->OutputShape[HeightShapeIndex]; - const size_t OutputWidth = WorkBlock->OutputShape[WidthShapeIndex]; - - const int64_t KernelHeight = WorkBlock->KernelShape[HeightShapeIndex]; - const int64_t KernelWidth = WorkBlock->KernelShape[WidthShapeIndex]; - const int64_t PaddingLeftHeight = WorkBlock->Padding[HeightShapeIndex]; - const int64_t PaddingLeftWidth = WorkBlock->Padding[WidthShapeIndex]; - const int64_t StrideHeight = WorkBlock->StrideShape[HeightShapeIndex]; - const int64_t StrideWidth = WorkBlock->StrideShape[WidthShapeIndex]; - - for (size_t c = 0; c < ChannelCount; c++) { - - for (size_t ph = 0; ph < OutputHeight; ph++) { - - const int64_t ihStart64 = ph * StrideHeight - PaddingLeftHeight; - const int64_t ihEnd64 = ihStart64 + KernelHeight; - - const size_t ihStart = size_t(std::max(ihStart64, int64_t(0))); - const size_t ihEnd = size_t(std::min(ihEnd64, int64_t(InputHeight))); - - for (size_t pw = 0; pw < OutputWidth; pw++) { - - const int64_t iwStart64 = pw * StrideWidth - PaddingLeftWidth; - const int64_t iwEnd64 = iwStart64 + KernelWidth; - - const size_t iwStart = size_t(std::max(iwStart64, int64_t(0))); - const size_t iwEnd = size_t(std::min(iwEnd64, int64_t(InputWidth))); - - float m = PoolingType::InitialValue(); - - for (size_t ih = ihStart; ih < ihEnd; ih++) { - for (size_t iw = iwStart; iw < iwEnd; iw++) { - m = PoolingType::Reduce(m, Input[ih * InputWidth + iw]); - } - } - - if (PoolingKind == MlasAveragePoolingExcludePad) { - m = PoolingType::AveragePool(m, float((ihEnd - ihStart) * (iwEnd - iwStart))); - } else { - m = PoolingType::AveragePool(m, float(KernelHeight * KernelWidth)); - } - - *Output++ = m; - } - } - - Input += InputSize; - } -} - -template -void -MlasPool2DVectorKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements an optimized 2D pooling operation using vector - instructions. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - constexpr size_t Dimensions = 2; - - constexpr size_t HeightShapeIndex = 0; - constexpr size_t WidthShapeIndex = 1; - - const MLAS_POOLING_KIND PoolingKind = WorkBlock->PoolingKind; - - const size_t InputHeight = WorkBlock->InputShape[HeightShapeIndex]; - const size_t InputWidth = WorkBlock->InputShape[WidthShapeIndex]; - const size_t InputSize = WorkBlock->InputSize; - const size_t OutputHeight = WorkBlock->OutputShape[HeightShapeIndex]; - const size_t OutputWidth = WorkBlock->OutputShape[WidthShapeIndex]; - - const size_t KernelHeight = size_t(WorkBlock->KernelShape[HeightShapeIndex]); - const size_t KernelWidth = size_t(WorkBlock->KernelShape[WidthShapeIndex]); - const size_t PaddingLeftHeight = size_t(WorkBlock->Padding[HeightShapeIndex]); - const size_t PaddingLeftWidth = size_t(WorkBlock->Padding[WidthShapeIndex]); - const size_t PaddingRightWidth = size_t(WorkBlock->Padding[Dimensions + WidthShapeIndex]); - const size_t StrideHeight = size_t(WorkBlock->StrideShape[HeightShapeIndex]); - const size_t StrideWidth = size_t(WorkBlock->StrideShape[WidthShapeIndex]); - - float ReductionBuffer[MLAS_POOL_REDUCTION_BUFFER_STACK]; - - // - // Fill the edges of the reduction buffer with the padding value. - // - - float* FillReductionBuffer = ReductionBuffer; - float* FillReductionBufferEnd = FillReductionBuffer + PaddingLeftWidth; - - while (FillReductionBuffer < FillReductionBufferEnd) { - *FillReductionBuffer++ = PoolingType::InitialValue(); - } - - FillReductionBuffer = FillReductionBuffer + InputWidth; - FillReductionBufferEnd = FillReductionBuffer + PaddingRightWidth + MLAS_POOL_REDUCTION_BUFFER_PADDING; - - while (FillReductionBuffer < FillReductionBufferEnd) { - *FillReductionBuffer++ = PoolingType::InitialValue(); - } - - // - // Apply the pooling operation to each channel. - // - - typename PoolingType::DividerVectorContext divider; - divider.PrepareExcludePad(PaddingLeftWidth, InputWidth, KernelWidth); - divider.PrepareIncludePad(KernelHeight * KernelWidth); - - for (size_t c = 0; c < ChannelCount; c++) { - - for (size_t ph = 0; ph < OutputHeight; ph++) { - - size_t ihStart = ph * StrideHeight - PaddingLeftHeight; - size_t ihEnd = ihStart + KernelHeight; - - if (ihStart >= InputHeight) { - ihStart = 0; - } - - if (ihEnd > InputHeight) { - ihEnd = InputHeight; - } - - divider.StartNextOutputRow(ihEnd - ihStart); - - // - // Reduce the input across the kernel height and store in a local - // reduction buffer. - // - - const float* InputRowStart = &Input[ihStart * InputWidth]; - const size_t InputRowsCount = ihEnd - ihStart - 1; - size_t InputWidthRemaining = InputWidth; - float* ReductionOutput = &ReductionBuffer[PaddingLeftWidth]; - - while (InputWidthRemaining >= 4) { - - const float* InputRow = InputRowStart; - size_t InputRowsRemaining = InputRowsCount; - MLAS_FLOAT32X4 Reduction = MlasLoadFloat32x4(InputRow); - - while (InputRowsRemaining > 0) { - InputRow += InputWidth; - Reduction = PoolingType::Reduce(Reduction, MlasLoadFloat32x4(InputRow)); - InputRowsRemaining--; - } - - MlasStoreFloat32x4(ReductionOutput, Reduction); - ReductionOutput += 4; - - InputRowStart += 4; - InputWidthRemaining -= 4; - } - - while (InputWidthRemaining > 0) { - - const float* InputRow = InputRowStart; - size_t InputRowsRemaining = InputRowsCount; - float Reduction = *InputRow; - - while (InputRowsRemaining > 0) { - InputRow += InputWidth; - Reduction = PoolingType::Reduce(Reduction, *InputRow); - InputRowsRemaining--; - } - - *ReductionOutput++ = Reduction; - - InputRowStart += 1; - InputWidthRemaining -= 1; - } - - // - // Reduce the input across the kernel width and store to the output - // tensor. - // - - size_t OutputWidthRemaining = OutputWidth; - const float* ReductionInputStart = ReductionBuffer; - - do { - - const float* ReductionInput = ReductionInputStart; - const float* ReductionInputEnd = ReductionInput + KernelWidth; - MLAS_FLOAT32X4 Reduction = MlasLoadFloat32x4(ReductionInput++); - - while (ReductionInput < ReductionInputEnd) { - Reduction = PoolingType::Reduce(Reduction, MlasLoadFloat32x4(ReductionInput++)); - } - - if (PoolingKind == MlasAveragePoolingExcludePad) { - Reduction = divider.DivideExcludePad(Reduction); - } else { - Reduction = divider.DivideIncludePad(Reduction); - } - - if (StrideWidth == 1) { - - if (OutputWidthRemaining < 4) { - - if (OutputWidthRemaining >= 2) { - - MlasStoreLowHalfFloat32x4(Output, Reduction); - - if (OutputWidthRemaining > 2) { - MlasStoreLaneFloat32x4<2>(Output + 2, Reduction); - } - - } else { - MlasStoreLaneFloat32x4<0>(Output, Reduction); - } - - Output += OutputWidthRemaining; - - break; - } - - MlasStoreFloat32x4(Output, Reduction); - - Output += 4; - OutputWidthRemaining -= 4; - - } else { - - if (OutputWidthRemaining == 1) { - MlasStoreLaneFloat32x4<0>(Output++, Reduction); - break; - } - -#if defined(MLAS_SSE2_INTRINSICS) - Reduction = _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(2, 0, 2, 0)); - MlasStoreLowHalfFloat32x4(Output, Reduction); -#else - MlasStoreLaneFloat32x4<0>(Output, Reduction); - MlasStoreLaneFloat32x4<2>(Output + 1, Reduction); -#endif - - Output += 2; - OutputWidthRemaining -= 2; - } - - ReductionInputStart += 4; - - } while (OutputWidthRemaining > 0); - } - - Input += InputSize; - } -} - -template -void -MlasPool3DKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements the 3D pooling operation using generic constructs. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - constexpr size_t DepthShapeIndex = 0; - constexpr size_t HeightShapeIndex = 1; - constexpr size_t WidthShapeIndex = 2; - - const MLAS_POOLING_KIND PoolingKind = WorkBlock->PoolingKind; - - const size_t InputDepth = WorkBlock->InputShape[DepthShapeIndex]; - const size_t InputHeight = WorkBlock->InputShape[HeightShapeIndex]; - const size_t InputWidth = WorkBlock->InputShape[WidthShapeIndex]; - const size_t InputSize = WorkBlock->InputSize; - const size_t OutputDepth = WorkBlock->OutputShape[DepthShapeIndex]; - const size_t OutputHeight = WorkBlock->OutputShape[HeightShapeIndex]; - const size_t OutputWidth = WorkBlock->OutputShape[WidthShapeIndex]; - - const int64_t KernelDepth = WorkBlock->KernelShape[DepthShapeIndex]; - const int64_t KernelHeight = WorkBlock->KernelShape[HeightShapeIndex]; - const int64_t KernelWidth = WorkBlock->KernelShape[WidthShapeIndex]; - const int64_t PaddingLeftDepth = WorkBlock->Padding[DepthShapeIndex]; - const int64_t PaddingLeftHeight = WorkBlock->Padding[HeightShapeIndex]; - const int64_t PaddingLeftWidth = WorkBlock->Padding[WidthShapeIndex]; - const int64_t StrideDepth = WorkBlock->StrideShape[DepthShapeIndex]; - const int64_t StrideHeight = WorkBlock->StrideShape[HeightShapeIndex]; - const int64_t StrideWidth = WorkBlock->StrideShape[WidthShapeIndex]; - - for (size_t c = 0; c < ChannelCount; c++) { - - for (size_t pd = 0; pd < OutputDepth; pd++) { - - const int64_t idStart64 = pd * StrideDepth - PaddingLeftDepth; - const int64_t idEnd64 = idStart64 + KernelDepth; - - const size_t idStart = size_t(std::max(idStart64, int64_t(0))); - const size_t idEnd = size_t(std::min(idEnd64, int64_t(InputDepth))); - - for (size_t ph = 0; ph < OutputHeight; ph++) { - - const int64_t ihStart64 = ph * StrideHeight - PaddingLeftHeight; - const int64_t ihEnd64 = ihStart64 + KernelHeight; - - const size_t ihStart = size_t(std::max(ihStart64, int64_t(0))); - const size_t ihEnd = size_t(std::min(ihEnd64, int64_t(InputHeight))); - - for (size_t pw = 0; pw < OutputWidth; pw++) { - - const int64_t iwStart64 = pw * StrideWidth - PaddingLeftWidth; - const int64_t iwEnd64 = iwStart64 + KernelWidth; - - const size_t iwStart = size_t(std::max(iwStart64, int64_t(0))); - const size_t iwEnd = size_t(std::min(iwEnd64, int64_t(InputWidth))); - - float m = PoolingType::InitialValue(); - - for (size_t id = idStart; id < idEnd; id++) { - for (size_t ih = ihStart; ih < ihEnd; ih++) { - for (size_t iw = iwStart; iw < iwEnd; iw++) { - m = PoolingType::Reduce(m, Input[id * InputHeight * InputWidth + ih * InputWidth + iw]); - } - } - } - - if (PoolingKind == MlasAveragePoolingExcludePad) { - m = PoolingType::AveragePool(m, float((idEnd - idStart) * (ihEnd - ihStart) * (iwEnd - iwStart))); - } else { - m = PoolingType::AveragePool(m, float(KernelDepth * KernelHeight * KernelWidth)); - } - - *Output++ = m; - } - } - } - - Input += InputSize; - } -} - -template -void -MlasPool3DVectorKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements an optimized 2D pooling operation using vector - instructions. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - constexpr size_t Dimensions = 3; - - constexpr size_t DepthShapeIndex = 0; - constexpr size_t HeightShapeIndex = 1; - constexpr size_t WidthShapeIndex = 2; - - const MLAS_POOLING_KIND PoolingKind = WorkBlock->PoolingKind; - - const size_t InputDepth = WorkBlock->InputShape[DepthShapeIndex]; - const size_t InputHeight = WorkBlock->InputShape[HeightShapeIndex]; - const size_t InputWidth = WorkBlock->InputShape[WidthShapeIndex]; - const size_t InputSize = WorkBlock->InputSize; - const size_t OutputDepth = WorkBlock->OutputShape[DepthShapeIndex]; - const size_t OutputHeight = WorkBlock->OutputShape[HeightShapeIndex]; - const size_t OutputWidth = WorkBlock->OutputShape[WidthShapeIndex]; - - const size_t KernelDepth = size_t(WorkBlock->KernelShape[DepthShapeIndex]); - const size_t KernelHeight = size_t(WorkBlock->KernelShape[HeightShapeIndex]); - const size_t KernelWidth = size_t(WorkBlock->KernelShape[WidthShapeIndex]); - const size_t PaddingLeftDepth = size_t(WorkBlock->Padding[DepthShapeIndex]); - const size_t PaddingLeftHeight = size_t(WorkBlock->Padding[HeightShapeIndex]); - const size_t PaddingLeftWidth = size_t(WorkBlock->Padding[WidthShapeIndex]); - const size_t PaddingRightWidth = size_t(WorkBlock->Padding[Dimensions + WidthShapeIndex]); - const size_t StrideDepth = size_t(WorkBlock->StrideShape[DepthShapeIndex]); - const size_t StrideHeight = size_t(WorkBlock->StrideShape[HeightShapeIndex]); - const size_t StrideWidth = size_t(WorkBlock->StrideShape[WidthShapeIndex]); - - float ReductionBuffer[MLAS_POOL_REDUCTION_BUFFER_STACK]; - - // - // Fill the edges of the reduction buffer with the padding value. - // - - float* FillReductionBuffer = ReductionBuffer; - float* FillReductionBufferEnd = FillReductionBuffer + PaddingLeftWidth; - - while (FillReductionBuffer < FillReductionBufferEnd) { - *FillReductionBuffer++ = PoolingType::InitialValue(); - } - - FillReductionBuffer = FillReductionBuffer + InputWidth; - FillReductionBufferEnd = FillReductionBuffer + PaddingRightWidth + MLAS_POOL_REDUCTION_BUFFER_PADDING; - - while (FillReductionBuffer < FillReductionBufferEnd) { - *FillReductionBuffer++ = PoolingType::InitialValue(); - } - - // - // Apply the pooling operation to each channel. - // - - typename PoolingType::DividerVectorContext divider; - divider.PrepareExcludePad(PaddingLeftWidth, InputWidth, KernelWidth); - divider.PrepareIncludePad(KernelDepth * KernelHeight * KernelWidth); - - for (size_t c = 0; c < ChannelCount; c++) { - - for (size_t pd = 0; pd < OutputDepth; pd++) { - - size_t idStart = pd * StrideDepth - PaddingLeftDepth; - size_t idEnd = idStart + KernelDepth; - - if (idStart >= InputDepth) { - idStart = 0; - } - - if (idEnd > InputDepth) { - idEnd = InputDepth; - } - - for (size_t ph = 0; ph < OutputHeight; ph++) { - - size_t ihStart = ph * StrideHeight - PaddingLeftHeight; - size_t ihEnd = ihStart + KernelHeight; - - if (ihStart >= InputHeight) { - ihStart = 0; - } - - if (ihEnd > InputHeight) { - ihEnd = InputHeight; - } - - divider.StartNextOutputRow((idEnd - idStart) * (ihEnd - ihStart)); - - // - // Reduce the input across the kernel height and store in a local - // reduction buffer. - // - - const float* InputRowStart = &Input[idStart * InputHeight * InputWidth + ihStart * InputWidth]; - const size_t InputPlanesCount = idEnd - idStart; - const size_t InputRowsCount = ihEnd - ihStart; - size_t InputWidthRemaining = InputWidth; - float* ReductionOutput = &ReductionBuffer[PaddingLeftWidth]; - const size_t InputAdvancePlane = (InputHeight - InputRowsCount) * InputWidth; - - while (InputWidthRemaining >= 4) { - - const float* InputRow = InputRowStart; - size_t InputPlanesRemaining = InputPlanesCount; - MLAS_FLOAT32X4 Reduction = PoolingType::InitialVector(); - - do { - - size_t InputRowsRemaining = InputRowsCount; - - do { - - Reduction = PoolingType::Reduce(Reduction, MlasLoadFloat32x4(InputRow)); - InputRow += InputWidth; - InputRowsRemaining--; - - } while (InputRowsRemaining > 0); - - InputRow += InputAdvancePlane; - InputPlanesRemaining--; - - } while (InputPlanesRemaining > 0); - - MlasStoreFloat32x4(ReductionOutput, Reduction); - ReductionOutput += 4; - - InputRowStart += 4; - InputWidthRemaining -= 4; - } - - while (InputWidthRemaining > 0) { - - const float* InputRow = InputRowStart; - size_t InputPlanesRemaining = InputPlanesCount; - float Reduction = PoolingType::InitialValue(); - - do { - - size_t InputRowsRemaining = InputRowsCount; - - do { - - Reduction = PoolingType::Reduce(Reduction, *InputRow); - InputRow += InputWidth; - InputRowsRemaining--; - - } while (InputRowsRemaining > 0); - - InputRow += InputAdvancePlane; - InputPlanesRemaining--; - - } while (InputPlanesRemaining > 0); - - *ReductionOutput++ = Reduction; - - InputRowStart += 1; - InputWidthRemaining -= 1; - } - - // - // Reduce the input across the kernel width and store to the output - // tensor. - // - - size_t OutputWidthRemaining = OutputWidth; - const float* ReductionInputStart = ReductionBuffer; - - do { - - const float* ReductionInput = ReductionInputStart; - const float* ReductionInputEnd = ReductionInput + KernelWidth; - MLAS_FLOAT32X4 Reduction = MlasLoadFloat32x4(ReductionInput++); - - while (ReductionInput < ReductionInputEnd) { - Reduction = PoolingType::Reduce(Reduction, MlasLoadFloat32x4(ReductionInput++)); - } - - if (PoolingKind == MlasAveragePoolingExcludePad) { - Reduction = divider.DivideExcludePad(Reduction); - } else { - Reduction = divider.DivideIncludePad(Reduction); - } - - if (StrideWidth == 1) { - - if (OutputWidthRemaining < 4) { - - if (OutputWidthRemaining >= 2) { - - MlasStoreLowHalfFloat32x4(Output, Reduction); - - if (OutputWidthRemaining > 2) { - MlasStoreLaneFloat32x4<2>(Output + 2, Reduction); - } - - } else { - MlasStoreLaneFloat32x4<0>(Output, Reduction); - } - - Output += OutputWidthRemaining; - - break; - } - - MlasStoreFloat32x4(Output, Reduction); - - Output += 4; - OutputWidthRemaining -= 4; - - } else { - - if (OutputWidthRemaining == 1) { - MlasStoreLaneFloat32x4<0>(Output++, Reduction); - break; - } - -#if defined(MLAS_SSE2_INTRINSICS) - Reduction = _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(2, 0, 2, 0)); - MlasStoreLowHalfFloat32x4(Output, Reduction); -#else - MlasStoreLaneFloat32x4<0>(Output, Reduction); - MlasStoreLaneFloat32x4<2>(Output + 1, Reduction); -#endif - - Output += 2; - OutputWidthRemaining -= 2; - } - - ReductionInputStart += 4; - - } while (OutputWidthRemaining > 0); - } - } - - Input += InputSize; - } -} - -template -void -MlasPoolGlobalKernel( - const MLAS_POOL_WORK_BLOCK* WorkBlock, - size_t ChannelCount, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements a global pooling operation. - -Arguments: - - WorkBlock - Supplies the structure that contains the pooling parameters. - - ChannelCount - Supplies the number of channels to process. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - const size_t InputSize = WorkBlock->InputSize; - const float InputSizeFloat = float(InputSize); - - // - // Apply the pooling operation to each channel. - // - - for (size_t c = 0; c < ChannelCount; c++) { - - size_t InputSizeRemaining = InputSize; - - // - // Iterate over the input buffer a vector at a time. - // - - MLAS_FLOAT32X4 Reduction = PoolingType::InitialVector(); - - while (InputSizeRemaining >= 4) { - Reduction = PoolingType::Reduce(Reduction, MlasLoadFloat32x4(Input)); - Input += 4; - InputSizeRemaining -= 4; - } - - // - // Reduce the vector to a single float value. - // - - float ReductionValue = PoolingType::Reduce(Reduction); - - // - // Iterate over the remaining input buffer an element at a time. - // - - while (InputSizeRemaining > 0) { - ReductionValue = PoolingType::Reduce(ReductionValue, *Input++); - InputSizeRemaining -= 1; - } - - // - // Apply average pooling if necessary. - // - - ReductionValue = PoolingType::AveragePool(ReductionValue, InputSizeFloat); - - *Output++ = ReductionValue; - } -} - -// -// Stores pointers to the pooling kernel routines. -// - -static MLAS_POOL_KERNEL_ROUTINE* const MlasPoolGenericKernels[][3] = -{ - { - MlasPool1DKernel, - MlasPool2DKernel, - MlasPool3DKernel, - }, - { - MlasPool1DKernel, - MlasPool2DKernel, - MlasPool3DKernel, - }, - { - MlasPool1DKernel, - MlasPool2DKernel, - MlasPool3DKernel, - }, -}; - -static MLAS_POOL_KERNEL_ROUTINE* const MlasPoolGlobalKernels[] = -{ - MlasPoolGlobalKernel, - MlasPoolGlobalKernel, - MlasPoolGlobalKernel, -}; - -static MLAS_POOL_KERNEL_ROUTINE* const MlasPoolVectorKernels[][2] = -{ - { - MlasPool2DVectorKernel, - MlasPool3DVectorKernel, - }, - { - MlasPool2DVectorKernel, - MlasPool3DVectorKernel, - }, - { - MlasPool2DVectorKernel, - MlasPool3DVectorKernel, - }, -}; - -void -MLASCALL -MlasPool( - MLAS_POOLING_KIND PoolingKind, - size_t Dimensions, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the pooling operation. - -Arguments: - - PoolingKind - Supplies the kind of pooling operation to perform. - - Dimensions - Supplies the number of dimensions. - - InputShape - Supplies the shape of the input tensor. - - KernelShape - Supplies the shape of the kernel transform. - - Padding - Supplies the number of padding elements at the edge of the input - tensor. - - StrideShape - Supplies the shape of the stride. - - OutputShape - Supplies the shape of the output tensor. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_POOL_WORK_BLOCK WorkBlock; - - WorkBlock.PoolingKind = PoolingKind; - - // - // Compute the total number of channels to process and advance the input - // and output shapes over the batch and channel counts. - // - - size_t TotalChannelCount = size_t(InputShape[0]) * size_t(InputShape[1]); - - InputShape += 2; - OutputShape += 2; - - // - // Save the pooling parameters. - // - - size_t InputSize = 1; - size_t OutputSize = 1; - - bool InputAndKernelShapeMatch = true; - bool AllStridesAreOne = true; - bool AllPaddingIsZero = true; - bool AllKernelsAreSmall = true; - - if (Dimensions > 3) { - MLAS_THROW_EX(std::runtime_error, "bad dimensions"); - } - - for (size_t dim = 0; dim < Dimensions; dim++) { - WorkBlock.InputShape[dim] = size_t(InputShape[dim]); - WorkBlock.OutputShape[dim] = size_t(OutputShape[dim]); - - if (KernelShape != nullptr) { - WorkBlock.KernelShape[dim] = KernelShape[dim]; - } else { - WorkBlock.KernelShape[dim] = InputShape[dim]; - } - - if (Padding != nullptr) { - WorkBlock.Padding[dim] = Padding[dim]; - WorkBlock.Padding[dim + Dimensions] = Padding[dim + Dimensions]; - } else { - WorkBlock.Padding[dim] = 0; - WorkBlock.Padding[dim + Dimensions] = 0; - } - - if (StrideShape != nullptr) { - WorkBlock.StrideShape[dim] = StrideShape[dim]; - } else { - WorkBlock.StrideShape[dim] = 1; - } - - InputSize *= WorkBlock.InputShape[dim]; - OutputSize *= WorkBlock.OutputShape[dim]; - - InputAndKernelShapeMatch &= (WorkBlock.KernelShape[dim] == int64_t(WorkBlock.InputShape[dim])); - AllStridesAreOne &= (WorkBlock.StrideShape[dim] == 1); - AllPaddingIsZero &= (WorkBlock.Padding[dim] == 0 && WorkBlock.Padding[dim + Dimensions] == 0); - AllKernelsAreSmall &= (WorkBlock.KernelShape[dim] <= 32); - } - - WorkBlock.InputSize = InputSize; - - // - // Determine which pooling kernel routine to use. - // - // The vectorized kernels only support strides of 1 or 2. The kernel size - // should be kept low in order to keep the divisors for average pooling to - // be exactly representable as float. The input width plus padding must fit - // in the reduction buffer. - // - - MLAS_POOL_KERNEL_ROUTINE* PoolKernelRoutine = MlasPoolGenericKernels[PoolingKind][Dimensions - 1]; - - if (InputAndKernelShapeMatch && AllStridesAreOne && AllPaddingIsZero) { - - PoolKernelRoutine = MlasPoolGlobalKernels[PoolingKind]; - - } else if (Dimensions >= 2 && WorkBlock.StrideShape[Dimensions - 1] <= 2 && AllKernelsAreSmall) { - - int64_t ReductionBufferRemaining = MLAS_POOL_REDUCTION_BUFFER_STACK - MLAS_POOL_REDUCTION_BUFFER_PADDING; - - if (ReductionBufferRemaining >= WorkBlock.Padding[Dimensions - 1]) { - ReductionBufferRemaining -= WorkBlock.Padding[Dimensions - 1]; - } else { - ReductionBufferRemaining = 0; - } - - if (ReductionBufferRemaining >= WorkBlock.Padding[Dimensions * 2 - 1]) { - ReductionBufferRemaining -= WorkBlock.Padding[Dimensions * 2 - 1]; - } else { - ReductionBufferRemaining = 0; - } - - if (ReductionBufferRemaining >= int64_t(WorkBlock.InputShape[Dimensions - 1])) { - PoolKernelRoutine = MlasPoolVectorKernels[PoolingKind][Dimensions - 2]; - } - } - -#ifdef BUILD_MLAS_NO_ONNXRUNTIME - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - // - // Execute the pooling kernel routine. - // - - MLAS_UNREFERENCED_PARAMETER(OutputSize); - PoolKernelRoutine(&WorkBlock, TotalChannelCount, Input, Output); -#else - // - // Use an external thread pool if one is provided. - // TODO: change to use MlasExecuteThreaded - onnxruntime::concurrency::ThreadPool::TryBatchParallelFor(ThreadPool, static_cast(TotalChannelCount), [&](ptrdiff_t c) { - PoolKernelRoutine(&WorkBlock, 1, Input + c * InputSize, Output + c * OutputSize); - }, 0); - return; -#endif -} - -template -void -MLASCALL -MlasMaximumPool( - const T8Bits* const* Input, - T8Bits* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -/*++ - -Routine Description: - - This routine implements the maximum pooling operation. - - The input is supplied as an indirection buffer. Every pointer in the - indirection buffer points at a Channels length vector (either from the - input tensor or a vector of padding values). These are grouped in batches - of length KernelSize that are processed by the kernel to produce a single - output of length Channels. These batches are then repeated OutputCount - times. - -Arguments: - - Input - Supplies an indirection buffer to the elements of the input tensor. - - Output - Supplies the output tensor in channels last format. - - Channels - Supplies the number of channels. - - OutputCount - Supplies the number of channel sized output elements to - produce. - - KernelSize - Supplies the total number of channel sized kernel elements to - consume. - -Return Value: - - None. - ---*/ -{ - while (OutputCount > 0) { - - size_t ChannelOffset = 0; - size_t c = Channels; - -#if defined(MLAS_SSE2_INTRINSICS) - const __m128i BitFlipVector = _mm_set1_epi32(0x80808080); - if constexpr (std::is_unsigned::value) { - MLAS_UNREFERENCED_PARAMETER(BitFlipVector); - } - - while (c >= 32) { - - __m128i MaximumVector0 = _mm_setzero_si128(); - __m128i MaximumVector1 = _mm_setzero_si128(); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = _mm_loadu_si128((const __m128i*)&Input[k][ChannelOffset]); - __m128i InputVector1 = _mm_loadu_si128((const __m128i*)&Input[k][ChannelOffset + 16]); - - if constexpr (std::is_signed::value) { - InputVector0 = _mm_xor_si128(InputVector0, BitFlipVector); - InputVector1 = _mm_xor_si128(InputVector1, BitFlipVector); - } - - MaximumVector0 = _mm_max_epu8(MaximumVector0, InputVector0); - MaximumVector1 = _mm_max_epu8(MaximumVector1, InputVector1); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = _mm_xor_si128(MaximumVector0, BitFlipVector); - MaximumVector1 = _mm_xor_si128(MaximumVector1, BitFlipVector); - } - - _mm_storeu_si128((__m128i*)&Output[0], MaximumVector0); - _mm_storeu_si128((__m128i*)&Output[16], MaximumVector1); - Output += 32; - - ChannelOffset += 32; - c -= 32; - } - - while (c >= 16) { - - __m128i MaximumVector0 = _mm_setzero_si128(); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = _mm_loadu_si128((const __m128i*)&Input[k][ChannelOffset]); - - if constexpr (std::is_signed::value){ - InputVector0 = _mm_xor_si128(InputVector0, BitFlipVector); - } - - MaximumVector0 = _mm_max_epu8(MaximumVector0, InputVector0); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = _mm_xor_si128(MaximumVector0, BitFlipVector); - } - - _mm_storeu_si128((__m128i*)&Output[0], MaximumVector0); - Output += 16; - - ChannelOffset += 16; - c -= 16; - } - - if (c >= 8) { - - __m128i MaximumVector0 = _mm_setzero_si128(); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = _mm_loadl_epi64((const __m128i*)&Input[k][ChannelOffset]); - - if constexpr (std::is_signed::value){ - InputVector0 = _mm_xor_si128(InputVector0, BitFlipVector); - } - - MaximumVector0 = _mm_max_epu8(MaximumVector0, InputVector0); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = _mm_xor_si128(MaximumVector0, BitFlipVector); - } - - _mm_storel_epi64((__m128i*)&Output[0], MaximumVector0); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - -#elif defined(MLAS_NEON_INTRINSICS) - - while (c >= 32) { - - if constexpr (std::is_signed::value){ - - int8x16_t MaximumVector0 = vdupq_n_s8(-128); - int8x16_t MaximumVector1 = vdupq_n_s8(-128); - - for (size_t k = 0; k < KernelSize; k++) { - - int8x16_t InputVector0 = vld1q_s8(&Input[k][ChannelOffset]); - int8x16_t InputVector1 = vld1q_s8(&Input[k][ChannelOffset + 16]); - - MaximumVector0 = vmaxq_s8(MaximumVector0, InputVector0); - MaximumVector1 = vmaxq_s8(MaximumVector1, InputVector1); - } - - vst1q_s8(&Output[0], MaximumVector0); - vst1q_s8(&Output[16], MaximumVector1); - } else { - - uint8x16_t MaximumVector0 = vdupq_n_u8(0); - uint8x16_t MaximumVector1 = vdupq_n_u8(0); - - for (size_t k = 0; k < KernelSize; k++) { - - uint8x16_t InputVector0 = vld1q_u8(&Input[k][ChannelOffset]); - uint8x16_t InputVector1 = vld1q_u8(&Input[k][ChannelOffset + 16]); - - MaximumVector0 = vmaxq_u8(MaximumVector0, InputVector0); - MaximumVector1 = vmaxq_u8(MaximumVector1, InputVector1); - } - - vst1q_u8(&Output[0], MaximumVector0); - vst1q_u8(&Output[16], MaximumVector1); - } - - Output += 32; - - ChannelOffset += 32; - c -= 32; - } - - while (c >= 16) { - - if constexpr (std::is_signed::value){ - - int8x16_t MaximumVector0 = vdupq_n_s8(-128); - - for (size_t k = 0; k < KernelSize; k++) { - - int8x16_t InputVector0 = vld1q_s8(&Input[k][ChannelOffset]); - MaximumVector0 = vmaxq_s8(MaximumVector0, InputVector0); - } - - vst1q_s8(&Output[0], MaximumVector0); - } else { - - uint8x16_t MaximumVector0 = vdupq_n_u8(0); - - for (size_t k = 0; k < KernelSize; k++) { - - uint8x16_t InputVector0 = vld1q_u8(&Input[k][ChannelOffset]); - MaximumVector0 = vmaxq_u8(MaximumVector0, InputVector0); - } - - vst1q_u8(&Output[0], MaximumVector0); - } - - Output += 16; - - ChannelOffset += 16; - c -= 16; - } - - if (c >= 8) { - - if constexpr (std::is_signed::value){ - - int8x8_t MaximumVector0 = vdup_n_s8(-128); - - for (size_t k = 0; k < KernelSize; k++) { - - int8x8_t InputVector0 = vld1_s8(&Input[k][ChannelOffset]); - MaximumVector0 = vmax_s8(MaximumVector0, InputVector0); - } - - vst1_s8(&Output[0], MaximumVector0); - } else { - - uint8x8_t MaximumVector0 = vdup_n_u8(0); - - for (size_t k = 0; k < KernelSize; k++) { - - uint8x8_t InputVector0 = vld1_u8(&Input[k][ChannelOffset]); - MaximumVector0 = vmax_u8(MaximumVector0, InputVector0); - } - vst1_u8(&Output[0], MaximumVector0); - } - - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - -#elif defined(MLAS_TARGET_POWER) - - while (c >= 32) { - auto MaximumVector0 = vec_splats(std::numeric_limits::lowest()); - auto MaximumVector1 = vec_splats(std::numeric_limits::lowest()); - - for (size_t k = 0; k < KernelSize; k++) { - auto InputVector0 = vec_xl(0, &Input[k][ChannelOffset]); - auto InputVector1 = vec_xl(16, &Input[k][ChannelOffset]); - - MaximumVector0 = vec_max(MaximumVector0, InputVector0); - MaximumVector1 = vec_max(MaximumVector1, InputVector1); - } - - vec_xst(MaximumVector0, 0, (T8Bits *) Output); - vec_xst(MaximumVector1, 16, (T8Bits *) Output); - - Output += 32; - ChannelOffset += 32; - c -= 32; - } - - while (c >= 16) { - auto MaximumVector = vec_splats(std::numeric_limits::lowest()); - - for (size_t k = 0; k < KernelSize; k++) { - auto InputVector = vec_xl(0, &Input[k][ChannelOffset]); - MaximumVector = vec_max(MaximumVector, InputVector); - } - vec_xst(MaximumVector, 0, (T8Bits *) Output); - - Output += 16; - ChannelOffset += 16; - c -= 16; - } - -#elif defined(MLAS_LSX_INTRINSICS) - uint32_t val = 0x80808080; - const __m128i BitFlipVector = __lsx_vreplgr2vr_w(val); - if constexpr (std::is_unsigned::value) { - MLAS_UNREFERENCED_PARAMETER(BitFlipVector); - } - - while (c >= 32) { - - __m128i MaximumVector0 = __lsx_vldi(0); - __m128i MaximumVector1 = __lsx_vldi(0); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); - __m128i InputVector1 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset + 16], 0); - - if constexpr (std::is_signed::value) { - InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); - InputVector1 = __lsx_vxor_v(InputVector1, BitFlipVector); - } - - MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); - MaximumVector1 = __lsx_vmax_bu(MaximumVector1, InputVector1); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); - MaximumVector1 = __lsx_vxor_v(MaximumVector1, BitFlipVector); - } - - __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); - __lsx_vst(MaximumVector1, (__m128i*)&Output[16], 0); - Output += 32; - - ChannelOffset += 32; - c -= 32; - } - - while (c >= 16) { - - __m128i MaximumVector0 = __lsx_vldi(0); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); - - if constexpr (std::is_signed::value){ - InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); - } - - MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); - } - - __lsx_vst(MaximumVector0, (__m128i*)&Output[0], 0); - Output += 16; - - ChannelOffset += 16; - c -= 16; - } - - if (c >= 8) { - - __m128i MaximumVector0 = __lsx_vldi(0); - - for (size_t k = 0; k < KernelSize; k++) { - - __m128i InputVector0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0), 0, 1); - - if constexpr (std::is_signed::value){ - InputVector0 = __lsx_vxor_v(InputVector0, BitFlipVector); - } - - MaximumVector0 = __lsx_vmax_bu(MaximumVector0, InputVector0); - } - - if constexpr (std::is_signed::value) { - MaximumVector0 = __lsx_vxor_v(MaximumVector0, BitFlipVector); - } - - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i*)&Output[0] , 0), __lsx_vpickve2gr_d(MaximumVector0, 0), 0), (__m128i*)&Output[0], 0); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } -#endif - - while (c > 0) { - - int32_t MaximumValue = std::numeric_limits::lowest(); - - for (size_t k = 0; k < KernelSize; k++) { - MaximumValue = std::max(MaximumValue, int32_t(Input[k][ChannelOffset])); - } - - *Output++ = T8Bits(MaximumValue); - - ChannelOffset += 1; - c -= 1; - } - - Input += KernelSize; - OutputCount -= 1; - } -} - -template -void -MLASCALL -MlasMaximumPool( - const int8_t* const* Input, - int8_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasMaximumPool( - const uint8_t* const* Input, - uint8_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); diff --git a/onnxruntime/core/mlas/lib/pooling_fp16.cpp b/onnxruntime/core/mlas/lib/pooling_fp16.cpp deleted file mode 100644 index 98e84736cb55f..0000000000000 --- a/onnxruntime/core/mlas/lib/pooling_fp16.cpp +++ /dev/null @@ -1,355 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - pooling_fp16.cpp - -Abstract: - This module implements the pooling operation for fp16 - tensors in NHWC format. ---*/ - -#include "mlasi.h" - -#include "fp16_common.h" - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - - -template -typename AggregationType::CtxType -PoolCreateContext(size_t KernelSize); - -template -MLAS_FLOAT16X8 -PoolInit16x8(); - -template -MLAS_FLOAT16X4 -PoolInit16x4(); - -template -MLAS_FLOAT16X8 -PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element); - -template -MLAS_FLOAT16X4 -PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element); - -template -MLAS_FLOAT16X8 -PoolSummary16x8(MLAS_FLOAT16X8 agg, typename AggregationType::CtxType context); - -template -MLAS_FLOAT16X4 -PoolSummary16x4(MLAS_FLOAT16X4 agg, typename AggregationType::CtxType context); - - -struct MaxPoolAggregation { - typedef size_t CtxType; // useless type to satisfy compilers -}; - -template <> -MLAS_FORCEINLINE -size_t -PoolCreateContext(size_t KernelSize) -{ - MLAS_UNREFERENCED_PARAMETER(KernelSize); - return 0; -} - -template<> -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -PoolInit16x8() -{ - // lowest fp16 -65504.0f - return MlasBroadcastFloat16x8(0xfbff); -} - -template <> -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -PoolInit16x4() -{ - // lowest fp16 -65504.0f - return MlasBroadcastFloat16x4(0xfbff); -} - -template<> -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) -{ - return MlasMaximumFloat16x8(agg, element); -} - -template<> -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) -{ - return MlasMaximumFloat16x4(agg, element); -} - -template<> -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -PoolSummary16x8(MLAS_FLOAT16X8 agg, size_t size) -{ - MLAS_UNREFERENCED_PARAMETER(size); - return agg; -} - -template<> -MLAS_FORCEINLINE -MLAS_FLOAT16X4 -PoolSummary16x4(MLAS_FLOAT16X4 agg, size_t size) -{ - MLAS_UNREFERENCED_PARAMETER(size); - return agg; -} - -struct AveragePoolAggregation { - typedef MLAS_FLOAT16X8 CtxType; -}; - -template<> -MLAS_FLOAT16X8 -PoolCreateContext(size_t KernelSize) -{ - return MlasBroadcastFloat16x8(MLAS_Float2Half(float(KernelSize))); -} - - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X8 -PoolInit16x8() -{ - return MlasZeroFloat16x8(); -} - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X4 -PoolInit16x4() -{ - return MlasZeroFloat16x4(); -} - - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X8 -PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) -{ - return MlasAddFloat16x8(agg, element); -} - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X4 -PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) -{ - return MlasAddFloat16x4(agg, element); -} - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X8 -PoolSummary16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 context) -{ - return MlasDivFloat16x8(agg, context); -} - -template <> -MLAS_FORCEINLINE MLAS_FLOAT16X4 -PoolSummary16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X8 context) -{ - return MlasDivFloat16x4(agg, MlasToLowHalfFloat16x4(context)); -} - - -template -MLAS_FORCEINLINE -void -MlasPoolFp16HWC( - const _mlas_fp16_* const* Input, - _mlas_fp16_* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -{ - while (OutputCount > 0) { - size_t ChannelOffset = 0; - size_t c = Channels; - - while (c >= 32) { - MLAS_FLOAT16X8 MaximumVector0 = PoolInit16x8(); - MLAS_FLOAT16X8 MaximumVector1 = MaximumVector0; - MLAS_FLOAT16X8 MaximumVector2 = MaximumVector0; - MLAS_FLOAT16X8 MaximumVector3 = MaximumVector0; - size_t cnt = 0; - - for (size_t k = 0; k < KernelSize; k++) { - if (Input[k] == nullptr) { - continue; - } - MLAS_FLOAT16X8 InputVector0 = MlasLoadFloat16x8(&Input[k][ChannelOffset]); - MLAS_FLOAT16X8 InputVector1 = MlasLoadFloat16x8(&Input[k][ChannelOffset + 8]); - MLAS_FLOAT16X8 InputVector2 = MlasLoadFloat16x8(&Input[k][ChannelOffset + 16]); - MLAS_FLOAT16X8 InputVector3 = MlasLoadFloat16x8(&Input[k][ChannelOffset + 24]); - - MaximumVector0 = PoolAggregate16x8(MaximumVector0, InputVector0); - MaximumVector1 = PoolAggregate16x8(MaximumVector1, InputVector1); - MaximumVector2 = PoolAggregate16x8(MaximumVector2, InputVector2); - MaximumVector3 = PoolAggregate16x8(MaximumVector3, InputVector3); - cnt++; - } - typename AggregationType::CtxType context = PoolCreateContext(cnt); - MaximumVector0 = PoolSummary16x8(MaximumVector0, context); - MaximumVector1 = PoolSummary16x8(MaximumVector1, context); - MaximumVector2 = PoolSummary16x8(MaximumVector2, context); - MaximumVector3 = PoolSummary16x8(MaximumVector3, context); - - MlasStoreFloat16x8(&Output[0], MaximumVector0); - MlasStoreFloat16x8(&Output[8], MaximumVector1); - MlasStoreFloat16x8(&Output[16], MaximumVector2); - MlasStoreFloat16x8(&Output[24], MaximumVector3); - - Output += 32; - ChannelOffset += 32; - c -= 32; - } - - if (c >= 16) { - MLAS_FLOAT16X8 MaximumVector0 = PoolInit16x8(); - MLAS_FLOAT16X8 MaximumVector1 = MaximumVector0; - size_t cnt = 0; - - for (size_t k = 0; k < KernelSize; k++) { - if (Input[k] == nullptr) { - continue; - } - MLAS_FLOAT16X8 InputVector0 = MlasLoadFloat16x8(&Input[k][ChannelOffset]); - MLAS_FLOAT16X8 InputVector1 = MlasLoadFloat16x8(&Input[k][ChannelOffset + 8]); - - MaximumVector0 = PoolAggregate16x8(MaximumVector0, InputVector0); - MaximumVector1 = PoolAggregate16x8(MaximumVector1, InputVector1); - cnt++; - } - typename AggregationType::CtxType context = PoolCreateContext(cnt); - MaximumVector0 = PoolSummary16x8(MaximumVector0, context); - MaximumVector1 = PoolSummary16x8(MaximumVector1, context); - - MlasStoreFloat16x8(&Output[0], MaximumVector0); - MlasStoreFloat16x8(&Output[8], MaximumVector1); - - Output += 16; - ChannelOffset += 16; - c -= 16; - } - - if (c >= 8) { - MLAS_FLOAT16X8 MaximumVector0 = PoolInit16x8(); - size_t cnt = 0; - - for (size_t k = 0; k < KernelSize; k++) { - if (Input[k] == nullptr) { - continue; - } - MLAS_FLOAT16X8 InputVector0 = MlasLoadFloat16x8(&Input[k][ChannelOffset]); - MaximumVector0 = PoolAggregate16x8(MaximumVector0, InputVector0); - cnt++; - } - typename AggregationType::CtxType context = PoolCreateContext(cnt); - MaximumVector0 = PoolSummary16x8(MaximumVector0, context); - - MlasStoreFloat16x8(&Output[0], MaximumVector0); - - Output += 8; - ChannelOffset += 8; - c -= 8; - } - - if (c >= 4) { - MLAS_FLOAT16X4 MaximumVector0 = PoolInit16x4(); - size_t cnt = 0; - - for (size_t k = 0; k < KernelSize; k++) { - if (Input[k] == nullptr) { - continue; - } - MLAS_FLOAT16X4 InputVector0 = MlasLoadFloat16x4(&Input[k][ChannelOffset]); - MaximumVector0 = PoolAggregate16x4(MaximumVector0, InputVector0); - cnt++; - } - typename AggregationType::CtxType context = PoolCreateContext(cnt); - MaximumVector0 = PoolSummary16x4(MaximumVector0, context); - - MlasStoreFloat16x4(&Output[0], MaximumVector0); - - Output += 4; - ChannelOffset += 4; - c -= 4; - } - - if (c > 0) { - // possible over read by 7 bytes - MLAS_FLOAT16X4 MaximumVector0 = PoolInit16x4(); - size_t cnt = 0; - - for (size_t k = 0; k < KernelSize; k++) { - if (Input[k] == nullptr) { - continue; - } - MLAS_FLOAT16X4 InputVector0 = MlasLoadFloat16x4(&Input[k][ChannelOffset]); - MaximumVector0 = PoolAggregate16x4(MaximumVector0, InputVector0); - cnt++; - } - typename AggregationType::CtxType context = PoolCreateContext(cnt); - MaximumVector0 = PoolSummary16x4(MaximumVector0, context); - - MlasStorePartialFloat16x4(&Output[0], MaximumVector0, c); - Output += c; - } - - Input += KernelSize; - OutputCount -= 1; - } -} - - -void -MLASCALL -MlasNhwcMaxPool( - const MLAS_FP16* const* Input, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -{ - const _mlas_fp16_* const* input_ptr = reinterpret_cast(Input); - _mlas_fp16_* output_ptr = reinterpret_cast<_mlas_fp16_*>(Output); - MlasPoolFp16HWC(input_ptr, output_ptr, Channels, OutputCount, KernelSize); -} - - -void -MLASCALL -MlasNhwcAvgPool( - const MLAS_FP16* const* Input, - MLAS_FP16* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -{ - const _mlas_fp16_* const* input_ptr = reinterpret_cast(Input); - _mlas_fp16_* output_ptr = reinterpret_cast<_mlas_fp16_*>(Output); - MlasPoolFp16HWC(input_ptr, output_ptr, Channels, OutputCount, KernelSize); -} - - -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp deleted file mode 100644 index 560df1e6188ee..0000000000000 --- a/onnxruntime/core/mlas/lib/power/DgemmKernelPOWER10.cpp +++ /dev/null @@ -1,425 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelPower.cpp - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - ---*/ - -#include "DgemmKernelpower.h" -struct MlasDgemmBroadcastAElementsMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - double ARow[RowCount], - const double* A, - size_t lda - ) - { - ARow[Row] = A [Row * lda]; - } -}; - -template -MLAS_FORCEINLINE -void -MlasDgemmComputeAElements( - MLAS_FLOAT64X2 AElements[RowCount], - MLAS_FLOAT64X2 ABroadcast[RowCount] - ) -{ - ABroadcast[0] = vec_mergee (AElements[0], AElements[1]); - ABroadcast[1] = vec_mergee (AElements[2], AElements[3]); - ABroadcast[2] = vec_mergeo (AElements[0], AElements[1]); - ABroadcast[3] = vec_mergeo (AElements[2], AElements[3]); -} - -template -MLAS_FORCEINLINE -void -MlasDgemmComputeBlockMMA( - __vector_quad acc[8], - MLAS_FLOAT64X2 ABroadcast[RowCount], - MLAS_FLOAT64X2 A2Broadcast[RowCount], - const double* B, - size_t CountM - ) -{ - MLAS_FLOAT64X2 BElements[4]; - typedef __vector unsigned char vec_t; - __vector_pair A2pair, Apair; -#if (defined(__GNUC__) && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3)) -#if (__BYTE_ORDER__ != __ORDER_BIG_ENDIAN__) - __builtin_mma_assemble_pair (&Apair, reinterpret_cast(ABroadcast[1]), reinterpret_cast(ABroadcast[0])); - if (CountM == 8) { - __builtin_mma_assemble_pair (&A2pair, reinterpret_cast(A2Broadcast[1]), reinterpret_cast(A2Broadcast[0])); - } -#else - __builtin_mma_assemble_pair (&Apair, reinterpret_cast(ABroadcast[0]), reinterpret_cast(ABroadcast[1])); - if (CountM == 8) { - __builtin_mma_assemble_pair (&A2pair, reinterpret_cast(A2Broadcast[0]), reinterpret_cast(A2Broadcast[1])); - } -#endif -#elif (defined(__GNUC__) && (__GNUC__ == 11 && __GNUC_MINOR__ <= 2)) - Apair = *reinterpret_cast<__vector_pair *>(&ABroadcast[0]); - if (CountM == 8) { - A2pair = *reinterpret_cast<__vector_pair *>(&A2Broadcast[0]); - } -#else - __builtin_vsx_build_pair (&Apair, reinterpret_cast(ABroadcast[0]), reinterpret_cast(ABroadcast[1])); - if (CountM == 8) { - __builtin_vsx_build_pair (&A2pair, reinterpret_cast(A2Broadcast[0]), reinterpret_cast(A2Broadcast[1])); - } -#endif - BElements[0] = MlasLoadFloat64x2(B); - BElements[1] = MlasLoadFloat64x2(B + 2); - BElements[2] = MlasLoadFloat64x2(B + 4); - BElements[3] = MlasLoadFloat64x2(B + 6); - __builtin_mma_xvf64gerpp (&acc[0], Apair, reinterpret_cast(BElements[0])); - __builtin_mma_xvf64gerpp (&acc[1], Apair, reinterpret_cast(BElements[1])); - __builtin_mma_xvf64gerpp (&acc[2], Apair, reinterpret_cast(BElements[2])); - __builtin_mma_xvf64gerpp (&acc[3], Apair, reinterpret_cast(BElements[3])); - if (CountM == 8) { - __builtin_mma_xvf64gerpp (&acc[4], A2pair, reinterpret_cast(BElements[0])); - __builtin_mma_xvf64gerpp (&acc[5], A2pair, reinterpret_cast(BElements[1])); - __builtin_mma_xvf64gerpp (&acc[6], A2pair, reinterpret_cast(BElements[2])); - __builtin_mma_xvf64gerpp (&acc[7], A2pair, reinterpret_cast(BElements[3])); - } -} -template -struct MlasDgemmStoreVectorMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Result[4], - double* C, - size_t ldc, - MLAS_FLOAT64X2 AlphaBroadcast, - bool ZeroMode - ) - { - MLAS_FLOAT64X2 *rowC; - if (ZeroMode) { - rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); - rowC[0] = Result[Row] * AlphaBroadcast; - } else { - rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); - rowC[0] += Result[Row] * AlphaBroadcast; - } - } -}; - -struct MlasDgemmMultiplyAlphaTrailingMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount], - MLAS_FLOAT64X2 AlphaBroadcast - ) - { - Accumulators[Row] = MlasMultiplyFloat64x2(Accumulators[Row], AlphaBroadcast); - } -}; -template -struct MlasDgemmStoreScalarMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT64X2 Accumulators[RowCount], - double* C, - size_t ldc, - bool ZeroMode - ) - { - double* c = C + Row * ldc + Lane; - double Value = Accumulators[Row][Lane]; - if (!ZeroMode) { - Value += *c; - } - - *c = Value; - } -}; - -template -MLAS_FORCEINLINE -size_t -MlasDgemmMMAProcessCount( - const double* A, - const double* B, - double* C, - size_t CountM, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - MLAS_FLOAT64X2 AlphaBroadcast, - bool ZeroMode - ) -{ - do { - - const double* a = A; - size_t k = CountK; - - MLAS_FLOAT64X2 Accumulators[2][RowCount] = {{ 0 }}; - MLAS_FLOAT64X2 Result[RowCount]; - MLAS_FLOAT64X2 AElements[RowCount]; - MLAS_FLOAT64X2 ABroadcast[RowCount] = { 0 }; - MLAS_FLOAT64X2 A2Broadcast[RowCount] = { 0 }; - MLAS_FLOAT64X2 A3Broadcast[RowCount] = { 0 }; - MLAS_FLOAT64X2 A4Broadcast[RowCount] = { 0 }; - double ARow[RowCount] = { 0 }; - double A2Row[RowCount] = { 0 }; - __vector_quad acc[8]; - - // - // Clear the block accumulators. - // - __builtin_mma_xxsetaccz(&acc[0]); - __builtin_mma_xxsetaccz(&acc[1]); - __builtin_mma_xxsetaccz(&acc[2]); - __builtin_mma_xxsetaccz(&acc[3]); - __builtin_mma_xxsetaccz(&acc[4]); - __builtin_mma_xxsetaccz(&acc[5]); - __builtin_mma_xxsetaccz(&acc[6]); - __builtin_mma_xxsetaccz(&acc[7]); - - // - // Compute the output block. - // - while (k >= 4) { - - MlasLoopUnroll()(AElements, a, lda); - MlasDgemmComputeAElements(AElements, ABroadcast); - MlasLoopUnroll()(AElements, a+2, lda); - MlasDgemmComputeAElements(AElements, A3Broadcast); - if (CountM == 8) { - MlasLoopUnroll()(AElements, a + ( lda * 4), lda); - MlasDgemmComputeAElements(AElements, A2Broadcast); - MlasLoopUnroll()(AElements, (a+2) + ( lda * 4), lda); - MlasDgemmComputeAElements(AElements, A4Broadcast); - } - MlasDgemmComputeBlockMMA(&acc[0], &ABroadcast[0], &A2Broadcast[0], B, CountM); - MlasDgemmComputeBlockMMA(&acc[0], &ABroadcast[2], &A2Broadcast[2], B+8, CountM); - MlasDgemmComputeBlockMMA(&acc[0], &A3Broadcast[0], &A4Broadcast[0], B+16, CountM); - MlasDgemmComputeBlockMMA(&acc[0], &A3Broadcast[2], &A4Broadcast[2], B+24, CountM); - B += 8 * 4; - a += 4; - k -= 4; - } - while (k > 0) { - MlasLoopUnroll()(ARow, a, lda); - if (CountM == 8) { - MlasLoopUnroll()(A2Row, a + (lda * 4), lda); - } - - MlasDgemmComputeBlockMMA(&acc[0], (MLAS_FLOAT64X2 *)ARow, (MLAS_FLOAT64X2 *)A2Row, B, CountM); - a += 1; - B += 8; - k -= 1; - } - if (CountN >= 8) { - - // - // Store the entire output block. - // - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[2]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[3]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[6]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[7]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - } - } else { - - // - // Store the partial output block. - // - - if (CountN >= 6) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[2]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[6]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 6 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[7]); - } - } - if (CountN - 6 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[3]); - } - } else if (CountN >= 4) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[6]); - } - } - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[2]); - } - } else if (CountN >= 2) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 2 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[5]); - } - } - if (CountN - 2 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[1]); - } - } else { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[0]); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[4]); - } - } - - // - // Store the remaining unaligned columns. - // - C += (CountN & ~1); - CountN &= 1; - - if (CountN > 0) { - - MlasLoopUnroll()(Accumulators[0], AlphaBroadcast); - MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); - if (CountM == 8) { - MlasLoopUnroll()(Accumulators[1], AlphaBroadcast); - MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); - } - } - - break; - } - - C += 8; - CountN -= 8; - - } while (CountN > 0); - - return CountM; -} - -size_t -MLASCALL -MlasDgemmKernelPOWER10( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - double alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasDgemmCopyPackB or MlasDgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see DGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha); - if (CountM >= 8) { - RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 4) { - RowsHandled = MlasDgemmMMAProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 2) { - RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else { - RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } - - return RowsHandled; -} diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelPower.cpp b/onnxruntime/core/mlas/lib/power/DgemmKernelPower.cpp deleted file mode 100644 index ba8317544000c..0000000000000 --- a/onnxruntime/core/mlas/lib/power/DgemmKernelPower.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelPower.cpp - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - ---*/ -#include "DgemmKernelpower.h" - -size_t -MLASCALL -MlasDgemmKernel( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - double alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasDgemmCopyPackB or MlasDgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see DGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - - MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha); - - if (CountM >= 4) { - RowsHandled = MlasDgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 2) { - RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else { - RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } - - return RowsHandled; -} diff --git a/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h deleted file mode 100644 index 0dca7e4e43961..0000000000000 --- a/onnxruntime/core/mlas/lib/power/DgemmKernelpower.h +++ /dev/null @@ -1,122 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelpower.h - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - ---*/ - -#include "FgemmKernelpower.h" - -template -MLAS_FORCEINLINE -size_t -MlasDgemmProcessCount( - const double* A, - const double* B, - double* C, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - MLAS_FLOAT64X2 AlphaBroadcast, - bool ZeroMode - ) -{ - do { - - const double* a = A; - size_t k = CountK; - - MLAS_FLOAT64X2 Accumulators[RowCount][4]; - MLAS_FLOAT64X2 AElements[RowCount]; - MLAS_FLOAT64X2 ABroadcast[RowCount]; - - // - // Clear the block accumulators. - // - - MlasLoopUnroll()(Accumulators); - - // - // Compute the output block. - // - while (k >= 2) { - - MlasLoopUnroll()(AElements, a, lda); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 8); - - a += 2; - B += 8 * 2; - k -= 2; - } - if (k > 0) { - - MlasLoopUnroll()(ABroadcast, a, lda); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - - a += 1; - B += 8; - k -= 1; - } - - if (CountN >= 8) { - - // - // Store the entire output block. - // - - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - - } else { - - // - // Store the partial output block. - // - - // - if (CountN >= 6) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } else if (CountN >= 4) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } else if (CountN >= 2) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } - // - // Store the remaining unaligned columns. - // - C += (CountN & ~1); - CountN &= 1; - - if (CountN > 0) { - - MlasLoopUnroll()(Accumulators, AlphaBroadcast); - - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); - } - - break; - } - - C += 8; - CountN -= 8; - - } while (CountN > 0); - - return RowCount; -} - diff --git a/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h deleted file mode 100644 index 3746dbc82b3f6..0000000000000 --- a/onnxruntime/core/mlas/lib/power/FgemmKernelpower.h +++ /dev/null @@ -1,333 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelPower.h - -Abstract: - - This module implements the kernels for the single/double precision matrix/matrix - multiply operation (DGEMM/SGEMM). - ---*/ - -#include "mlasi.h" -#if defined(SINGLE) -#define MLAS_FLOATTYPE MLAS_FLOAT32X4 -#define MLAS_GEMMTYPE float -#define MLAS_LOAD_FLOAT MlasLoadFloat32x4 -#define MLAS_ZERO_FLOAT MlasZeroFloat32x4 -#define MLAS_STORE_FLOAT MlasStoreFloat32x4 -#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4 -#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4 -#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4 -#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4 -#else -#define MLAS_FLOATTYPE MLAS_FLOAT64X2 -#define MLAS_GEMMTYPE double -#define MLAS_LOAD_FLOAT MlasLoadFloat64x2 -#define MLAS_ZERO_FLOAT MlasZeroFloat64x2 -#define MLAS_STORE_FLOAT MlasStoreFloat64x2 -#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2 -#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2 -#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2 -#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2 -#endif -// -// Templates to ensure that a loop is unrolled. -// - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... Arguments - ) - { - IterationType::template Iteration(Arguments...); - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -template -struct MlasLoopUnrollStep -{ - template - MLAS_FORCEINLINE - static - void - Step( - IterationArgs&&... - ) - { - // Terminate the loop. - } -}; - -template -struct MlasLoopUnroll -{ - template - MLAS_FORCEINLINE - void - operator()( - IterationArgs&&... Arguments - ) - { - MlasLoopUnrollStep::template Step(Arguments...); - } -}; - -// -// Templates used with loop unrolling to perform an action on one row of the -// output. -// - -struct MlasFgemmZeroAccumulators -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[RowCount][4] - ) - { - Accumulators[Row][0] = MLAS_ZERO_FLOAT(); - Accumulators[Row][1] = MLAS_ZERO_FLOAT(); - Accumulators[Row][2] = MLAS_ZERO_FLOAT(); - Accumulators[Row][3] = MLAS_ZERO_FLOAT(); - } -}; - -struct MlasFgemmLoadAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE AElements[RowCount], - const MLAS_GEMMTYPE* A, - size_t lda - ) - { - AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda); - } -}; - -struct MlasFgemmBroadcastAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE ABroadcast[RowCount], - const MLAS_GEMMTYPE* A, - size_t lda - ) - { - ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda); - } -}; - -template -struct MlasFgemmSplatAElements -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE AElements[RowCount], - MLAS_FLOATTYPE ABroadcast[RowCount] - ) - { - ABroadcast[Row] = vec_splat(AElements[Row], Lane); - } -}; - -struct MlasFgemmMultiplyAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[RowCount][4], - MLAS_FLOATTYPE ABroadcast[RowCount], - MLAS_FLOATTYPE BElements[4] - ) - { - Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]); - Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]); - Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]); - Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]); - } -}; - -template -MLAS_FORCEINLINE -void -MlasFgemmComputeBlock( - MLAS_FLOATTYPE Accumulators[RowCount][4], - MLAS_FLOATTYPE ABroadcast[RowCount], - const MLAS_GEMMTYPE* B - ) -{ - MLAS_FLOATTYPE BElements[4]; -#if defined(SINGLE) - BElements[0] = MLAS_LOAD_FLOAT(B); - BElements[1] = MLAS_LOAD_FLOAT(B + 4); - BElements[2] = MLAS_LOAD_FLOAT(B + 8); - BElements[3] = MLAS_LOAD_FLOAT(B + 12); -#else - BElements[0] = MLAS_LOAD_FLOAT(B); - BElements[1] = MLAS_LOAD_FLOAT(B + 2); - BElements[2] = MLAS_LOAD_FLOAT(B + 4); - BElements[3] = MLAS_LOAD_FLOAT(B + 6); -#endif - - MlasLoopUnroll()(Accumulators, ABroadcast, BElements); -} - -struct MlasFgemmMultiplyAlphaRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[4], - MLAS_FLOATTYPE AlphaBroadcast - ) - { - Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast); - } -}; - -struct MlasFgemmMultiplyAlphaAddRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[4], - MLAS_FLOATTYPE AlphaBroadcast, - const MLAS_GEMMTYPE* C - ) - { -#if defined(SINGLE) - Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], - AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4)); -#else - Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], - AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2)); -#endif - } -}; - -struct MlasFgemmStoreRow -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[4], - MLAS_GEMMTYPE* C - ) - { -#if defined(SINGLE) - MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]); -#else - MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]); -#endif - } -}; - -template -struct MlasFgemmStoreVector -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[RowCount][4], - MLAS_GEMMTYPE* C, - size_t ldc, - MLAS_FLOATTYPE AlphaBroadcast, - bool ZeroMode - ) - { - MLAS_GEMMTYPE* c = C + Row * ldc; - - if (ZeroMode) { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); - } else { - MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); - } - - MlasLoopUnroll()(Accumulators[Row], c); - - // - // Shift down any unaligned elements to the bottom for further processing. - // - - if (VectorCount < 4) { - Accumulators[Row][0] = Accumulators[Row][VectorCount]; - } - } -}; - -struct MlasFgemmMultiplyAlphaTrailing -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[RowCount][4], - MLAS_FLOATTYPE AlphaBroadcast - ) - { - Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast); - } -}; - -template -struct MlasFgemmStoreScalar -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOATTYPE Accumulators[RowCount][4], - MLAS_GEMMTYPE* C, - size_t ldc, - bool ZeroMode - ) - { - MLAS_GEMMTYPE* c = C + Row * ldc + Lane; - MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT(Accumulators[Row][0]); - - if (!ZeroMode) { - Value += *c; - } - - *c = Value; - } -}; - diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp deleted file mode 100644 index 2d4d791c3a000..0000000000000 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ /dev/null @@ -1,303 +0,0 @@ -#include -#include "mlasi.h" -#include - -// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. -// ONNX Runtime CI pipelines do not build with all compiler versions. - -template -void -MLASCALL -MlasQuantizeLinearKernel( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - constexpr int32_t MinimumValue = std::numeric_limits::lowest(); - constexpr int32_t MaximumValue = std::numeric_limits::max(); - - auto ScaleVector = vec_splats(Scale); - auto MinimumValueVector = vec_splats(float(MinimumValue)); - auto MaximumValueVector = vec_splats(float(MaximumValue)); - auto ZeroPointVector = vec_splats(float(ZeroPoint)); - - while (N >= 16) { - auto FloatVector0 = vec_xl(0, Input); - auto FloatVector1 = vec_xl(0, Input + 4); - auto FloatVector2 = vec_xl(0, Input + 8); - auto FloatVector3 = vec_xl(0, Input + 12); - - FloatVector0 = vec_div(FloatVector0, ScaleVector); - FloatVector1 = vec_div(FloatVector1, ScaleVector); - FloatVector2 = vec_div(FloatVector2, ScaleVector); - FloatVector3 = vec_div(FloatVector3, ScaleVector); - - FloatVector0 = vec_round(FloatVector0); - FloatVector1 = vec_round(FloatVector1); - FloatVector2 = vec_round(FloatVector2); - FloatVector3 = vec_round(FloatVector3); - - FloatVector0 = vec_add(FloatVector0, ZeroPointVector); - FloatVector1 = vec_add(FloatVector1, ZeroPointVector); - FloatVector2 = vec_add(FloatVector2, ZeroPointVector); - FloatVector3 = vec_add(FloatVector3, ZeroPointVector); - - FloatVector0 = vec_max(FloatVector0, MinimumValueVector); - FloatVector1 = vec_max(FloatVector1, MinimumValueVector); - FloatVector2 = vec_max(FloatVector2, MinimumValueVector); - FloatVector3 = vec_max(FloatVector3, MinimumValueVector); - - FloatVector0 = vec_min(FloatVector0, MaximumValueVector); - FloatVector1 = vec_min(FloatVector1, MaximumValueVector); - FloatVector2 = vec_min(FloatVector2, MaximumValueVector); - FloatVector3 = vec_min(FloatVector3, MaximumValueVector); - - auto IntegerVector0 = vec_signed(FloatVector0); - auto IntegerVector1 = vec_signed(FloatVector1); - auto IntegerVector2 = vec_signed(FloatVector2); - auto IntegerVector3 = vec_signed(FloatVector3); - - auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); - auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - - if constexpr (std::is_same_v || std::is_same_v) { - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *)Output); - } else { - static_assert(std::is_same_v || std::is_same_v); - vec_xst(ShortVector0, 0, (int16_t *)Output); - vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); - } - - Output += 16; - Input += 16; - N -= 16; - } - - for (size_t n = 0; n < N; n++) { - - float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); - FloatValue = std::max(FloatValue, float(MinimumValue)); - FloatValue = std::min(FloatValue, float(MaximumValue)); - Output[n] = (OutputType)(int32_t)FloatValue; - } -} - -template -void -MLASCALL -MlasQuantizeLinearInt4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer as int4 using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. Contains packed 4-bit elements. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - constexpr int32_t MinimumValue = Int4Traits::Min; - constexpr int32_t MaximumValue = Int4Traits::Max; - using UnpackedType = typename Int4Traits::UnpackedType; - - auto ScaleVector = vec_splats(Scale); - auto MinimumValueVector = vec_splats(float(MinimumValue)); - auto MaximumValueVector = vec_splats(float(MaximumValue)); - auto ZeroPointVector = vec_splats(float(ZeroPoint)); - - // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. - UnpackedType TmpOutput[16] = {}; - - while (N >= 16) { - auto FloatVector0 = vec_xl(0, Input); - auto FloatVector1 = vec_xl(0, Input + 4); - auto FloatVector2 = vec_xl(0, Input + 8); - auto FloatVector3 = vec_xl(0, Input + 12); - - FloatVector0 = vec_div(FloatVector0, ScaleVector); - FloatVector1 = vec_div(FloatVector1, ScaleVector); - FloatVector2 = vec_div(FloatVector2, ScaleVector); - FloatVector3 = vec_div(FloatVector3, ScaleVector); - - FloatVector0 = vec_round(FloatVector0); - FloatVector1 = vec_round(FloatVector1); - FloatVector2 = vec_round(FloatVector2); - FloatVector3 = vec_round(FloatVector3); - - FloatVector0 = vec_add(FloatVector0, ZeroPointVector); - FloatVector1 = vec_add(FloatVector1, ZeroPointVector); - FloatVector2 = vec_add(FloatVector2, ZeroPointVector); - FloatVector3 = vec_add(FloatVector3, ZeroPointVector); - - FloatVector0 = vec_max(FloatVector0, MinimumValueVector); - FloatVector1 = vec_max(FloatVector1, MinimumValueVector); - FloatVector2 = vec_max(FloatVector2, MinimumValueVector); - FloatVector3 = vec_max(FloatVector3, MinimumValueVector); - - FloatVector0 = vec_min(FloatVector0, MaximumValueVector); - FloatVector1 = vec_min(FloatVector1, MaximumValueVector); - FloatVector2 = vec_min(FloatVector2, MaximumValueVector); - FloatVector3 = vec_min(FloatVector3, MaximumValueVector); - - auto IntegerVector0 = vec_signed(FloatVector0); - auto IntegerVector1 = vec_signed(FloatVector1); - auto IntegerVector2 = vec_signed(FloatVector2); - auto IntegerVector3 = vec_signed(FloatVector3); - - auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); - auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); - - MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); - MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); - MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); - MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); - MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); - MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); - MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); - MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); - - Input += 16; - N -= 16; - } - - for (size_t n = 0; n < N; n++) { - - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); - FloatValue = std::max(FloatValue, static_cast(MinimumValue)); - FloatValue = std::min(FloatValue, static_cast(MaximumValue)); - UnpackedType IntValue = static_cast(FloatValue); - - MlasSetInt4Element(Output, n, IntValue); - } -} - -void -MLASCALL -MlasQuantizeLinearU8Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS8Kernel( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU16Kernel( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint - ) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS16Kernel( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint - ) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); -} - diff --git a/onnxruntime/core/mlas/lib/power/QuantizePowerVSX.cpp b/onnxruntime/core/mlas/lib/power/QuantizePowerVSX.cpp deleted file mode 100644 index de3a23452b128..0000000000000 --- a/onnxruntime/core/mlas/lib/power/QuantizePowerVSX.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "mlasi.h" -#include - -template -void -MLASCALL -MlasQuantizeLinearVSX( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ) -{ - // Workaround for bad GCC warning that Scale is set but not used. - MLAS_UNREFERENCED_PARAMETER(Scale); - - constexpr int32_t MinimumValue = std::numeric_limits::min(); - constexpr int32_t MaximumValue = std::numeric_limits::max(); - - auto ScaleVector = vec_splats(Scale); - auto MinimumValueVector = vec_splats(float(MinimumValue)); - auto MaximumValueVector = vec_splats(float(MaximumValue)); - auto ZeroPointVector = vec_splats(float(ZeroPoint)); - - while (N >= 16) { - auto FloatVector0 = vec_xl(0, Input); - auto FloatVector1 = vec_xl(0, Input + 4); - auto FloatVector2 = vec_xl(0, Input + 8); - auto FloatVector3 = vec_xl(0, Input + 12); - - FloatVector0 = vec_div(FloatVector0, ScaleVector); - FloatVector1 = vec_div(FloatVector1, ScaleVector); - FloatVector2 = vec_div(FloatVector2, ScaleVector); - FloatVector3 = vec_div(FloatVector3, ScaleVector); - - FloatVector0 = vec_round(FloatVector0); - FloatVector1 = vec_round(FloatVector1); - FloatVector2 = vec_round(FloatVector2); - FloatVector3 = vec_round(FloatVector3); - - FloatVector0 = vec_add(FloatVector0, ZeroPointVector); - FloatVector1 = vec_add(FloatVector1, ZeroPointVector); - FloatVector2 = vec_add(FloatVector2, ZeroPointVector); - FloatVector3 = vec_add(FloatVector3, ZeroPointVector); - - FloatVector0 = vec_max(FloatVector0, MinimumValueVector); - FloatVector1 = vec_max(FloatVector1, MinimumValueVector); - FloatVector2 = vec_max(FloatVector2, MinimumValueVector); - FloatVector3 = vec_max(FloatVector3, MinimumValueVector); - - FloatVector0 = vec_min(FloatVector0, MaximumValueVector); - FloatVector1 = vec_min(FloatVector1, MaximumValueVector); - FloatVector2 = vec_min(FloatVector2, MaximumValueVector); - FloatVector3 = vec_min(FloatVector3, MaximumValueVector); - - auto IntegerVector0 = vec_signed(FloatVector0); - auto IntegerVector1 = vec_signed(FloatVector1); - auto IntegerVector2 = vec_signed(FloatVector2); - auto IntegerVector3 = vec_signed(FloatVector3); - - auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); - auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); - auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, (int8_t *) Output); - - Output += 16; - Input += 16; - N -= 16; - } - - while (N >= 4) { - auto FloatVector = vec_xl(0, Input); - FloatVector = vec_div(FloatVector, ScaleVector); - FloatVector = vec_round(FloatVector); - FloatVector = vec_add(FloatVector, ZeroPointVector); - - FloatVector = vec_max(FloatVector, MinimumValueVector); - FloatVector = vec_min(FloatVector, MaximumValueVector); - auto IntegerVector = vec_signed(FloatVector); - - auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); - auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); - vec_xst_len(CharVector, (int8_t *) Output, N); - - Output += 4; - Input += 4; - N -= 4; - } - - if (N > 0) { - auto FloatVector = vec_xl_len( const_cast(Input), 4*N); - - FloatVector = vec_div(FloatVector, ScaleVector); - FloatVector = vec_round(FloatVector); - FloatVector = vec_add(FloatVector, ZeroPointVector); - - FloatVector = vec_max(FloatVector, MinimumValueVector); - FloatVector = vec_min(FloatVector, MaximumValueVector); - auto IntegerVector = vec_signed(FloatVector); - - auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); - auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); - vec_xst_len(CharVector, (int8_t *) Output, N); - } -} - -void -MLASCALL -MlasQuantizeLinearU8KernelVSX( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - MlasQuantizeLinearVSX(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS8KernelVSX( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearVSX(Input, Output, N, Scale, ZeroPoint); -} diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp b/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp deleted file mode 100644 index 3dfe061c72524..0000000000000 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelPOWER10.cpp +++ /dev/null @@ -1,412 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelPower.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "SgemmKernelpower.h" -struct MlasSgemmBroadcastAElementsMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 ABroadcast[RowCount], - const float* A, - size_t lda - ) - { - ABroadcast[0][Row] = A [Row * lda]; - } -}; - -template -MLAS_FORCEINLINE -void -MlasSgemmComputeAElements( - MLAS_FLOAT32X4 AElements[RowCount], - MLAS_FLOAT32X4 ABroadcast[RowCount] - ) -{ - __vector float a1,a2; - a1 = vec_mergee (AElements[0], AElements[1]); - a2 = vec_mergee (AElements[2], AElements[3]); - ABroadcast[0] =vec_xxpermdi(a1,a2,0); - ABroadcast[2] =vec_xxpermdi(a1,a2,3); - a1 = vec_mergeo (AElements[0], AElements[1]); - a2 = vec_mergeo (AElements[2], AElements[3]); - ABroadcast[1] =vec_xxpermdi(a1,a2,0); - ABroadcast[3] =vec_xxpermdi(a1,a2,3); -} -template -MLAS_FORCEINLINE -void -MlasSgemmComputeBlockMMA( - __vector_quad acc[8], - MLAS_FLOAT32X4 ABroadcast, - MLAS_FLOAT32X4 A2Broadcast, - const float* B, - size_t CountM - ) -{ - MLAS_FLOAT32X4 BElements[4]; - typedef __vector unsigned char vec_t; - - BElements[0] = MlasLoadFloat32x4(B); - BElements[1] = MlasLoadFloat32x4(B + 4); - BElements[2] = MlasLoadFloat32x4(B + 8); - BElements[3] = MlasLoadFloat32x4(B + 12); - __builtin_mma_xvf32gerpp (&acc[0], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[0])); - __builtin_mma_xvf32gerpp (&acc[1], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[1])); - __builtin_mma_xvf32gerpp (&acc[2], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[2])); - __builtin_mma_xvf32gerpp (&acc[3], reinterpret_cast(ABroadcast), reinterpret_cast(BElements[3])); - if (CountM == 8) { - __builtin_mma_xvf32gerpp (&acc[4], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[0])); - __builtin_mma_xvf32gerpp (&acc[5], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[1])); - __builtin_mma_xvf32gerpp (&acc[6], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[2])); - __builtin_mma_xvf32gerpp (&acc[7], reinterpret_cast(A2Broadcast), reinterpret_cast(BElements[3])); - } -} -template -struct MlasSgemmStoreVectorMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Result[4], - float* C, - size_t ldc, - MLAS_FLOAT32X4 AlphaBroadcast, - bool ZeroMode - ) - { - MLAS_FLOAT32X4 *rowC; - if (ZeroMode) { - rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); - rowC[0] = Result[Row] * AlphaBroadcast; - } else { - rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); - rowC[0] += Result[Row] * AlphaBroadcast; - } - } -}; - -struct MlasSgemmMultiplyAlphaTrailingMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount], - MLAS_FLOAT32X4 AlphaBroadcast - ) - { - Accumulators[Row] = MlasMultiplyFloat32x4(Accumulators[Row], AlphaBroadcast); - } -}; -template -struct MlasSgemmStoreScalarMMA -{ - template - MLAS_FORCEINLINE - static - void - Iteration( - MLAS_FLOAT32X4 Accumulators[RowCount], - float* C, - size_t ldc, - bool ZeroMode - ) - { - float* c = C + Row * ldc + Lane; - float Value = Accumulators[Row][Lane]; - if (!ZeroMode) { - Value += *c; - } - - *c = Value; - } -}; - -template -MLAS_FORCEINLINE -size_t -MlasSgemmMMAProcessCount( - const float* A, - const float* B, - float* C, - size_t CountM, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - MLAS_FLOAT32X4 AlphaBroadcast, - bool ZeroMode - ) -{ - do { - - const float* a = A; - size_t k = CountK; - - MLAS_FLOAT32X4 Accumulators[2][RowCount] = {{ 0 }}; - MLAS_FLOAT32X4 Result[RowCount]; - MLAS_FLOAT32X4 AElements[RowCount]; - MLAS_FLOAT32X4 ABroadcast[RowCount] = { 0 }; - MLAS_FLOAT32X4 A2Broadcast[RowCount] = { 0 }; - __vector_quad acc[8]; - - // - // Clear the block accumulators. - // - __builtin_mma_xxsetaccz(&acc[0]); - __builtin_mma_xxsetaccz(&acc[1]); - __builtin_mma_xxsetaccz(&acc[2]); - __builtin_mma_xxsetaccz(&acc[3]); - __builtin_mma_xxsetaccz(&acc[4]); - __builtin_mma_xxsetaccz(&acc[5]); - __builtin_mma_xxsetaccz(&acc[6]); - __builtin_mma_xxsetaccz(&acc[7]); - - // - // Compute the output block. - // - while (k >= 4) { - - MlasLoopUnroll()(AElements, a, lda); - MlasSgemmComputeAElements(AElements, ABroadcast); - if (CountM == 8) { - MlasLoopUnroll()(AElements, a + ( lda * 4), lda); - MlasSgemmComputeAElements(AElements, A2Broadcast); - } - MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); - MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[1], A2Broadcast[1], B+16, CountM); - MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[2], A2Broadcast[2], B+32, CountM); - MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[3], A2Broadcast[3], B+48, CountM); - B += 16 * 4; - a += 4; - k -= 4; - } - - while (k > 0) { - MlasLoopUnroll()(ABroadcast, a, lda); - if (CountM == 8) { - MlasLoopUnroll()(A2Broadcast, a + (lda * 4), lda); - } - MlasSgemmComputeBlockMMA(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); - a += 1; - B += 16; - k -= 1; - } - if (CountN >= 16) { - - // - // Store the entire output block. - // - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[2]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[3]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[6]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[7]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - } - } else { - - // - // Store the partial output block. - // - - if (CountN >= 12) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[2]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[6]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[7]); - } - } - if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[3]); - } - } else if (CountN >= 8) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[1]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - __builtin_mma_disassemble_acc (Result, &acc[5]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[6]); - } - } - if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[2]); - } - } else if (CountN >= 4) { - __builtin_mma_disassemble_acc (Result, &acc[0]); - MlasLoopUnroll>()(Result, C, ldc, AlphaBroadcast, ZeroMode); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Result, &acc[4]); - MlasLoopUnroll>()(Result, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[5]); - } - } - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[1]); - } - } else { - __builtin_mma_disassemble_acc (Accumulators[0], &acc[0]); - if (CountM == 8) { - __builtin_mma_disassemble_acc (Accumulators[1], &acc[4]); - } - } - - // - // Store the remaining unaligned columns. - // - - C += (CountN & ~3); - CountN &= 3; - - if (CountN > 0) { - - MlasLoopUnroll()(Accumulators[0], AlphaBroadcast); - MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); - if (CountM == 8) { - MlasLoopUnroll()(Accumulators[1], AlphaBroadcast); - MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); - } - if (CountN >= 2) { - MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); - if (CountM == 8) { - MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); - } - } - if (CountN >= 3) { - MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); - if (CountM == 8) { - MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); - } - } - } - - break; - } - - C += 16; - CountN -= 16; - - } while (CountN > 0); - - return CountM; -} - -size_t -MLASCALL -MlasSgemmKernelPOWER10( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); - - if (CountM >= 8) { - RowsHandled = MlasSgemmMMAProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 4) { - RowsHandled = MlasSgemmMMAProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 2) { - RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else { - RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } - - return RowsHandled; -} diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp b/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp deleted file mode 100644 index 8c42348338c27..0000000000000 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelPower.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelPower.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ -#include "SgemmKernelpower.h" - -size_t -MLASCALL -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar multiplier (see SGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - - MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); - - if (CountM >= 4) { - RowsHandled = MlasSgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else if (CountM >= 2) { - RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } else { - RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); - } - - return RowsHandled; -} diff --git a/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h b/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h deleted file mode 100644 index 53be544bdbe3f..0000000000000 --- a/onnxruntime/core/mlas/lib/power/SgemmKernelpower.h +++ /dev/null @@ -1,139 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelpower.h - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "FgemmKernelpower.h" - -template -MLAS_FORCEINLINE -size_t -MlasSgemmProcessCount( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - MLAS_FLOAT32X4 AlphaBroadcast, - bool ZeroMode - ) -{ - do { - - const float* a = A; - size_t k = CountK; - - MLAS_FLOAT32X4 Accumulators[RowCount][4]; - MLAS_FLOAT32X4 AElements[RowCount]; - MLAS_FLOAT32X4 ABroadcast[RowCount]; - - // - // Clear the block accumulators. - // - - MlasLoopUnroll()(Accumulators); - - // - // Compute the output block. - // - - while (k >= 4) { - - MlasLoopUnroll()(AElements, a, lda); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 16); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 32); - - MlasLoopUnroll>()(AElements, ABroadcast); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 48); - - a += 4; - B += 16 * 4; - k -= 4; - } - - while (k > 0) { - - MlasLoopUnroll()(ABroadcast, a, lda); - MlasFgemmComputeBlock(Accumulators, ABroadcast, B); - - a += 1; - B += 16; - k -= 1; - } - - if (CountN >= 16) { - - // - // Store the entire output block. - // - - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - - } else { - - // - // Store the partial output block. - // - - if (CountN >= 12) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } else if (CountN >= 8) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } else if (CountN >= 4) { - MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); - } - - // - // Store the remaining unaligned columns. - // - - C += (CountN & ~3); - CountN &= 3; - - if (CountN > 0) { - - MlasLoopUnroll()(Accumulators, AlphaBroadcast); - - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); - - if (CountN >= 2) { - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); - } - - if (CountN >= 3) { - MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); - } - } - - break; - } - - C += 16; - CountN -= 16; - - } while (CountN > 0); - - return RowCount; -} - diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp deleted file mode 100644 index 0f3bc1d579711..0000000000000 --- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp +++ /dev/null @@ -1,1374 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_power10.cpp - -Abstract: - - This module implements QGEMM kernel for POWER10. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" -#include - -struct MLAS_GEMM_QUANT_KERNEL_POWER10 -{ - typedef int8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef uint8_t OffsetBType; - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 16, 256, 384 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 16, 128, 128 }; -}; - -constexpr size_t MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_POWER10::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_POWER10::PackedStrides; - -#define INC_BUFFER(cnt) \ - ColumnSumBuffer += cnt; \ - if (ZeroPointB != nullptr) { \ - ZeroPointB += cnt; \ - } -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned - ) -{ - if (!AIsSigned) { - ZeroPointA = MLAS_GEMM_QUANT_KERNEL_POWER10::OffsetAType(ZeroPointA ^ 0x80); - } - return ZeroPointA; -} - -template<> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (BIsSigned) { - ZeroPointB = MLAS_GEMM_QUANT_KERNEL_POWER10::OffsetBType(ZeroPointB ^ 0x80); - } - return ZeroPointB; - -} - -template -void -MlasGemmQuantCopyPackA8x8( - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) -{ - constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80); - Vtype vmask = reinterpret_cast(vec_splats(Flip)); - typedef __vector signed char vec_t; - - // Process eight rows of matrix A in a loop. - // - // The buffer is packed as a series of 4x4 byte vectors to help - // in getting into MMA loop. - // - // Unsigned buffers are converted to signed buffers in order to - // share a common kernel. - // This pattern is repeated (CountK / 4) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - while (CountM >= 8) { - const uint8_t *a = A; - __vector signed int vsum = { 0 }; - __vector signed int vsum2 = { 0 }; - size_t y = CountK; - while (y >= 16) { - Vtype a1 = *reinterpret_cast(&a[0]); - Vtype a2 = *reinterpret_cast(&a[lda]); - Vtype a3 = *reinterpret_cast(&a[lda * 2]); - Vtype a4 = *reinterpret_cast(&a[lda * 3]); - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi(vx, vx1, 0); - Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi(vx, vx1, 3); - Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); - a1 = *reinterpret_cast(&a[lda*4]); - a2 = *reinterpret_cast(&a[lda*5]); - a3 = *reinterpret_cast(&a[lda*6]); - a4 = *reinterpret_cast(&a[lda*7]); - vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx8 = vec_xxpermdi(vx, vx1, 0); - Vtype vx9 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx10 = vec_xxpermdi(vx, vx1, 3); - Vtype vx11 = vec_xxpermdi(vx2, vx3, 3); - vec_t vxx = - AIsSigned ? reinterpret_cast(vx4) : - reinterpret_cast(vec_sub(vx4, vmask)); - vsum = vec_sum4s(vxx, vsum); - *reinterpret_cast(&D[0]) = vxx; - vxx = AIsSigned ? reinterpret_cast(vx5) : - reinterpret_cast(vec_sub(vx5, vmask)); - vsum = vec_sum4s(vxx, vsum); - *reinterpret_cast(&D[16]) = vxx; - vxx = AIsSigned ? reinterpret_cast(vx6) : - reinterpret_cast(vec_sub(vx6, vmask)); - vsum = vec_sum4s(vxx, vsum); - *reinterpret_cast(&D[32]) = vxx; - vxx = AIsSigned ? reinterpret_cast(vx7) : - reinterpret_cast(vec_sub(vx7, vmask)); - vsum = vec_sum4s(vxx, vsum); - *reinterpret_cast(&D[48]) = vxx; - vxx = AIsSigned ? reinterpret_cast(vx8) : - reinterpret_cast(vec_sub(vx8, vmask)); - *reinterpret_cast(&D[64]) = vxx; - vsum2 = vec_sum4s(vxx, vsum2); - vxx = AIsSigned ? reinterpret_cast(vx9) : - reinterpret_cast(vec_sub(vx9, vmask)); - *reinterpret_cast(&D[80]) = vxx; - vsum2 = vec_sum4s(vxx, vsum2); - vxx = AIsSigned ? reinterpret_cast(vx10) : - reinterpret_cast(vec_sub(vx10, vmask)); - *reinterpret_cast(&D[96]) = vxx; - vsum2 = vec_sum4s(vxx, vsum2); - vxx = AIsSigned ? reinterpret_cast(vx11) : - reinterpret_cast(vec_sub(vx11, vmask)); - *reinterpret_cast(&D[112]) = vxx; - vsum2 = vec_sum4s(vxx, vsum2); - D += 16 * 8; - a += 16; - y -= 16; - } - size_t yval = y; - while (y >= 4) - { - int a1 = *reinterpret_cast(&a[0]); - int a2 = *reinterpret_cast(&a[lda]); - int a3 = *reinterpret_cast(&a[lda*2]); - int a4 = *reinterpret_cast(&a[lda*3]); - __vector int vx1 = { a1, a2, a3, a4}; - vec_t vx = - AIsSigned ? reinterpret_cast(vx1) : - reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); - vsum = vec_sum4s(vx, vsum); - *reinterpret_cast(&D[0]) = vx; - a1 = *reinterpret_cast(&a[lda*4]); - a2 = *reinterpret_cast(&a[lda*5]); - a3 = *reinterpret_cast(&a[lda*6]); - a4 = *reinterpret_cast(&a[lda*7]); - __vector int vx2 = { a1, a2, a3, a4}; - vx = AIsSigned ? reinterpret_cast(vx2) : - reinterpret_cast(vec_sub(reinterpret_cast(vx2), vmask)); - vsum2 = vec_sum4s(vx, vsum2); - if (CountK & 3) { - if (yval >= 12) { - *reinterpret_cast(&D[64]) = vx; - } else if (yval >= 8) { - *reinterpret_cast(&D[48]) = vx; - } else { - *reinterpret_cast(&D[32]) = vx; - } - } else { - if (yval >= 12) { - *reinterpret_cast(&D[48]) = vx; - } else if (yval >= 8) { - *reinterpret_cast(&D[32]) = vx; - } else { - *reinterpret_cast(&D[16]) = vx; - } - } - D += 16; - a += 4; - y -= 4; - } - if (yval >= 12) { - if (!(CountK & 3)) { - D += 48; - } - } else if (yval >= 8) { - if (!(CountK & 3)) { - D += 32; - } - } else if (yval >= 4) { - if (!(CountK & 3)) { - D += 16; - } - } - if (y >= 1) - { - Vtype a1 = vmask; - Vtype a2 = vmask; - Vtype a3 = vmask; - Vtype a4 = vmask; - a1[0] = a[0]; - a2[0] = a[lda]; - a3[0] = a[lda * 2]; - a4[0] = a[lda * 3]; - if (y >= 2) { - a1[1] = a[1]; - a2[1] = a[lda + 1]; - a3[1] = a[lda * 2 + 1]; - a4[1] = a[lda * 3 + 1]; - } - if (y >= 3) { - a1[2] = a[2]; - a2[2] = a[lda + 2]; - a3[2] = a[lda * 2 + 2]; - a4[2] = a[lda * 3 + 2]; - } - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = vec_xxpermdi(vx, vx1, 0); - vec_t vx3 = - AIsSigned ? reinterpret_cast(vx2) : - reinterpret_cast(vec_sub(vx2, vmask)); - vsum = vec_sum4s(vx3, vsum); - *reinterpret_cast(&D[0]) = vx3; - a1 = vmask; - a2 = vmask; - a3 = vmask; - a4 = vmask; - a1[0] = a[lda * 4]; - a2[0] = a[lda * 5]; - a3[0] = a[lda * 6]; - a4[0] = a[lda * 7]; - if (y >= 2) { - a1[1] = a[lda * 4 + 1]; - a2[1] = a[lda * 5 + 1]; - a3[1] = a[lda * 6 + 1]; - a4[1] = a[lda * 7 + 1]; - } - if (y >= 3) { - a1[2] = a[lda * 4 + 2]; - a2[2] = a[lda * 5 + 2]; - a3[2] = a[lda * 6 + 2]; - a4[2] = a[lda * 7 + 2]; - } - vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - vx2 = vec_xxpermdi(vx, vx1, 0); - vx3 = AIsSigned ? reinterpret_cast(vx2) : - reinterpret_cast(vec_sub(vx2, vmask)); - vsum2 = vec_sum4s(vx3, vsum2); - if (CountK % 16 >= 12) { - *reinterpret_cast(&D[64]) = vx3; - D += 80; - } else if (CountK % 16 >= 8) { - *reinterpret_cast(&D[48]) = vx3; - D += 64; - } else if (CountK % 16 >= 4) { - *reinterpret_cast(&D[32]) = vx3; - D += 48; - } else { - *reinterpret_cast(&D[16]) = vx3; - D += 16 * 2; - } - a += 16; - } - A += lda * 8; - *RowSumBuffer++ = vsum[0]; - *RowSumBuffer++ = vsum[1]; - *RowSumBuffer++ = vsum[2]; - *RowSumBuffer++ = vsum[3]; - *RowSumBuffer++ = vsum2[0]; - *RowSumBuffer++ = vsum2[1]; - *RowSumBuffer++ = vsum2[2]; - *RowSumBuffer++ = vsum2[3]; - CountM -= 8; - } - - // Process four rows of matrix A in a loop. - // - if (CountM >= 4) - { - const uint8_t *a = A; - __vector signed int vsum = { 0 }; - size_t y = CountK; - - while (y >= 16) - { - Vtype a1 = *reinterpret_cast(&a[0]); - Vtype a2 = *reinterpret_cast(&a[lda]); - Vtype a3 = *reinterpret_cast(&a[lda * 2]); - Vtype a4 = *reinterpret_cast(&a[lda * 3]); - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi(vx, vx1, 0); - Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi(vx, vx1, 3); - Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); - vec_t vx0 = - AIsSigned ? reinterpret_cast(vx4) : - reinterpret_cast(vec_sub(vx4, vmask)); - *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx5) : - reinterpret_cast(vec_sub(vx5, vmask)); - *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx6) : - reinterpret_cast(vec_sub(vx6, vmask)); - *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx7) : - reinterpret_cast(vec_sub(vx7, vmask)); - *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s(vx0, vsum); - D += 16 * 4; - a += 16; - y -= 16; - } - while (y >= 4) - { - int a1 = *reinterpret_cast(&a[0]); - int a2 = *reinterpret_cast(&a[lda]); - int a3 = *reinterpret_cast(&a[lda*2]); - int a4 = *reinterpret_cast(&a[lda*3]); - __vector int vx1 = { a1, a2, a3, a4}; - vec_t vx = - AIsSigned ? reinterpret_cast(vx1) : - reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); - *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s(vx, vsum); - D += 16; - a += 4; - y -= 4; - } - if (y >= 1) - { - Vtype vx = vmask; - vx[0] = a[0]; - vx[4] = a[lda]; - vx[8] = a[lda * 2]; - vx[12] = a[lda * 3]; - if (y >= 2) { - vx[1] = a[1]; - vx[5] = a[lda + 1]; - vx[9] = a[lda * 2 + 1]; - vx[13] = a[lda * 3 + 1]; - } - if (y >= 3) { - vx[2] = a[2]; - vx[6] = a[lda + 2]; - vx[10] = a[lda * 2 + 2]; - vx[14] = a[lda * 3 + 2]; - } - vec_t vx1 = - AIsSigned ? reinterpret_cast(vx) : - reinterpret_cast(vec_sub(vx, vmask)); - *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s(vx1, vsum); - D += 16; - a += 16; - } - A += lda * 4; - *RowSumBuffer++ = vsum[0]; - *RowSumBuffer++ = vsum[1]; - *RowSumBuffer++ = vsum[2]; - *RowSumBuffer++ = vsum[3]; - CountM -= 4; - } - - // Process remaining rows of matrix A in a loop. - // - if (CountM <= 3 && CountM > 0) { - const uint8_t *a = A; - size_t y = CountK; - __vector signed int vsum = { 0 }; - - while (y >= 16) { - Vtype a4 = vmask; - Vtype a2 = vmask; - Vtype a3 = vmask; - Vtype a1 = *reinterpret_cast(&a[0]); - if (CountM == 3) { - a3 = *reinterpret_cast(&a[lda * 2]); - } - if (CountM >= 2) { - a2 = *reinterpret_cast(&a[lda]); - } - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi(vx, vx1, 0); - Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi(vx, vx1, 3); - Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); - vec_t vx0 = - AIsSigned ? reinterpret_cast(vx4) : - reinterpret_cast(vec_sub(vx4, vmask)); - *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx5) : - reinterpret_cast(vec_sub(vx5, vmask)); - *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx6) : - reinterpret_cast(vec_sub(vx6, vmask)); - *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx7) : - reinterpret_cast(vec_sub(vx7, vmask)); - *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s(vx0, vsum); - D += 16 * 4; - a += 16; - y -= 16; - } - while (y >= 4) - { - Vtype vb = vmask; - __vector int vx1 = reinterpret_cast<__vector int>(vb); - vx1[0] = *reinterpret_cast(&a[0]); - if (CountM >= 2) { - vx1[1] = *reinterpret_cast(&a[lda]); - } - if (CountM >= 3) { - vx1[2] = *reinterpret_cast(&a[lda*2]); - } - vec_t vx = - AIsSigned ? reinterpret_cast(vx1) : - reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); - *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s(vx, vsum); - D += 16; - a += 4; - y -= 4; - } - if (y >= 1) - { - int8_t vz = 0; - vec_t vx = vec_splats(vz); - vx[0] = a[0] ^ Flip; - if (y >= 2) { - vx[1] = a[1] ^ Flip; - } - if (y >= 3) { - vx[2] = a[2] ^ Flip; - } - if (CountM >= 2) { - vx[4] = a[lda] ^ Flip; - if (y >= 2) { - vx[5] = a[lda + 1] ^ Flip; - } - if (y >= 3) { - vx[6] = a[lda + 2] ^ Flip; - } - } - if (CountM == 3) { - vx[8] = a[lda * 2] ^ Flip; - if (y >= 2) { - vx[9] = a[lda * 2 + 1] ^ Flip; - } - if (y >= 3) { - vx[10] = a[lda * 2 + 2] ^ Flip; - } - } - *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s(vx, vsum); - D += 16; - } - *RowSumBuffer++ = vsum[0]; - if (CountM >= 2) { - *RowSumBuffer++ = vsum[1]; - } - if (CountM >= 3) { - *RowSumBuffer++ = vsum[2]; - } - } -} - -template -void -MlasGemmQuantCopyPackB8x8( - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer - ) -{ - [[maybe_unused]] constexpr uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); - typedef __vector unsigned char vec_t; - Vtype vmask = reinterpret_cast(vec_splats(BitFlipValue)); - vec_t mask = {0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15}; - - // Copy columns from matrix B to the packed buffer. Signed buffers are - // converted to unsigned buffers in order to share a common kernel. - // - // If CountK is not aligned to a multiple of four, then the packed buffer - // is padded with zero vectors. - - // Process 16 columns of matrix B in a loop. - // - size_t PackedK = ((CountK + 4 - 1) / 4) * 16; - size_t k2 = PackedK; - size_t k3 = PackedK*2; - size_t k4 = PackedK*3; - - while (CountN >= 16) { - const uint8_t* b = B; - __vector unsigned int vsum = {0}; - __vector unsigned int vsum2 = {0}; - __vector unsigned int vsum3 = {0}; - __vector unsigned int vsum4 = {0}; - size_t y = CountK; - if (y >= 4) { - do { - Vtype b1 = *reinterpret_cast(&b[0]); - Vtype b2 = *reinterpret_cast(&b[ldb]); - Vtype b3 = *reinterpret_cast(&b[ldb*2]); - Vtype b4 = *reinterpret_cast(&b[ldb*3]); - Vtype t1 = vec_mergeh(b1, b3); - Vtype t2 = vec_mergel(b1, b3); - Vtype t3 = vec_mergeh(b2, b4); - Vtype t4 = vec_mergel(b2, b4); - b1 = vec_mergeh(t1, t3); - b2 = vec_mergel(t1, t3); - b3 = vec_mergeh(t2, t4); - b4 = vec_mergel(t2, t4); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : - reinterpret_cast(b1); - vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : - reinterpret_cast(b2); - vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : - reinterpret_cast(b3); - vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : - reinterpret_cast(b4); - *reinterpret_cast(&D[0]) = vx1; - *reinterpret_cast(&D[k2]) = vx2; - *reinterpret_cast(&D[k3]) = vx3; - *reinterpret_cast(&D[k4]) = vx4; - vsum = vec_sum4s(vx1, vsum); - vsum2 = vec_sum4s(vx2, vsum2); - vsum3 = vec_sum4s(vx3, vsum3); - vsum4 = vec_sum4s(vx4, vsum4); - D += 16; - b += ldb*4; - y -= 4; - } while (y >= 4); - } - if (y >= 1) { - Vtype b1 = *reinterpret_cast(&b[0]); - Vtype b2 = (y >= 2) ? *reinterpret_cast(&b[ldb]) : vmask; - Vtype b3 = (y >= 3) ? *reinterpret_cast(&b[ldb*2]) : vmask; - Vtype b4 = vmask; - Vtype t1 = vec_mergeh(b1, b3); - Vtype t2 = vec_mergel(b1, b3); - Vtype t3 = vec_mergeh(b2, b4); - Vtype t4 = vec_mergel(b2, b4); - b1 = vec_mergeh(t1, t3); - b2 = vec_mergel(t1, t3); - b3 = vec_mergeh(t2, t4); - b4 = vec_mergel(t2, t4); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : - reinterpret_cast(b1); - vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : - reinterpret_cast(b2); - vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : - reinterpret_cast(b3); - vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : - reinterpret_cast(b4); - *reinterpret_cast(&D[0]) = vx1; - *reinterpret_cast(&D[k2]) = vx2; - *reinterpret_cast(&D[k3]) = vx3; - *reinterpret_cast(&D[k4]) = vx4; - vsum = vec_sum4s(vx1, vsum); - vsum2 = vec_sum4s(vx2, vsum2); - vsum3 = vec_sum4s(vx3, vsum3); - vsum4 = vec_sum4s(vx4, vsum4); - D += 16; - } - *ColumnSumBuffer++ = vsum[0]; - *ColumnSumBuffer++ = vsum[1]; - *ColumnSumBuffer++ = vsum[2]; - *ColumnSumBuffer++ = vsum[3]; - *ColumnSumBuffer++ = vsum2[0]; - *ColumnSumBuffer++ = vsum2[1]; - *ColumnSumBuffer++ = vsum2[2]; - *ColumnSumBuffer++ = vsum2[3]; - *ColumnSumBuffer++ = vsum3[0]; - *ColumnSumBuffer++ = vsum3[1]; - *ColumnSumBuffer++ = vsum3[2]; - *ColumnSumBuffer++ = vsum3[3]; - *ColumnSumBuffer++ = vsum4[0]; - *ColumnSumBuffer++ = vsum4[1]; - *ColumnSumBuffer++ = vsum4[2]; - *ColumnSumBuffer++ = vsum4[3]; - B += 16; - CountN -= 16; - D += k4; - } - - // Process four columns of matrix B in a loop. - // - while (CountN >= 4) { - const uint8_t* b = B; - __vector unsigned int vsum = {0}; - size_t y = CountK; - if (y >= 4) { - do { - int b1 = *reinterpret_cast(&b[0]); - int b2 = *reinterpret_cast(&b[ldb]); - int b3 = *reinterpret_cast(&b[ldb*2]); - int b4 = *reinterpret_cast(&b[ldb*3]); - __vector int vb = {b1, b2, b3, b4}; - Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : - reinterpret_cast(vx); - *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s(vx1, vsum); - D += 16; - b += ldb*4; - y -= 4; - } while (y >= 4); - } - if (y >= 1) { - Vtype vb = vmask; - __vector int vb1 = reinterpret_cast<__vector int>(vb); - vb1[0] = *reinterpret_cast(&b[0]); - if (y >= 2) { - vb1[1] = *reinterpret_cast(&b[ldb]); - } - if (y >= 3) { - vb1[2] = *reinterpret_cast(&b[ldb*2]); - } - Vtype vx = vec_perm(reinterpret_cast(vb1), reinterpret_cast(vb1), mask); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : - reinterpret_cast(vx); - *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s(vx1, vsum); - D += 16; - } - *ColumnSumBuffer++ = vsum[0]; - *ColumnSumBuffer++ = vsum[1]; - *ColumnSumBuffer++ = vsum[2]; - *ColumnSumBuffer++ = vsum[3]; - B += 4; - CountN -= 4; - } - - // - // Process the remaining columns of matrix B. - // - if (CountN > 0) { - __vector unsigned int vsum = {0}; - const uint8_t* b = B; - size_t y = CountK; - if (y >= 4) { - do { - Vtype vb = vmask; - if (CountN == 1) { - vb[0] = b[0]; - vb[4] = b[ldb]; - vb[8] = b[ldb*2]; - vb[12] = b[ldb*3]; - } - if (CountN == 2) { - vb[0] = b[0]; - vb[1] = b[1]; - vb[4] = b[ldb]; - vb[5] = b[ldb+1]; - vb[8] = b[ldb*2]; - vb[9] = b[ldb*2+1]; - vb[12] = b[ldb*3]; - vb[13] = b[ldb*3+1]; - } - if (CountN == 3) { - vb[0] = b[0]; - vb[1] = b[1]; - vb[2] = b[2]; - vb[4] = b[ldb]; - vb[5] = b[ldb+1]; - vb[6] = b[ldb+2]; - vb[8] = b[ldb*2]; - vb[9] = b[ldb*2+1]; - vb[10] = b[ldb*2+2]; - vb[12] = b[ldb*3]; - vb[13] = b[ldb*3+1]; - vb[14] = b[ldb*3+2]; - } - Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : - reinterpret_cast(vx); - *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s(vx1, vsum); - D += 16; - b += ldb*4; - y -= 4; - } while (y >= 4); - } - if (y >= 1) { - Vtype vb = vmask; - if (CountN == 1) { - vb[0]= b[0]; - if (y >= 2) { - vb[4] = b[ldb]; - } - if (y >= 3) { - vb[8] = b[ldb*2]; - } - } - if (CountN == 2) { - vb[0] = b[0]; - vb[1] = b[1]; - if (y >= 2) { - vb[4] = b[ldb]; - vb[5] = b[ldb+1]; - } - if (y >= 3) { - vb[8] = b[ldb*2]; - vb[9] = b[ldb*2+1]; - } - } - if (CountN == 3) { - vb[0] = b[0]; - vb[1] = b[1]; - vb[2] = b[2]; - if (y >= 2) { - vb[4] = b[ldb]; - vb[5] = b[ldb+1]; - vb[6] = b[ldb+2]; - } - if (y >= 3) { - vb[8] = b[ldb*2]; - vb[9] = b[ldb*2+1]; - vb[10] = b[ldb*2+2]; - } - } - Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : - reinterpret_cast(vx); - *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s(vx1, vsum); - D += 16; - } - *ColumnSumBuffer++ = vsum[0]; - if (CountN >= 2) { - *ColumnSumBuffer++ = vsum[1]; - } - if (CountN >= 3) { - *ColumnSumBuffer++ = vsum[2]; - } - } -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - if (AIsSigned) { - MlasGemmQuantCopyPackA8x8<__vector signed char, true>(D, A, lda, CountM, CountK, RowSumBuffer); - } else { - MlasGemmQuantCopyPackA8x8<__vector unsigned char, false>(D, A, lda, CountM, CountK, RowSumBuffer); - } -} -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - if (BIsSigned) { - MlasGemmQuantCopyPackB8x8<__vector signed char, true>(D, B, ldb, CountN, CountK, ColumnSumBuffer); - } else { - MlasGemmQuantCopyPackB8x8< __vector unsigned char, false>(D, B, ldb, CountN, CountK, ColumnSumBuffer); - } -} - -template -MLAS_FORCEINLINE -void -MlasQgemmStoreVectorMMA - ( - MLAS_INT32X4 result[4], - int32_t* C, - size_t ldc, - size_t row, - bool ZeroMode, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - int pos - ) -{ - size_t RowCount; - __vector signed int vsum0, vsum1, vsum2, vsum3; -#if defined(_AIX) && defined(__clang__) - __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); -#else - __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); -#endif - C += VectorCount; - if (ZeroPointB != nullptr) { -#if defined(_AIX) && defined(__clang__) - __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); -#else - __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); -#endif - if (ZeroMode) { - for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { - vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; - vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; - vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; - vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - *reinterpret_cast<__vector int *>(&C[ldc]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; - *reinterpret_cast<__vector int *>(&C[ldc*2]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; - *reinterpret_cast<__vector int *>(&C[ldc*3]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; - } - for (; RowCount < row; RowCount++, C += ldc) { - vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - } - } else { - for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { - vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; - vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; - vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; - vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - *reinterpret_cast<__vector int *>(&C[ldc]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; - *reinterpret_cast<__vector int *>(&C[ldc*2]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; - *reinterpret_cast<__vector int *>(&C[ldc*3]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; - } - for (; RowCount < row; RowCount++, C += ldc) { - vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - } - } - } else { - if (ZeroMode) { - for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { - vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; - vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; - vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; - vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - *reinterpret_cast<__vector int *>(&C[ldc]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; - *reinterpret_cast<__vector int *>(&C[ldc*2]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; - *reinterpret_cast<__vector int *>(&C[ldc*3]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; - } - for (; RowCount < row; RowCount++, C += ldc) { - vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) = - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - } - } else { - for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { - vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; - vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; - vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; - vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - *reinterpret_cast<__vector int *>(&C[ldc]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; - *reinterpret_cast<__vector int *>(&C[ldc*2]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; - *reinterpret_cast<__vector int *>(&C[ldc*3]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; - } - for (; RowCount < row; RowCount++, C += ldc) { - vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; - *reinterpret_cast<__vector int *>(&C[0]) += - *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; - } - } - } -}; -template -MLAS_FORCEINLINE -void -MlasQgemmStoreScalarMMA( - MLAS_INT32X4 result[4], - int32_t* C, - size_t ldc, - size_t row, - bool ZeroMode, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB - ) -{ - if (ZeroPointB != nullptr) { - if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - int sum = RowSumBuffer[RowCount]; - sum *= ZeroPointB[0]; - sum += ColumnSumBuffer[0]; - C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; - } - } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - int sum = RowSumBuffer[RowCount]; - sum *= ZeroPointB[0]; - sum += ColumnSumBuffer[0]; - C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; - } - } - } else { - if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; - C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; - } - } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; - C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; - } - } - } -}; -template -MLAS_FORCEINLINE -void -MlasQgemmComputeMMA( - __vector_quad *acc0, - __vector_quad *acc1, - __vector unsigned char *va, - __vector unsigned char *vb - ) -{ - if (CountK == 16) { - __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); - __builtin_mma_xvi8ger4pp(acc0, va[3], vb[3]); - if (CountM) { - __builtin_mma_xvi8ger4pp(acc1, va[4], vb[0]); - __builtin_mma_xvi8ger4pp(acc1, va[5], vb[1]); - __builtin_mma_xvi8ger4pp(acc1, va[6], vb[2]); - __builtin_mma_xvi8ger4pp(acc1, va[7], vb[3]); - } - } else if (CountK == 12) { - __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); - if (CountM) { - __builtin_mma_xvi8ger4pp(acc1, va[3], vb[0]); - __builtin_mma_xvi8ger4pp(acc1, va[4], vb[1]); - __builtin_mma_xvi8ger4pp(acc1, va[5], vb[2]); - } - } else if (CountK == 8) { - __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); - if (CountM) { - __builtin_mma_xvi8ger4pp(acc1, va[2], vb[0]); - __builtin_mma_xvi8ger4pp(acc1, va[3], vb[1]); - } - } else { - __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); - if (CountM) { - __builtin_mma_xvi8ger4pp(acc1, va[1], vb[0]); - } - } -}; -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* A, - const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - if (CountM < 8 && CountM >= 4) { - CountM = 4; - } - size_t Mval = CountM; - if (Mval >= 8) { - Mval = 4; - } - while (CountN > 0) { - const int8_t *a = A; - typedef __vector unsigned char vec_t; - const uint8_t *b = B; - int32_t *C1; - __vector_quad acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7; - // - // Initialize the accumulators with zero. - // - __builtin_mma_xxsetaccz(&acc0); - __builtin_mma_xxsetaccz(&acc1); - __builtin_mma_xxsetaccz(&acc2); - __builtin_mma_xxsetaccz(&acc3); - __builtin_mma_xxsetaccz(&acc4); - __builtin_mma_xxsetaccz(&acc5); - __builtin_mma_xxsetaccz(&acc6); - __builtin_mma_xxsetaccz(&acc7); - MLAS_INT32X4 result[4] = {0}; - MLAS_INT32X4 result1[4] = {0}; - size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK; - size_t k1 = PackedCountK; - // - // Compute the output block using POWER10 MMA builtins. - // - while (k >= 16) { - vec_t *va = const_cast(reinterpret_cast(a)); - vec_t *vb = const_cast(reinterpret_cast(b)); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } else { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*16])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } else { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*32])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } else { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*48])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } else { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } - b += 64; - if (CountM >= 8) { - a += 128; - } else { - a += 64; - } - k -= 16; - } - if (k >= 12) { - vec_t *va = const_cast(reinterpret_cast(a)); - vec_t *vb = const_cast(reinterpret_cast(b)); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } else { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*16])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } else { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*32])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } else { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*48])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } else { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } - if (CountM >= 8) { - a += 96; - } else { - a += 48; - } - b += 48; - k -= 12; - } - if (k >= 8) { - vec_t *va = const_cast(reinterpret_cast(a)); - vec_t *vb = const_cast(reinterpret_cast(b)); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } else { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*16])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } else { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*32])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } else { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*48])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } else { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } - if (CountM >= 8) { - a += 64; - } else { - a += 32; - } - b += 32; - k -= 8; - } - if (k >= 4) { - vec_t *va = const_cast(reinterpret_cast(a)); - vec_t *vb = const_cast(reinterpret_cast(b)); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } else { - MlasQgemmComputeMMA(&acc0, &acc4, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*16])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } else { - MlasQgemmComputeMMA(&acc1, &acc5, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*32])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } else { - MlasQgemmComputeMMA(&acc2, &acc6, va, vb); - } - vb = const_cast(reinterpret_cast(&b[k1*48])); - if (CountM >= 8) { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } else { - MlasQgemmComputeMMA(&acc3, &acc7, va, vb); - } - } - // Store matrix C with accumulator result. - if (CountN >=16) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); - MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); - MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); - MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); - MlasQgemmStoreVectorMMA<12>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); - if (CountM >= 8) { - C1 = C+ldc*4; - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); - MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); - MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); - MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc7); - MlasQgemmStoreVectorMMA<12>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 12); - } - INC_BUFFER(16); - CountN -= 16; - B += 16 * 4 *PackedCountK; - C += 16; - } else { - if (CountN >=12 ) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); - MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); - MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); - MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); - if (CountM >= 8) { - C1 = C+ldc*4; - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); - MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); - MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); - MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); - } - INC_BUFFER(12); - if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); - if (CountM >= 8) { - __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc7); - } - } - CountN -= 12; - C += 12; - } else if (CountN >= 8) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); - MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); - MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - if (CountM >= 8) { - C1 = C+ldc*4; - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); - MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); - MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - } - INC_BUFFER(8); - if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); - if (CountM >= 8) { - __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc6); - } - } - CountN -= 8; - C += 8; - } else if (CountN >= 4) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); - MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - if (CountM >= 8) { - C1 = C+ldc*4; - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); - MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc5); - } - } - INC_BUFFER(4); - if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); - } - CountN -= 4; - C += 4; - } else { - __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); - if (CountM >= 8) { - __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc4); - } - } - CountN &= 3; - // - // Output the remaining partial output block. - // - if (CountN > 0) { - MlasQgemmStoreScalarMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); - if (CountM >= 8) { - MlasQgemmStoreScalarMMA<0>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); - } - INC_BUFFER(1); - if (CountN >= 2) { - MlasQgemmStoreScalarMMA<1>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); - if (CountM >= 8) { - MlasQgemmStoreScalarMMA<1>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); - } - INC_BUFFER(1); - } - if (CountN >= 3) { - MlasQgemmStoreScalarMMA<2>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); - if (CountM >= 8) { - MlasQgemmStoreScalarMMA<2>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); - } - INC_BUFFER(1); - } - } - CountN = 0; - } - } - if (CountM >= 8) { - return 8; - } - return CountM; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10 = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK, - MLAS_GEMM_QUANT_KERNEL_POWER10::PackedStrides.K, - 8 // Kernel M stride -}; diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp deleted file mode 100644 index 015d69de68766..0000000000000 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ /dev/null @@ -1,1874 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4_dq.cpp - -Abstract: - - This module contains the data structures and implementations - for blocked int4 quantization and dequantization. - - Int4 block quantization is used to compress weight tensors of large - language models. - ---*/ - -#include "q4common.h" - -template -constexpr -size_t -BlkQ4BufSize(size_t N, size_t K) -{ - const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); - return N * KBlocks * T::BlobSize; -} - -size_t -MLASCALL -MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K) -{ - if (GetMlasPlatform().FpQ4GemmDispatch == nullptr) { - return 0; - } - - switch (QType) { - case BlkQ4Sym: - return BlkQ4BufSize(N, K); - case BlkQ4Sym64: - return BlkQ4BufSize(N, K); - case BlkQ4Sym128: - return BlkQ4BufSize(N, K); - default: - return BlkQ4BufSize(N, K); - } -} - - -template -MLAS_FORCEINLINE -void -MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) -{ - auto* dst_ptr = reinterpret_cast(PackedBuf); - - for (size_t n = 0; n < N; n ++) { - const float* src = FpData; // starting from top of the column - - for (size_t k = 0; k < K; k += T::BlkLen) { - size_t klen = std::min(size_t(T::BlkLen), K - k); - float amax = 0.0f; // abs(max) - float max = 0.0f; - - for (size_t l = 0; l < klen; l++) { - const float v = src[ldb * l]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float scale = max / (-8); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - MlasQ4BlkScale(dst_ptr) = scale; - uint8_t* data = MlasQ4BlkData(dst_ptr); - - for (size_t kk = 0; kk < klen; kk += 32) { - size_t kklen = std::min((size_t)32, klen - kk); - for (size_t l = 0; l < 16; l++) { - const float v0 = l < kklen ? src[ldb * (kk + l)] * reciprocal_scale : 0; - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, v0 + 8.5f)); - - const size_t l1 = l + 16; - const float v1 = (l1 < kklen) ? src[ldb * (kk + l1)] * reciprocal_scale : 0; - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, v1 + 8.5f)); - - data[l] = vi0 | (vi1 << 4); - } - data += 16; - } - - // Move to next block of values in this column - dst_ptr += T::BlobSize; - src += ldb * klen; - } - - FpData++; // move to next column - } -} - -template<> -MLAS_FORCEINLINE -void -MlasQ4GemmPackBImpl( - void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) -{ - auto* dst_ptr = reinterpret_cast(PackedBuf); - - for (size_t n = 0; n < N; n++) { - const float* src = FpData; // starting from top of the column - - for (size_t k = 0; k < K; k += MLAS_Q4TYPE_BLK1::BlkLen) { - size_t klen = std::min(MLAS_Q4TYPE_BLK1::BlkLen, K - k); - float min = std::numeric_limits::max(); - float max = -min; - - for (size_t l = 0; l < klen; l++) { - const float v = src[ldb * l]; - if (v < min) min = v; - if (v > max) max = v; - } - min = std::min(min, 0.0f); - max = std::max(max, 0.0f); - - const float scale = (max - min) / ((1 << 4) - 1); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - float zero_point_fp = min; - if (scale != 0.0f) { - zero_point_fp = 0.f - min / scale; - } - - // Handle any clamping - uint8_t& zp = MlasQ4BlkZeroPoint(dst_ptr); - if (zero_point_fp < 0.0f) { - zp = 0; - } else if (zero_point_fp > 15.0f) { - zp = 15; - } else { - zp = (uint8_t)roundf(zero_point_fp); - } - MlasQ4BlkScale(dst_ptr) = scale; - uint8_t* data = MlasQ4BlkData(dst_ptr); - - for (size_t kk = 0; kk < klen; kk += 32) { - size_t kklen = std::min((size_t)32, klen - kk); - for (size_t l = 0; l < 16; l++) { - const float v0 = l < kklen ? src[ldb * (kk + l)] : 0; - const uint8_t vi0 = (uint8_t)std::min( - 15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); - - const size_t l1 = l + 16; - const float v1 = (l1 < kklen) ? src[ldb * (kk + l1)] : 0; - const uint8_t vi1 = (uint8_t)std::min( - 15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); - - data[l] = vi0 | (vi1 << 4); - } - data += 16; - } - // move to next block of values in this column - dst_ptr += MLAS_Q4TYPE_BLK1::BlobSize; - src += ldb * klen; - } - FpData++; // move to next column - } -} - -void -MLASCALL -MlasQ4GemmPackB( - MLAS_BLK_QUANT_TYPE QType, - void* PackedBuf, - const float* FpData, - size_t N, - size_t K, - size_t ldb - ) -{ - switch (QType) { - case BlkQ4Sym: - return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); - case BlkQ4Sym64: - return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); - case BlkQ4Sym128: - return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); - default: - return MlasQ4GemmPackBImpl(PackedBuf, FpData, N, K, ldb); - } -} - -template -MLAS_FORCEINLINE -void -MlasQ4GemmUnPackBImpl(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) -{ - const auto* src = reinterpret_cast(PackedBuf); - for (size_t n = 0; n < N; n++) { - for (size_t k = 0; k < K; k += T::BlkLen) { - size_t CountK = std::min(K - k, T::BlkLen); - - float* dest = FpData + ldb * k + n; - const float scale = MlasQ4BlkScale(src); - const uint8_t* data = MlasQ4BlkData(src); - - for (size_t kk = 0; kk < CountK; kk += 32) { - size_t kklen = std::min((size_t)32, CountK - kk); - for (size_t l = 0; l < 16; l++) { - const uint8_t vi = data[l]; - - if (l < kklen) { - const int vi0 = (vi & 0x0F) - 8; - const float v0 = vi0 * scale; - dest[ldb * (kk + l)] = v0; - } - - const size_t l1 = l + 16; - if (l1 < kklen) { - const int vi1 = (vi >> 4) - 8; - const float v1 = vi1 * scale; - dest[ldb * (kk + l1)] = v1; - } - } - data += 16; - } - src += T::BlobSize; - } - } -} - -template<> -MLAS_FORCEINLINE -void -MlasQ4GemmUnPackBImpl( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) -{ - const auto* src = reinterpret_cast(PackedBuf); - for (size_t n = 0; n < N; n++) { - for (size_t k = 0; k < K; k += MLAS_Q4TYPE_BLK1::BlkLen) { - size_t CountK = std::min(K - k, MLAS_Q4TYPE_BLK1::BlkLen); - - float* dest = FpData + ldb * k + n; - const float s = MlasQ4BlkScale(src); - const uint8_t z = MlasQ4BlkZeroPoint(src); - const uint8_t* pp = MlasQ4BlkData(src); - - for (size_t kk = 0; kk < CountK; kk += 32) { - size_t kklen = std::min((size_t)32, CountK - kk); - for (size_t l = 0; l < 16; l++) { - const uint8_t vi = pp[l]; - - if (l < kklen) { - const int8_t vi0 = vi & 0x0F; - const float v0 = (vi0 - z) * s; - dest[ldb * (kk + l)] = v0; - } - - size_t l1 = l + 16; - if (l1 < kklen) { - const int8_t vi1 = vi >> 4; - const float v1 = (vi1 - z) * s; - dest[ldb * (kk + l1)] = v1; - } - } - pp += 16; - } - src += MLAS_Q4TYPE_BLK1::BlobSize; - } - } -} - -void -MLASCALL -MlasQ4GemmUnPackB( - MLAS_BLK_QUANT_TYPE QType, - float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb - ) -{ - switch (QType) { - case BlkQ4Sym: - return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); - case BlkQ4Sym64: - return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); - case BlkQ4Sym128: - return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); - default: - return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); - } -} - - - -/*************************************************************** - * The quantization format that pack data and quantization - * parameters into separate buffers. - */ - - -template < - int Row_, ///< rows of a matrix - int Column_ ///< columns of a matrix - > -struct Shape2D { - static int const kRow = Row_; ///< rows of a matrix - static int const kColumn = Column_; ///< columns of a matrix - static int const kCount = Row_ * Column_; ///< total number of elements in a matrix -}; - - -template -struct BitsTraits { - static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); - - static constexpr int kBits = qbits; - static constexpr int kMax = signed_quant ? (1 << (qbits -1)) - 1 : (1 << qbits) - 1; - static constexpr int kMid = signed_quant ? 0 : (1 << (qbits - 1)); - static constexpr int kMin = signed_quant ? -(1 << (qbits - 1)) : 0; - static constexpr float kMaxFp = static_cast(kMax); - static constexpr float kMinFp = static_cast(kMin); - static constexpr float fullRange = kMaxFp - kMinFp; - static constexpr float halfRange = static_cast(kMid - kMin); - - // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; - static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); -}; - - -/** - * @brief Rectify min/max from a set of weights, and convert to scale and zero point - * for quantization. - * @tparam ScaleT type of scale, usually floating point of various bits - * @tparam qbits number of int bits used for zero point value - * @tparam signed_quant output quantized type is signed - * @param[in] min - * @param[in] max - * @param[out] scale - * @param[out] zp - */ -template -MLAS_FORCEINLINE -void -range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) -{ - min = std::min(min, 0.0f); - max = std::max(max, 0.0f); - - float scale_f = (max - min) / BitsTraits::fullRange; - - float zero_point_fp = min; - if (scale_f != 0.0f) { - zero_point_fp = BitsTraits::kMinFp - min / scale_f; - } - - if (zero_point_fp < BitsTraits::kMinFp) { - zp = static_cast(BitsTraits::kMin); - } else if (zero_point_fp > BitsTraits::kMaxFp) { - zp = static_cast(BitsTraits::kMax); - } else { - zp = (uint8_t)roundf(zero_point_fp); - } - scale = ScaleT(scale_f); -} - -/** - * @brief Rectify min/max from a set of symmetric weights, and convert - * to scale for quantization. - */ -template -MLAS_FORCEINLINE -void -range2scale(float min, float max, ScaleT& scale) -{ - max = fabsf(max) > fabsf(min) ? max : min; - // !!Note: in the quantized space, abs of min -8 > abs of max 7. - // Therefore map the larger half FP space to [-8, 0]. - // Minus sign achieves this purpose. - scale = ScaleT(-max / BitsTraits::halfRange); -}; - - -/** - * @brief Blockwise quantization methods - * @tparam ElementT source data type, e.g. fp32/fp16 - * @tparam block_size number of elemenets quantized together - * @tparam qbits number of bits in each quantized element - * @tparam Columnwise true: elements in a block come from one single column - * false: elements in a block come from one single row - */ -template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -struct BlockwiseQuantizer { - // To support other qbits, need to add bit packing code for - // storing to dst and zero points - static_assert(qbits == 4, "Only 4b block quantization is supported!"); - - using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; - - static - MLAS_FORCEINLINE - void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) - { - meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; - meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; - } - - static - MLAS_FORCEINLINE - void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { - int meta_rows; - int meta_cols; - quantizeMetaShape(rows, columns, meta_rows, meta_cols); - - // quantized matrix is stored in column major, packed by column - q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; - q_cols = meta_cols * QuantBlk::kColumn; - } - - static MLAS_FORCEINLINE void quantizedBufferSizes( - int rows, int columns, size_t& data_bytes, size_t& scale_num_elements, size_t* zero_point_bytes - ) - { - int meta_rows, meta_cols; - quantizeMetaShape(rows, columns, meta_rows, meta_cols); - int q_rows, q_cols; - quantizedShape(rows, columns, q_rows, q_cols); - - data_bytes = q_rows * q_cols; - scale_num_elements = meta_rows * meta_cols; - - if (zero_point_bytes) { - // this works for qbits == 4 but may need to be updated for other qbits values - *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; - } - } - - /** - * @brief Quantized a Matrix shape [rows, columns], resulting quantized - * and packed data are stored in column major (transposed) - * @param[out] dst pointer to the quantized weights, column major: [columns, rows] - * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] - * @param[out] zero_points pointer to the zero points, same shape as scale - * @param[in] src pointer to the source matrix, row major: [rows, columns] - * @param rows - * @param columns - * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next - */ - static void quantizeAndTranspose( - uint8_t* dst, - ElementT* scales, - uint8_t* zero_points, - const ElementT* src, - int32_t rows, - int32_t columns, - int32_t leadingDimension, - MLAS_THREADPOOL* thread_pool) - { - // Thread partitioning - const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; - const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; - const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; - - const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; - - int q_rows, q_cols; - quantizedShape(rows, columns, q_rows, q_cols); - - MlasTryBatchParallel( - thread_pool, total_thrd_blks, - [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); - - const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); - const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); - - const int32_t r = r_blk_idx * ThreadBlk::kRow; - const int32_t c = c_blk_idx * ThreadBlk::kColumn; - - const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); - const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); - - const int meta_row = r / QuantBlk::kRow; - const int meta_col = c / QuantBlk::kColumn; - - // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { - - // scan a single block to extract range [min, max] - float min = std::numeric_limits::max(); - float max = -min; - const int row_start = r + kpack * QuantBlk::kRow; - const int row_end = std::min(row_start + QuantBlk::kRow, r_end); - for (int i = row_start; i < row_end; ++i) { - for (int j = c; j < c_end; ++j) { - const float v = static_cast(src[i * leadingDimension + j]); - if (v < min) min = v; - if (v > max) max = v; - } - } - - // store scale and zero point at quant parameter matrix position - if (row_start < row_end) { - const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; - if (zero_points == nullptr) { - range2scale(min, max, scales[meta_idx]); - } else { - range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); - } - } - } - - // !! 4b specific code as we need to pack 2 4b numbers into one byte - if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; - zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); - } - - for (int32_t j = c; j < c_end; ++j) { - const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_r = i / QuantBlk::kRow; - const float scale = static_cast(scales[meta_c * row_blks + meta_r]); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = - static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; - } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); - } - - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); - } - } - }); - } - - /** - * @brief Dequantize a column major quantized matrix, and store the result in a column major - * matrix for use in GEMM - * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] - * @param[in] weights pointer to the quantized weights, column major: [columns, rows] - * @param[in] scales pointer to the scales of quantized blocks, column major layout - * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major - * scales - * @param[in] rows - * @param[in] columns - */ - static void dequantize( - ElementT* dst, - const uint8_t* weights, - const ElementT* scales, - const uint8_t* zero_points, - int32_t rows, - int32_t columns, - MLAS_THREADPOOL* thread_pool) - { - // Thread partitioning - const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; - const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; - const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; - - const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; - - int q_rows, q_cols; - quantizedShape(rows, columns, q_rows, q_cols); - - MlasTryBatchParallel( - thread_pool, total_thrd_blks, - [&](ptrdiff_t block_idx) { - int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); - int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); - - int32_t r = r_blk_idx * ThreadBlk::kRow; - int32_t c = c_blk_idx * ThreadBlk::kColumn; - - int32_t r_end = std::min(r + ThreadBlk::kRow, rows); - int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); - - for (int32_t j = c; j < c_end; ++j) { - const int32_t meta_col = j / QuantBlk::kColumn; - - // !! 4b specific code - // the whole loop is 4b specific due to sub 8 bit packing - // and unpacking. We can potentially make this qbits generic - // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_row = i / QuantBlk::kRow; - - const float scale0 = - static_cast(scales[meta_col * row_blks + meta_row]); - - const int zp_pair = - (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); - - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = - static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; - } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); - } - } - } - }); - } -}; - -/** - * @brief Blockwise quantization methods for QDQ format. Input tensor is quantized along column - * or row. Scales and zeros are calculated. Based on qbits, consecutive quantized elements - * in memory are packed together, which means the packing is along the row. Quantized data - * are stored in row major, so the output tensor reserves same shape, in terms of qbits type, - * as the input tensor. - * If has zero points, quantized type is unsigned. Otherwise, quantized type is signed and the - * zero point is 0. - * The transposed outputs are used by MatMulNBits, so quant type becomes uint4 with default - * zp at 8. - * @tparam Tin source data type, e.g. fp32/fp16 - * @tparam qbits number of bits in each quantized element - * @tparam signed_quant quantized type is signed - */ -template -struct BlockwiseQDQQuantizer; - -template -struct BlockwiseQDQQuantizer { - static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) - { - return (val >> (idx << 2)) & 0xF; - } - - static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) - { - auto shift = idx << 2; - return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); - } - - template - static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1) - { - if constexpr (add8) { - return ((v0 & 0xF) ^ 8) | (((v1 & 0xF) ^ 8) << 4); - } else { - return (v0 & 0xF) | ((v1 & 0xF) << 4); - } - } - - // If src is row major, then dst is column major. Transpose: - // | src0: low 4 bit | src0: high 4 bit | - // | src1: low 4 bit | src1: high 4 bit | - // --> - // | dst0: low 4 bit | dst1: low 4 bit | - // | dst0: high 4 bit| dst1: high 4 bit | - // If src is column major, then dst is row major. Transpose: - // | src0: low 4 bit | src1: low 4 bit | - // | src0: high 4 bit| src1: high 4 bit | - // --> - // | dst0: low 4 bit | dst0: high 4 bit | - // | dst1: low 4 bit | dst1: high 4 bit | - template - static MLAS_FORCEINLINE void Transpose(uint8_t src0, uint8_t src1, uint8_t& dst0, uint8_t& dst1) - { - if constexpr (add8) { - dst0 = ((src0 & 0xF) ^ 8) | (((src1 & 0xF) ^ 8) << 4); - dst1 = (((src0 & 0xF0) ^ 0x80) >> 4) | ((src1 & 0xF0) ^ 0x80); - } else { - dst0 = (src0 & 0xF) | ((src1 & 0xF) << 4); - dst1 = ((src0 & 0xF0) >> 4) | (src1 & 0xF0); - } - } - - static MLAS_FORCEINLINE uint8_t QuantizeV(Tin src, float reciprocal_scale, uint8_t zero_point) - { - return static_cast( - std::clamp( - static_cast( - std::roundf(static_cast(src) * reciprocal_scale) - ) + static_cast(zero_point), - BitsTraits<4, signed_quant>::kMin, - BitsTraits<4, signed_quant>::kMax - ) - ); - } - - /** - * @brief Quantize a matrix shape [rows, columns] column-wise. Scales and zero points are calculated. - * Quantized data are packed row-wise based on qbits. Quantized data are stored in row major - * so the output tensor reserves the shape, in terms output type. - * @param src the source matrix, row major: [rows * columns] - * @param scales the scales of quantized blocks, row major with shape: - * [ceil(rows/quant_block_size) * columns] - * @param zero_points the zero points of quantized blocks, packed. Same shape as scales in terms - * of output type. In uint8_t, the shape is: - * [ceil(columns * ceil(rows / quant_block_size) * qbits / 8)] - * @param dst the quantized weights, row major: [rows * columns] in terms of output type. - * In uint8_t, the shape is: [ceil(rows * columns * qbits / 8] - * @param rows number of rows in the source matrix - * @param columns number of columns in the source matrix. - * @param quant_block_size number of rows/columns quantized together - * @param thread_pool thread pool for parallel processing - */ - static void QuantizeColumnWise( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - ORT_ENFORCE(zero_points || signed_quant, "Unsigned quant with no zero points is not supported."); - // Must avoid multiple thread write to a single byte, which means the starting index - // of a thread block must be even. To achieve that, we need to customize the thread - // block size based on the parity of columns. - if (columns & 1) { - QuantizeColumnWisePackUnaligned( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - } else { - QuantizeColumnWisePackAligned( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - } - } - - - /** - * @brief Transpose quantized tensors, which has been column-wise quantized, for use in MatMulNbits. - * Since both src tensor and dst tensor are packed, it's not needed to consider sign - * during the unpacking/packing in transpose. - * @param src_weights The quantized weights, row major: [rows, columns] in qbits type. - * In uint8_t, size of [ceil(rows * columns * qbits / 8)]. - * @param src_scales [ceil(rows / quant_block_size), columns] - * @param src_zero_points [ceil(rows / quant_block_size), columns] in qbits type. In uint8_t, size of - * [ceil(ceil(rows / quant_block_size) * columns * qbits / 8 )]. - * @param dst_weights the transposed quantized weights, column major. In uint8_t, the shape is - * [columns, ceil(rows / quant_block_size), ceil(quant_block_size * qbits / 8)] - * @param dst_scales [columns, ceil(rows / quant_block_size)] - * @param dst_zero_points [columns, ceil(ceil(rows / quant_block_size) * qbits / 8)] in uint8_t. - * @param rows number of src rows in qbits type. - * @param columns number of src columns in qbits type. - * @param quant_block_size number of elements quantized together - * @param thread_pool thread pool for parallel processing - */ - static void TransposeColumnWiseQuantized( - const uint8_t* src_weights, - const Tin* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - Tin* dst_scales, - uint8_t* dst_zero_points, - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - ORT_ENFORCE( - src_zero_points || signed_quant || dst_zero_points, - "Unsigned quant types without zero points must allocate zero points with value 0." - ); - // Must avoid multiple thread write to a single byte, which means the starting index - // of a thread block must be even. To achieve that, we need to customize the thread - // block size based on the parity of columns. - if (columns & 1) { - TransposeColumnWiseQuantizedPackUnaligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool - ); - } else { - TransposeColumnWiseQuantizedPackAligned( - src_weights, src_scales, src_zero_points, - dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool - ); - } - } - -private: - static void QuantizeColumnWisePackAligned( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - ORT_ENFORCE(columns % 2 == 0, "Columns must be multiple of 2."); - // Thread block is [quant_block_size, thread_blk_size]. thread_blk_size % 2 == 0. - constexpr int32_t thread_blk_size = 128; - const auto num_row_thread_blk = (rows + quant_block_size - 1) / quant_block_size; - const auto num_col_thread_blk = (columns + thread_blk_size - 1) / thread_blk_size; - const auto num_thread_blk = num_row_thread_blk * num_col_thread_blk; - constexpr auto minf = std::numeric_limits::lowest(); - constexpr auto maxf = std::numeric_limits::max(); - - MlasTryBatchParallel( - thread_pool, static_cast(num_thread_blk), - [&](ptrdiff_t thread_blk_idx) { - // !!warning!!: buffering the whole thread block - constexpr int32_t buffer_size = 128; - ORT_ENFORCE(buffer_size == thread_blk_size, "buffer size must be equal to thread block size."); - float reciprocal_scale_t[buffer_size]; - uint8_t zp_t[buffer_size]; - float vmin_t[buffer_size]; - float vmax_t[buffer_size]; - - const int32_t row_thread_blk_idx = static_cast(thread_blk_idx / num_col_thread_blk); - const int32_t col_thread_blk_idx = static_cast(thread_blk_idx % num_col_thread_blk); - const int32_t row_idx = row_thread_blk_idx * quant_block_size; - const int32_t col_idx = col_thread_blk_idx * buffer_size; - const int32_t row_size = std::min(quant_block_size, rows - row_idx); - const int32_t col_size = std::min(buffer_size, columns - col_idx); - // input_idx, scale_idx, col_size are aligned to 2 - auto input_idx = row_idx * columns + col_idx; - auto scale_idx = row_thread_blk_idx * columns + col_idx; - - Tin scale0_tt, scale1_tt; - uint8_t v0_tt, v1_tt; - - std::fill_n(vmin_t, buffer_size, maxf); - std::fill_n(vmax_t, buffer_size, minf); - - // calculate min/max - for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { - // TODO(fajin): use SIMD - for (int32_t i = 0; i < col_size; i += 2) { - auto v0 = static_cast(src[input_idx_t + i]); - auto v1 = static_cast(src[input_idx_t + i + 1]); - vmin_t[i] = std::min(vmin_t[i], v0); - vmax_t[i] = std::max(vmax_t[i], v0); - vmin_t[i + 1] = std::min(vmin_t[i + 1], v1); - vmax_t[i + 1] = std::max(vmax_t[i + 1], v1); - } - } - - // calculate scale and zero point, and store - for (int32_t i = 0; i < col_size; i += 2) { - v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; - - if (zero_points) { - range2scalezp(vmin_t[i], vmax_t[i], scale0_tt, v0_tt); - range2scalezp(vmin_t[i + 1], vmax_t[i + 1], scale1_tt, v1_tt); - zero_points[(scale_idx + i) >> 1] = Pack(v0_tt, v1_tt); - } else { - range2scale(vmin_t[i], vmax_t[i], scale0_tt); - range2scale(vmin_t[i + 1], vmax_t[i + 1], scale1_tt); - } - - scales[scale_idx + i] = scale0_tt; - scales[scale_idx + i + 1] = scale1_tt; - - float scalef0 = static_cast(scale0_tt); - reciprocal_scale_t[i] = scalef0 ? 1.0f / scalef0 : 0.0f; - zp_t[i] = v0_tt; - - float scalef1 = static_cast(scale1_tt); - reciprocal_scale_t[i + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; - zp_t[i + 1] = v1_tt; - } - - // quantize and pack - for (int32_t j = 0, input_idx_t = input_idx; j < row_size; ++j, input_idx_t += columns) { - // TODO(fajin): use SIMD - for (int32_t i = 0; i < col_size; i += 2) { - v0_tt = QuantizeV(src[input_idx_t + i], reciprocal_scale_t[i], zp_t[i]); - v1_tt = QuantizeV(src[input_idx_t + i + 1], reciprocal_scale_t[i + 1], zp_t[i + 1]); - dst[(input_idx_t + i) >> 1] = Pack(v0_tt, v1_tt); - } - } - } - ); - } - - static void QuantizeColumnWisePackUnaligned( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - // Thread block is [quant_block_size * 2, columns], so the packed bytes do not cross threads. - constexpr auto minf = std::numeric_limits::lowest(); - constexpr auto maxf = std::numeric_limits::max(); - auto row_thread_blk_size = quant_block_size * 2; - auto num_row_thread_blk = (rows + row_thread_blk_size - 1) / (row_thread_blk_size); - - MlasTryBatchParallel( - thread_pool, static_cast(num_row_thread_blk), - [&](ptrdiff_t thread_blk_idx) { - constexpr int32_t buffer_size = 128; - float reciprocal_scale_t[buffer_size]; - uint8_t zp_t[buffer_size]; - float vmin_t[buffer_size]; - float vmax_t[buffer_size]; - - auto row_thread_blk_idx = static_cast(thread_blk_idx); - int32_t row_idx = row_thread_blk_idx * row_thread_blk_size; - int32_t row_idx_end = std::min(row_thread_blk_size + row_idx, rows); - auto input_idx = row_idx * columns; - auto scale_idx = row_thread_blk_idx * 2 * columns; - Tin scale0_tt, scale1_tt; - uint8_t v0_tt, v1_tt; - - for (; row_idx < row_idx_end; row_idx += quant_block_size) { - // per quant block row - auto quant_row_size = std::min(quant_block_size, row_idx_end - row_idx); - auto input_buffer_idx = input_idx; - auto scale_buffer_idx = scale_idx; - for (int32_t buffer_idx = 0; buffer_idx < columns; buffer_idx += buffer_size) { - // per buffer column - auto buffer_col_size = std::min(buffer_size, columns - buffer_idx); - - std::fill_n(vmin_t, buffer_size, maxf); - std::fill_n(vmax_t, buffer_size, minf); - // calculate min/max of [quant block, buffer] - auto input_idx_t = input_buffer_idx; - for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { - // TODO(fajin): use SIMD - for (int32_t i = 0; i < buffer_col_size; ++i) { - auto v = static_cast(src[input_idx_t + i]); - vmin_t[i] = std::min(vmin_t[i], v); - vmax_t[i] = std::max(vmax_t[i], v); - } - } - - // calculate scale and zero point - auto scale_buffer_idx_end = scale_buffer_idx + buffer_col_size; - int32_t col_idx = 0; - // leading unailgned zero points - if (scale_buffer_idx & 1) { - v0_tt = BitsTraits<4, signed_quant>::kMid; - if (zero_points) { - range2scalezp(vmin_t[0], vmax_t[0], scale0_tt, v0_tt); - zero_points[scale_buffer_idx >> 1] = SetElem( - v0_tt, 1, zero_points[scale_buffer_idx >> 1] - ); - } else { - range2scale(vmin_t[0], vmax_t[0], scale0_tt); - } - - scales[scale_buffer_idx] = scale0_tt; - - float scalef = static_cast(scale0_tt); - reciprocal_scale_t[0] = scalef ? 1.0f / scalef : 0.0f; - zp_t[0] = v0_tt; - - ++col_idx; - ++scale_buffer_idx; - } - // aligned zero points - for (; scale_buffer_idx < scale_buffer_idx_end - 1; col_idx += 2, scale_buffer_idx += 2) { - v0_tt = v1_tt = BitsTraits<4, signed_quant>::kMid; - if (zero_points) { - range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); - range2scalezp( - vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt, v1_tt - ); - zero_points[scale_buffer_idx >> 1] = Pack(v0_tt, v1_tt); - } else { - range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); - range2scale(vmin_t[col_idx + 1], vmax_t[col_idx + 1], scale1_tt); - } - - scales[scale_buffer_idx] = scale0_tt; - scales[scale_buffer_idx + 1] = scale1_tt; - - float scalef0 = static_cast(scale0_tt); - reciprocal_scale_t[col_idx] = scalef0 ? 1.0f / scalef0 : 0.0f; - zp_t[col_idx] = v0_tt; - - float scalef1 = static_cast(scale1_tt); - reciprocal_scale_t[col_idx + 1] = scalef1 ? 1.0f / scalef1 : 0.0f; - zp_t[col_idx + 1] = v1_tt; - } - // tailing unaligned elements - if (scale_buffer_idx < scale_buffer_idx_end) { - v0_tt = BitsTraits<4, signed_quant>::kMid; - if (zero_points) { - range2scalezp(vmin_t[col_idx], vmax_t[col_idx], scale0_tt, v0_tt); - zero_points[scale_buffer_idx >> 1] = SetElem( - v0_tt, 0, zero_points[scale_buffer_idx >> 1] - ); - } else { - range2scale(vmin_t[col_idx], vmax_t[col_idx], scale0_tt); - } - - scales[scale_buffer_idx] = scale0_tt; - - float scalef = static_cast(scale0_tt); - reciprocal_scale_t[col_idx] = scalef ? 1.0f / scalef : 0.0f; - zp_t[col_idx] = v0_tt; - - ++scale_buffer_idx; - } - - // quantize and pack - input_idx_t = input_buffer_idx; - for (int32_t j = 0; j < quant_row_size; ++j, input_idx_t += columns) { - auto input_idx_t_start = input_idx_t; - auto input_idx_t_end = input_idx_t + buffer_col_size; - col_idx = 0; - // leading unaligned output - if (input_idx_t_start & 1) { - v1_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); - dst[input_idx_t_start >> 1] = SetElem(v1_tt, 1, dst[input_idx_t_start >> 1]); - - ++col_idx; - ++input_idx_t_start; - } - // aligned output - // TODO(fajin): use SIMD - for (; input_idx_t_start < input_idx_t_end - 1; col_idx += 2, input_idx_t_start += 2) { - v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); - v1_tt = QuantizeV( - src[input_idx_t_start + 1], reciprocal_scale_t[col_idx + 1], zp_t[col_idx + 1] - ); - - dst[input_idx_t_start >> 1] = Pack(v0_tt, v1_tt); - } - // tailing unaligned output - if (input_idx_t_start < input_idx_t_end) { - v0_tt = QuantizeV(src[input_idx_t_start], reciprocal_scale_t[col_idx], zp_t[col_idx]); - dst[input_idx_t_start >> 1] = SetElem(v0_tt, 0, dst[input_idx_t_start >> 1]); - } - } - - input_buffer_idx += buffer_size; - } - - input_idx += quant_block_size * columns; - scale_idx += columns; - } - } - ); - } - - static void TransposeColumnWiseQuantizedPackAligned( - const uint8_t* src_weights, // [rows, columns / 2] - const Tin* src_scales, // [ceil(rows / quant_block_size), columns] - const uint8_t* src_zero_points, // [ceil(rows / quant_block_size), columns / 2] - uint8_t* dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] - Tin* dst_scales, // [columns, ceil(rows / quant_block_size)] - uint8_t* dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool - ) - { - ORT_ENFORCE(columns % 2 == 0, "Columns must be multiple of 2"); - - auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; - auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; - // number of rows in transposed dst - auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; - auto packed_col_size = columns / 2; - - // weight transpose thread block is [dst_bytes_per_quant_blk, 2] on dst_Transpose. - // Map to src it is [quant_block_size, 1]. Both in uint8_t. - auto num_thread_blk = row_quant_blk_num * packed_col_size; - MlasTryBatchParallel( - thread_pool, static_cast(num_thread_blk), - [&](ptrdiff_t thread_blk_idx) { - uint8_t src0_t, src1_t; - uint8_t dst0_t, dst1_t; - - auto row_thread_blk_idx = static_cast(thread_blk_idx / packed_col_size); - auto col_thread_blk_idx = static_cast(thread_blk_idx % packed_col_size); - - auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; - auto dstT_col_idx = col_thread_blk_idx * 2; - auto dst_idx = dstT_col_idx * dstT_num_row + dstT_row_idx; - - auto src_row_idx = row_thread_blk_idx * quant_block_size; - auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); - auto src_col_idx = col_thread_blk_idx; - auto src_idx = src_row_idx * packed_col_size + src_col_idx; - auto src_end_idx = src_row_end_idx * packed_col_size + src_col_idx; - - for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { - src0_t = src_weights[src_idx]; - src1_t = src_weights[src_idx + packed_col_size]; - src_idx += packed_col_size + packed_col_size; - Transpose(src0_t, src1_t, dst0_t, dst1_t); - dst_weights[dst_idx] = dst0_t; - dst_weights[dst_idx + dstT_num_row] = dst1_t; - } - - if (src_idx < src_end_idx) { - src0_t = src_weights[src_idx]; - src1_t = 0; - Transpose(src0_t, src1_t, dst0_t, dst1_t); - dst_weights[dst_idx] = dst0_t; - dst_weights[dst_idx + dstT_num_row] = dst1_t; - } - } - ); - - // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. - MlasTryBatchParallel( - thread_pool, static_cast(columns), - [&](ptrdiff_t thread_blk_idx) { - auto col_thread_blk_idx = static_cast(thread_blk_idx); - auto src_idx = col_thread_blk_idx; - auto dst_idx = col_thread_blk_idx * row_quant_blk_num; - for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { - dst_scales[dst_idx] = src_scales[src_idx]; - } - } - ); - - if (src_zero_points) { - // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 2] - // on dst_Transpose. Map to src it is [row_quant_blk_num, 1]. Both in uint8_t. - auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; - MlasTryBatchParallel( - thread_pool, static_cast(packed_col_size), - [&](ptrdiff_t thread_blk_idx) { - uint8_t src0_t, src1_t; - uint8_t dst0_t, dst1_t; - - auto col_thread_blk_idx = static_cast(thread_blk_idx); - auto src_idx = col_thread_blk_idx; - auto src_end_idx = row_quant_blk_num * packed_col_size + col_thread_blk_idx; - auto dst_idx = col_thread_blk_idx * 2 * dst_zp_row_num; - - for (; src_idx < src_end_idx - packed_col_size; ++dst_idx) { - src0_t = src_zero_points[src_idx]; - src1_t = src_zero_points[src_idx + packed_col_size]; - Transpose(src0_t, src1_t, dst0_t, dst1_t); - dst_zero_points[dst_idx] = dst0_t; - dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; - src_idx += packed_col_size + packed_col_size; - } - - if (src_idx < src_end_idx) { - src0_t = src_zero_points[src_idx]; - src1_t = 0; - Transpose(src0_t, src1_t, dst0_t, dst1_t); - dst_zero_points[dst_idx] = dst0_t; - dst_zero_points[dst_idx + dst_zp_row_num] = dst1_t; - } - } - ); - } - } - - static void TransposeColumnWiseQuantizedPackUnaligned( - const uint8_t* src_weights, // size of [ceil(rows * columns / 2)] - const Tin* src_scales, // [ceil(rows / quant_block_size), columns] - const uint8_t* src_zero_points, // size of [ceil(ceil(rows / quant_block_size) * columns / 2)] - uint8_t *dst_weights, // [columns, ceil(rows / quant_block_size), ceil(quant_block_size / 2)] - Tin* dst_scales, // [columns, ceil(rows / quant_block_size)] - uint8_t* dst_zero_points, // [columns, ceil(ceil(rows / quant_block_size) / 2)] - int32_t rows, - int32_t columns, - int32_t quant_block_size, - MLAS_THREADPOOL* thread_pool) - { - auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size; - auto dst_bytes_per_quant_blk = (quant_block_size * 4 + 7) / 8; - // number of rows in transposed dst - auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk; - - // weight transpose thread block is [dst_bytes_per_quant_blk, 1] on dst_Transpose in uint8_t. - // Map to src it is [quant_block_size, 1] in int4. - auto num_thread_blk = row_quant_blk_num * columns; - MlasTryBatchParallel( - thread_pool, static_cast(num_thread_blk), - [&](ptrdiff_t thread_blk_idx) { - uint8_t src0_t, src1_t; - - auto row_thread_blk_idx = static_cast(thread_blk_idx / columns); - auto col_thread_blk_idx = static_cast(thread_blk_idx % columns); - - auto dstT_row_idx = row_thread_blk_idx * dst_bytes_per_quant_blk; - auto dst_idx = col_thread_blk_idx * dstT_num_row + dstT_row_idx; - - auto src_row_idx = row_thread_blk_idx * quant_block_size; - auto src_row_end_idx = std::min(src_row_idx + quant_block_size, rows); - auto src_idx = src_row_idx * columns + col_thread_blk_idx; - auto src_end_idx = src_row_end_idx * columns + col_thread_blk_idx; - - for (; src_idx < src_end_idx - columns; ++dst_idx) { - src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); - src1_t = GetElem(src_weights[(src_idx + columns) >> 1], (src_idx + columns) & 1); - dst_weights[dst_idx] = Pack(src0_t, src1_t); - src_idx += columns + columns; - } - - if (src_idx < src_end_idx) { - src0_t = GetElem(src_weights[src_idx >> 1], src_idx & 1); - dst_weights[dst_idx] = Pack(src0_t, 0); - } - } - ); - - // Transpose scales. Thread block is [row_quant_blk_num, 1] on dst_Transpose. - MlasTryBatchParallel( - thread_pool, static_cast(columns), - [&](ptrdiff_t thread_blk_idx) { - auto col_thread_blk_idx = static_cast(thread_blk_idx); - auto src_idx = col_thread_blk_idx; - auto dst_idx = col_thread_blk_idx * row_quant_blk_num; - for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) { - dst_scales[dst_idx] = src_scales[src_idx]; - } - } - ); - - if (src_zero_points) { - // Transpose zero points. Thread block is [ceil(row_quant_blk_num / 2), 1] on dst_Transpose in uint8_t. - // Map to src it is [row_quant_blk_num, 1] in int4. - auto dst_zp_row_num = (row_quant_blk_num + 1) / 2; - MlasTryBatchParallel( - thread_pool, static_cast(columns), - [&](ptrdiff_t thread_blk_idx) { - uint8_t src0_t, src1_t; - - auto col_thread_blk_idx = static_cast(thread_blk_idx); - auto src_idx = col_thread_blk_idx; - auto src_end_idx = row_quant_blk_num * columns + col_thread_blk_idx; - auto dst_idx = col_thread_blk_idx * dst_zp_row_num; - - for (; src_idx < src_end_idx - columns; ++dst_idx) { - src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); - src1_t = GetElem(src_zero_points[(src_idx + columns) >> 1], (src_idx + columns) & 1); - dst_zero_points[dst_idx] = Pack(src0_t, src1_t); - src_idx += columns + columns; - } - - if (src_idx < src_end_idx) { - src0_t = GetElem(src_zero_points[src_idx >> 1], src_idx & 1); - dst_zero_points[dst_idx] = Pack(src0_t, 0); - } - } - ); - } - } -}; - -template -void -MlasBlockwiseQuantMetaShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& meta_rows, - int& meta_cols - ) -{ - switch (block_size) { - case 16: { - if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); - } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); - } - break; - } - case 32: { - if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); - } else { - BlockwiseQuantizer::quantizeMetaShape( - rows, columns, meta_rows, meta_cols); - } - break; - } - case 64: { - if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } - break; - } - case 128: { - if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } - break; - } - case 256: { - if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, - meta_cols); - } - break; - } - default: - meta_rows = 0; - meta_cols = 0; - break; - } -} - - - -template -void -MlasBlockwiseQuantizedShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& q_rows, - int& q_cols - ) -{ - switch (block_size) { - case 16: { - if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } - break; - } - case 32: { - if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } else { - BlockwiseQuantizer::quantizedShape( - rows, columns, q_rows, q_cols); - } - break; - } - case 64: { - if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } - break; - } - case 128: { - if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } - break; - } - case 256: { - if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); - } - break; - } - default: - q_rows = 0; - q_cols = 0; - break; - } -} - - -template -void -MlasBlockwiseQuantMetaShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& meta_rows, - int& meta_cols - ); - -template -void -MlasBlockwiseQuantMetaShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& meta_rows, - int& meta_cols - ); - -template -void -MlasBlockwiseQuantizedShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& q_rows, - int& q_cols - ); - -template -void -MlasBlockwiseQuantizedShape( - int block_size, - bool columnwise, - int rows, - int columns, - int& q_rows, - int& q_cols - ); - -void MLASCALL -MlasBlockwiseQuantizedBufferSizes( - int qbits, - int block_size, - bool columnwise, - int rows, - int columns, - size_t& q_data_size_in_bytes, - size_t& q_scale_num_elements, - size_t* q_zero_point_size_in_bytes -) -{ - q_data_size_in_bytes = q_scale_num_elements = 0; - if (q_zero_point_size_in_bytes) { - *q_zero_point_size_in_bytes = 0; - } - - if (qbits == 4) { - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } - } -} - - -template -void -MlasQuantizeBlockwise( - uint8_t* dst, - T* scales, - uint8_t* zero_points, - const T* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool - ) -{ - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } else { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } else { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } else { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } else { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } else { - BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); - } - break; - - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } -} - -template -void -MlasQuantizeBlockwise( - uint8_t* dst, - float* scales, - uint8_t* zero_points, - const float* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool - ); - -template -void -MlasQuantizeBlockwise( - uint8_t* dst, - MLAS_FP16* scales, - uint8_t* zero_points, - const MLAS_FP16* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool - ); - - -template -void -MlasDequantizeBlockwise( - T* dst, - const uint8_t* src, - const T* scales, - const uint8_t* zero_points, - int block_size, - bool columnwise, - int rows, - int columns, - MLAS_THREADPOOL* thread_pool - ) -{ - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } - break; - case 32: - if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } - break; - case 64: - if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } - break; - case 128: - if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, - rows, columns, thread_pool); - } - break; - case 256: - if (columnwise) { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, - columns, thread_pool); - } else { - BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, - rows, columns, thread_pool); - } - break; - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } -} - -template void -MlasDequantizeBlockwise( - float* dst, - const uint8_t* src, - const float* scales, - const uint8_t* zero_points, - int block_size, - bool columnwise, - int rows, - int columns, - MLAS_THREADPOOL* thread_pool -); - -template -bool -MlasQDQQuantizeBlockwise( - const Tin* src, - Tin* scales, - uint8_t* zero_points, - uint8_t* dst, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -) -{ - if (columnwise) { - if (zero_points) { - BlockwiseQDQQuantizer::QuantizeColumnWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - return false; - } else { - BlockwiseQDQQuantizer::QuantizeColumnWise( - src, scales, zero_points, dst, rows, columns, quant_block_size, thread_pool - ); - return true; - } - } else { - ORT_THROW("Row-wise MlasQDQQuantizeBlockwise is not implemented"); - } -} - -template bool -MlasQDQQuantizeBlockwise( - const float* src, - float* scales, - uint8_t* zero_points, - uint8_t* dst, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -template bool -MlasQDQQuantizeBlockwise( - const MLAS_FP16* src, - MLAS_FP16* scales, - uint8_t* zero_points, - uint8_t* dst, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -template -void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const Tin* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - Tin* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -) -{ - if (columnwise) { - BlockwiseQDQQuantizer::TransposeColumnWiseQuantized( - src_weights, src_scales, src_zero_points, dst_weights, dst_scales, dst_zero_points, - rows, columns, quant_block_size, thread_pool - ); - } else { - ORT_THROW("Row-wise MlasQDQTransposeBlockwiseQuantized is not implemented"); - } -} - -template void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const float* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - float* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -template void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const float* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - float* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -template void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const MLAS_FP16* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - MLAS_FP16* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); - -template void -MlasQDQTransposeBlockwiseQuantized( - const uint8_t* src_weights, - const MLAS_FP16* src_scales, - const uint8_t* src_zero_points, - uint8_t* dst_weights, - MLAS_FP16* dst_scales, - uint8_t* dst_zero_points, - bool columnwise, - int rows, - int columns, - int quant_block_size, - MLAS_THREADPOOL* thread_pool -); diff --git a/onnxruntime/core/mlas/lib/q4_dq_cli.cpp b/onnxruntime/core/mlas/lib/q4_dq_cli.cpp deleted file mode 100644 index 9c330b9eaf12a..0000000000000 --- a/onnxruntime/core/mlas/lib/q4_dq_cli.cpp +++ /dev/null @@ -1,304 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4_dq_cli.cpp - -Abstract: - - This module implements a command line tool that quantize fp32 into int4, - or reverse this process.. - ---*/ - -#include "mlas_q4.h" - -#include -#include -#include -#include -#include -#include - -char* -getCmdOption(char** begin, char** end, const std::string& option) -{ - char** itr = std::find(begin, end, option); - if (itr != end && ++itr != end) { - return *itr; - } - return nullptr; -} - -void -usage(const char* cli) -{ - std::cout << std::endl; - std::cout << "This utility performs int4 quantize and dequantize of a matrix, usage: " << std::endl; - std::cout << " " << cli << " ACTION NUM_ROWS NUM_COLS [OPTIONS]" << std::endl; - std::cout << " ACTION: can be either q (quantize) or dq (de-quantize)." << std::endl; - std::cout << " NUM_ROWS: number of rows in the matrix." << std::endl; - std::cout << " NUM_COLS: number of columns in the matrix." << std::endl; - std::cout << "options:" << std::endl; - std::cout << " --quant_type {0, 1}." << std::endl; - std::cout << " Type of the block quantization." << std::endl; - std::cout << " 0: Symmetric block quant, with fp32 scale." << std::endl; - std::cout << " 1: (default) Block quant with fp32 scale and int8 zero-point." << std::endl; - std::cout << " --input_file {PATH}." << std::endl; - std::cout << " Path to the input file." << std::endl; - std::cout << " --input_offset {N}." << std::endl; - std::cout << " Skip the first N bytes when reading the input file." << std::endl; - std::cout << " Ignored when read from std in." << std::endl; - std::cout << " --output_file {PATH}." << std::endl; - std::cout << " Path to the output file. Write to std out when missing" << std::endl; - std::cout << " --output_format {txt,bin}" << std::endl; - std::cout << " txt: (default) text format: space separated numbers." << std::endl; - std::cout << " bin: Binary format, can not be output to std out." << std::endl; - std::cout << std::endl; -} - - -// -// Variable for commands -// -struct Cli { - bool dqmode = false; // false -> quantize, true -> dequantize - - size_t num_rows = 0; - size_t num_cols = 0; - - MLAS_BLK_QUANT_TYPE quant_type = BlkQ4Zp8; - - char* input_file = nullptr; - size_t input_offset = 0; - - char* output_file = nullptr; - bool output_bin = false; // false -> csv, true -> binary -}; - - -bool -parseArgs(int argc, char* argv[], Cli& cli) -{ - if (argc < 4) { - return false; - } - - if (strncmp(argv[1], "q", 2) == 0) { - cli.dqmode = false; - } else if (strncmp(argv[1], "dq", 3) == 0) { - cli.dqmode = true; - } else { - return false; - } - - errno = 0; - cli.num_rows = (size_t)strtoul(argv[2], nullptr, 0); - if (cli.num_rows == 0 || errno != 0) { - return false; - } - cli.num_cols = (size_t)strtoul(argv[3], nullptr, 0); - if (cli.num_cols == 0 || errno != 0) { - return false; - } - - char* quant_t = getCmdOption(argv + 4, argv + argc, "--quant_type"); - if (quant_t) { - if (strncmp(quant_t, "0", 2) == 0) { - cli.quant_type = BlkQ4Sym; - } - } - - cli.input_file = getCmdOption(argv + 4, argv + argc, "--input_file"); - char* offset_str = getCmdOption(argv + 4, argv + argc, "--input_offset"); - if (offset_str != nullptr) { - errno = 0; - cli.input_offset = (size_t)strtoul(offset_str, nullptr, 0); - if (errno != 0) { - return false; - } - } - - cli.output_file = getCmdOption(argv + 4, argv + argc, "--output_file"); - char* output_format_str = getCmdOption(argv + 4, argv + argc, "--output_format"); - if (output_format_str != nullptr) { - if (strncmp(output_format_str, "csv", 4) == 0) { - cli.output_bin = false; - } else if (strncmp(output_format_str, "bin", 4) == 0) { - cli.output_bin = true; - if (!cli.output_file) { - // can't dump binary file to std-out - return false; - } - } else { - return false; - } - } - return true; -} - - -void -readBinFile(const char* filename, size_t start, size_t expected_size, std::vector& buf) -{ - // open the file: - std::streampos fileSize; - std::ifstream file(filename, std::ios::binary); - - // get its size: - file.seekg(0, std::ios::end); - fileSize = file.tellg(); - file.seekg(0, std::ios::beg); - - file.seekg(start); - fileSize -= start; - if ((size_t)fileSize < expected_size) { - return; - } - - // read the data: - buf.resize(expected_size); - file.read((char*)buf.data(), expected_size); -} - - -void -writeUint8Txt(std::ostream& out, const uint8_t* data, size_t len) -{ - for (size_t i = 0; i < len; i++) { - out << (int)data[i] << " "; - if (((i+1) % 21 == 0)) { - out << std::endl; - } - } - out << std::endl; -} - - -int -quantize(const Cli& cli) -{ - std::vector srcbuf; - readBinFile(cli.input_file, cli.input_offset, cli.num_rows * cli.num_cols * sizeof(float), srcbuf); - if (srcbuf.size() == 0) { - std::cerr << "Failed to read expected amount of data from file " << cli.input_file - << std::endl; - return -1; - } - - size_t qsize = MlasQ4GemmPackBSize(cli.quant_type, cli.num_cols, cli.num_rows); - if (qsize == 0) { - std::cerr << "Int4 Quantization not yet supported on this platform!"; - return -1; - } - std::vector dstbuf(qsize); - MlasQ4GemmPackB(cli.quant_type, dstbuf.data(), (const float*)srcbuf.data(), cli.num_cols, - cli.num_rows, cli.num_cols); - - if (cli.output_bin) { - std::ofstream out(cli.output_file, std::ios::out | std::ios::binary); - if (!out) { - std::cerr << "Cannot open output file " << cli.output_file << std::endl; - return -1; - } - out.write((const char*)dstbuf.data(), dstbuf.size()); - } else { - std::streambuf* buf; - if (cli.output_file) { - std::ofstream out(cli.output_file, std::ios::out); - if (!out) { - std::cerr << "Cannot open output file " << cli.output_file << std::endl; - return -1; - } - buf = out.rdbuf(); - } else { - buf = std::cout.rdbuf(); - } -#if defined(__GNUC__) && __GNUC__ >= 12 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored \ - "-Wdangling-pointer" // TODO: suppress warning about dangling pointer until we have a fix - std::ostream stream(buf); -#pragma GCC diagnostic pop -#else - std::ostream stream(buf); -#endif - - writeUint8Txt(stream, dstbuf.data(), dstbuf.size()); - } - return 0; -} - -int -dequantize(const Cli& cli) -{ - size_t qsize = MlasQ4GemmPackBSize(cli.quant_type, cli.num_cols, cli.num_rows); - if (qsize == 0) { - std::cerr << "Int4 Quantization not yet supported on this platform!"; - return -1; - } - std::vector srcbuf; - readBinFile(cli.input_file, cli.input_offset, qsize, srcbuf); - if (srcbuf.size() == 0) { - std::cerr << "Failed to read expected amount of data from file " << cli.input_file - << std::endl; - return -1; - } - - std::vector dstbuf(cli.num_rows * cli.num_cols); - MlasQ4GemmUnPackB(cli.quant_type, dstbuf.data(), srcbuf.data(), cli.num_cols, cli.num_rows, - cli.num_cols); - - if (cli.output_bin) { - std::ofstream out(cli.output_file, std::ios::out | std::ios::binary); - if (!out) { - std::cerr << "Cannot open output file " << cli.output_file << std::endl; - return -1; - } - out.write((const char*)dstbuf.data(), std::streamsize(dstbuf.size()) * sizeof(float)); - } else { - std::streambuf* buf; - std::ofstream file_output_stream; - if (cli.output_file) { - file_output_stream.open(cli.output_file, std::ios::out); - if (file_output_stream.fail()) { - std::cerr << "Cannot open output file " << cli.output_file << std::endl; - return -1; - } - buf = file_output_stream.rdbuf(); - } else { - buf = std::cout.rdbuf(); - } - std::ostream stream(buf); - size_t lcount = 0; - for (float v : dstbuf) { - stream << v << " "; - if (++lcount >= 16) { - stream << std::endl; - lcount = 0; - } - } - stream << std::endl; - } - return 0; -} - - -int -main(int argc, char* argv[]) -{ - Cli cli; - if (!parseArgs(argc, argv, cli)) { - usage(argv[0]); - return -1; - } - if (cli.dqmode) { - return dequantize(cli); - } else { - return quantize(cli); - } -} diff --git a/onnxruntime/core/mlas/lib/q4common.h b/onnxruntime/core/mlas/lib/q4common.h deleted file mode 100644 index ed54b7a207067..0000000000000 --- a/onnxruntime/core/mlas/lib/q4common.h +++ /dev/null @@ -1,156 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4common.h - -Abstract: - - Define int4 block quantization types. - - Int4 block quantization is used to compress weight tensors of large - language models. It takes a number (must be multiple of 32) of floating - point values, calculates their quantization parameters, and saves - the parameters and the quantized data in a blob. ---*/ - -#include "core/common/common.h" - -#include "mlas_q4.h" -#include "mlasi.h" - -#include -#include - -// -// Functions for locating data from a quantized blob -// -template -MLAS_FORCEINLINE -float& -MlasQ4BlkScale(uint8_t* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -template -MLAS_FORCEINLINE -float -MlasQ4BlkScale(const uint8_t* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -template -uint8_t& -MlasQ4BlkZeroPoint(uint8_t* BlkPtr); - -template -uint8_t -MlasQ4BlkZeroPoint(const uint8_t* BlkPtr); - -template -MLAS_FORCEINLINE -uint8_t* -MlasQ4BlkData(uint8_t* BlkPtr) -{ - return BlkPtr + sizeof(float); -} - -template -MLAS_FORCEINLINE -const uint8_t* -MlasQ4BlkData(const uint8_t* BlkPtr) -{ - return BlkPtr + sizeof(float); -} - -/** - * @brief Every block quantization type, its block size (BlkLen) - * Must be multiple of 32! - */ -constexpr size_t MLAS_QUANT4_BLK_UNIT = 32; - -/** - * @brief Representing int4 quantize type, block quant type 0: - * - * Block size 32, use 32 fp32 numbers to find quantization parameter: - * scale (fp 32) and no zero point, then quantize the numbers - * into int4. The resulting blob takes 16 + 4 = 20 bytes. - */ -struct MLAS_Q4TYPE_BLK0 { - static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT; - static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); -}; - -/** - * @brief Representing int4 quantize type, block quant type 1: - * - * Block size 32, use 32 fp32 numbers to find quantization parameter: - * scale (fp 32) and zero point (int8), and then quantize the numbers - * into int4. The resulting blob takes 16 + 5 = 21 bytes. - * - * So far this is the only type that includes a zero-point value. - * Maybe we should consider store the quantization parameters seperatedly. - */ -struct MLAS_Q4TYPE_BLK1 { - static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT; - static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float) + sizeof(uint8_t); -}; - -template<> -inline uint8_t& -MlasQ4BlkZeroPoint(uint8_t* BlkPtr) -{ - return *(BlkPtr + sizeof(float)); -} - -template<> -inline uint8_t -MlasQ4BlkZeroPoint(const uint8_t* BlkPtr) -{ - return *(BlkPtr + sizeof(float)); -} - -template<> -inline uint8_t* -MlasQ4BlkData(uint8_t* BlkPtr) -{ - return BlkPtr + sizeof(float) + sizeof(uint8_t); -} - -template<> -inline const uint8_t* -MlasQ4BlkData(const uint8_t* BlkPtr) -{ - return BlkPtr + sizeof(float) + sizeof(uint8_t); -} - -/** - * @brief Representing int4 quantize type, block quant type 2: - * - * Block size 64, use 64 fp32 numbers to find quantization parameter: - * scale (fp 32) and no zero point, then quantize the numbers - * into int4. The resulting blob takes 32 + 4 = 36 bytes. - */ -struct MLAS_Q4TYPE_BLK2 { - static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT * 2; - static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); -}; - - -/** - * @brief Representing int4 quantize type, block quant type 4: - * - * Block size 128, use 128 fp32 numbers to find quantization parameter: - * scale (fp 32) and no zero point, then quantize the numbers - * into int4. The resulting blob takes 32 + 4 = 36 bytes. - */ -struct MLAS_Q4TYPE_BLK4 { - static constexpr size_t BlkLen = MLAS_QUANT4_BLK_UNIT * 4; - static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float); -}; diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp deleted file mode 100644 index a734f53432bb6..0000000000000 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ /dev/null @@ -1,179 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4gemm.cpp - -Abstract: - - This module implements the fp32 matrix multiplication with compressed - weight tensor (right hand side). The assumption is the right hand side - tensor can be pre-packed and compressed using int-4 quantization to save - memory. ---*/ - -#include "q4gemm.h" - - -size_t -MLASCALL -MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K) -{ - if (GetMlasPlatform().Q8Q4GemmDispatch == nullptr) { - return 0; - } - switch (QType) { - case BlkQ4Zp8: - return MlasQ80BlkQuantSizeImpl(M, K); - case BlkQ4Sym64: - return MlasQ80BlkQuantSizeImpl(M, K); - case BlkQ4Sym128: - return MlasQ80BlkQuantSizeImpl(M, K); - default: - return MlasQ80BlkQuantSizeImpl(M, K); - } -} - - -void -MLASCALL -MlasQ80BlkQuant( - MLAS_BLK_QUANT_TYPE QType, - void* Qblob, - const float* A, - size_t M, - size_t K, - size_t lda, - MLAS_THREADPOOL* ThreadPool - ) -{ - auto* dispatch = GetMlasPlatform().Q8Q4GemmDispatch; - dispatch->Quants[QType](Qblob, A, M, K, lda, ThreadPool); -} - - -template -MLAS_FORCEINLINE -void -MlasQ4GemmBatchDriver( - MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const ParamBlockType* DataParams, - MLAS_THREADPOOL* ThreadPool - ) -{ - //const MLAS_Q4GEMM_DISPATCH* dispatch = MlasQ4GemmGetDispatch(); - //MLAS_Q4GEMM_OPERATION* operation = dispatch->Operation; - void (*operation)(const size_t, const ParamBlockType*, const size_t, const size_t, const size_t, - const size_t) = nullptr; - - if constexpr (std::is_same_v) - { - operation = GetMlasPlatform().FpQ4GemmDispatch->Operations[QType]; - } - else { - operation = GetMlasPlatform().Q8Q4GemmDispatch->Operations[QType]; - } - - if (ThreadPool == nullptr) { - for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - operation(K, Data, 0, M, 0, N); - } - return; - } - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - constexpr size_t StrideM = 128; - - size_t nc = N; - if (ThreadsPerGemm > 1) { - // more than one thread per GEMM - - const size_t BlockedM = MlasDivRoundup(M, StrideM); - const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min(nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * - MLAS_QGEMM_STRIDEN_THREAD_ALIGN); - } - } - const size_t StrideN = nc; - - const size_t ThreadCountM = MlasDivRoundup(M, StrideM); - const size_t ThreadCountN = MlasDivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; - - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; - - const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; - const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; - - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); - - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - - operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - }); -} - - -void -MLASCALL -MlasQ4GemmBatch( - MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool - ) -{ - MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); -} - -void -MLASCALL -MlasQ8Q4GemmBatch( - MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool - ) -{ - MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h deleted file mode 100644 index d16798eb8945f..0000000000000 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ /dev/null @@ -1,288 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4gemm.h - -Abstract: - - int4 block quantization gemm kernel template declarations. - - Int4 block quantization is used to compress weight tensors of large - language models. It takes a number (must be multiple of 32) of floating - point values, calculates their quantization parameters, and saves - the parameters and the quantized data in a blob. ---*/ - -#include "q4common.h" - - -template -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernel( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias -); - -template -MLAS_FORCEINLINE -void -MlasBlkQ4DequantB(float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); - - -template -MLAS_FORCEINLINE void -AddBiasAvx(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc); - - - -template -void MLASCALL -MlasQ4GemmOperation( - const size_t K, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, Q4TYPE::BlkLen); - const size_t ldb = k_blks * Q4TYPE::BlobSize; - const float* A = DataParams->A + RangeStartM * lda; - const uint8_t* PackedB = (const uint8_t*)DataParams->B; - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - const float* Bias = DataParams->Bias; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, (size_t)128); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; - const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; - float* c_blk = C + n; - const float* a_row = A; - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - auto RowsHandled = MlasQ4GemmKernel( - a_row, b_col, c_blk, RowsRemaining, CountN, K, lda, ldb, ldc, bias); - - if (DataParams->OutputProcessor != nullptr) { - DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } - return; - } - - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * Q4TYPE::BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, (size_t)StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; - const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; - float* c_blk = C + n; - const float* a_row = A; - - MlasBlkQ4DequantB(dequant_b, b_col, CountN, K, ldb); - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, - CountN, lda, ldc, 1.f); -#endif - - if (bias) { - AddBiasAvx(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->OutputProcessor != nullptr) { - DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, - RowsHandled, CountN, ldc); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } -} - -typedef -void -(MLAS_Q4GEMM_OPERATION)( - const size_t K, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ); - -struct MLAS_FPQ4GEMM_DISPATCH { - MLAS_Q4GEMM_OPERATION** Operations; -}; - -/** - * @brief Compute the size of a quantized block, one byte per value + fp32 scale - * @tparam QType - * @return - */ -template -constexpr size_t -Q8BlobUnitSize() -{ - return (QType::BlkLen + sizeof(float)); -} - -template -constexpr size_t -MlasQ80BlkQuantSizeImpl(size_t M, size_t K) -{ - const size_t KBlocks = MlasDivRoundup(K, QType::BlkLen); - - const size_t NumBlocks = M * KBlocks; - - return NumBlocks * Q8BlobUnitSize(); -} - -typedef -void -(MLAS_Q80_BLKQUANT)( - void* Qblob, - const float* A, - size_t M, - size_t K, - size_t lda, - MLAS_THREADPOOL* ThreadPool - ); - -template -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernel( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ); - - -template -void MLASCALL -MlasQ8Q4GemmOperation( - const size_t K, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ - const size_t k_blks = MlasDivRoundup(K, Q4TYPE::BlkLen); - const size_t ldb = k_blks * Q4TYPE::BlobSize; - const size_t lda = k_blks * Q8BlobUnitSize(); - const size_t ldc = DataParams->ldc; - - const int8_t* A = reinterpret_cast(DataParams->A) + RangeStartM * lda; - const uint8_t* PackedB = (const uint8_t*)DataParams->B; - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - const float* Bias = DataParams->Bias; - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, (size_t)128); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* bias = (Bias == nullptr) ? nullptr : Bias + RangeStartN + n; - const uint8_t* b_col = PackedB + (RangeStartN + n) * ldb; - float* c_blk = C + n; - const int8_t* a_row = A; - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - auto RowsHandled = MlasQ8Q4GemmKernel( - a_row, b_col, c_blk, RowsRemaining, CountN, K, lda, ldb, ldc, bias); - - if (DataParams->OutputProcessor != nullptr) { - DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, - RowsHandled, CountN, DataParams->ldc); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } -} - -typedef -void -(MLAS_Q8Q4GEMM_OPERATION)( - const size_t K, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ); - -struct MLAS_Q8Q4GEMM_DISPATCH { - MLAS_Q80_BLKQUANT** Quants; - MLAS_Q8Q4GEMM_OPERATION** Operations; -}; diff --git a/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp b/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp deleted file mode 100644 index f7af82ed12e0f..0000000000000 --- a/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp +++ /dev/null @@ -1,1509 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - q4gemm_avx512.cpp - -Abstract: - - This module implements the fp32 matrix multiplication with compressed - weight tensor (right hand side). The assumption is the right hand side - tensor can be pre-packed and compressed using int-4 quantization to save - memory. - Specificially on x64 avx512 ---*/ - -#include "q4gemm.h" - -#include -#include - -struct MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI { - static constexpr size_t StrideM = 256; -}; - -/** - * @brief Horizontally sum 4 vectors and store - * the results in the returned vector - */ -static MLAS_FORCEINLINE __m128 -FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) -{ - __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); - __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); - __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); - __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); - - __m512 acc_lo0123 = _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); - __m512 acc_hi0123 = _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23))); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23))); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - - __m256 acc_y = - _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); - return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); -} - - -template -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernelAvx512f( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - // We process 32 quantized values in a batch. - static_assert(MLAS_QUANT4_BLK_UNIT == 32); - static_assert(Q4Type::BlkLen % MLAS_QUANT4_BLK_UNIT == 0); - - const __m256i lowMask = _mm256_set1_epi8(0xF); - - for (size_t m = 0; m < CountM; m++) { - const auto* b_col = PackedB; - auto* sum_ptr = C; - const auto* bias_ptr = Bias; - - int64_t nblk = (int64_t)(CountN) - 4; - while (nblk >= 0) { - __m512 acc_lo0 = _mm512_setzero_ps(); - __m512 acc_lo1 = _mm512_setzero_ps(); - __m512 acc_lo2 = _mm512_setzero_ps(); - __m512 acc_lo3 = _mm512_setzero_ps(); - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - size_t ck = std::min(CountK - k, Q4Type::BlkLen); - - const float scale_v0 = MlasQ4BlkScale(b); - const float scale_v1 = MlasQ4BlkScale(b + ldb); - const float scale_v2 = MlasQ4BlkScale(b + ldb * 2); - const float scale_v3 = MlasQ4BlkScale(b + ldb * 3); - - const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); - const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); - const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); - const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); - - // Load A row vectors - uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - mask = mask >> 16; - __m512 av_hi = mask == 0 ? _mm512_setzero_ps() - : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); - - // Load B col vectors - const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); - const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); - const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); - const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); - - // expand 4b into byte array - __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); - __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); - __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); - __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - bytes0 = _mm256_sub_epi8( - bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); - bytes1 = _mm256_sub_epi8( - bytes1, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - const __m256i vx16_lo0 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - const __m256i vx16_hi0 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - const __m256i vx16_lo1 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - const __m256i vx16_hi1 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - const __m256i vx16_lo2 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - const __m256i vx16_hi2 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - const __m256i vx16_lo3 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - const __m256i vx16_hi3 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - __m512 s = _mm512_set1_ps(scale_v0); - bvf_lo0 = _mm512_mul_ps(bvf_lo0, s); - bvf_hi0 = _mm512_mul_ps(bvf_hi0, s); - s = _mm512_set1_ps(scale_v1); - bvf_lo1 = _mm512_mul_ps(bvf_lo1, s); - bvf_hi1 = _mm512_mul_ps(bvf_hi1, s); - s = _mm512_set1_ps(scale_v2); - bvf_lo2 = _mm512_mul_ps(bvf_lo2, s); - bvf_hi2 = _mm512_mul_ps(bvf_hi2, s); - s = _mm512_set1_ps(scale_v3); - bvf_lo3 = _mm512_mul_ps(bvf_lo3, s); - bvf_hi3 = _mm512_mul_ps(bvf_hi3, s); - - acc_lo0 = _mm512_fmadd_ps(bvf_lo0, av_lo, acc_lo0); - acc_lo0 = _mm512_fmadd_ps(bvf_hi0, av_hi, acc_lo0); - acc_lo1 = _mm512_fmadd_ps(bvf_lo1, av_lo, acc_lo1); - acc_lo1 = _mm512_fmadd_ps(bvf_hi1, av_hi, acc_lo1); - acc_lo2 = _mm512_fmadd_ps(bvf_lo2, av_lo, acc_lo2); - acc_lo2 = _mm512_fmadd_ps(bvf_hi2, av_hi, acc_lo2); - acc_lo3 = _mm512_fmadd_ps(bvf_lo3, av_lo, acc_lo3); - acc_lo3 = _mm512_fmadd_ps(bvf_hi3, av_hi, acc_lo3); - } - - b += Q4Type::BlobSize; - } - - __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); - if (Bias != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - - // move to next 4 columns - b_col += 4 * ldb; - sum_ptr += 4; - bias_ptr += 4; - nblk -= 4; - } - - // left over columns less than 4 ? - nblk += 4; - if (nblk > 0) { - __m512 acc_lo[4]{}; - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - size_t ck = std::min(CountK - k, Q4Type::BlkLen); - - float scale_v[4]; - const __m128i* b_ptr[4]; - for (int64_t nn = 0; nn < nblk; nn++) { - const auto* bb = b + ldb * nn; - scale_v[nn] = MlasQ4BlkScale(bb); - b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); - } - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); - - uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT - kklen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - mask = mask >> 16; - __m512 av_hi = mask == 0 - ? _mm512_setzero_ps() - : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); - - for (int64_t nn = 0; nn < nblk; nn++) { - const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); - __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); - bytes = _mm256_and_si256(lowMask, bytes); - - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - const auto* bb = b + ldb * nn; - const uint8_t zp = MlasQ4BlkZeroPoint(bb); - bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes = _mm256_sub_epi8(bytes, eight); - } - - // Convert to 16-bit int - const __m256i vx16_lo = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); - const __m256i vx16_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); - __m512 bvf_hi = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); - __m512 s = _mm512_set1_ps(scale_v[nn]); - bvf_lo = _mm512_mul_ps(bvf_lo, s); - bvf_hi = _mm512_mul_ps(bvf_hi, s); - - acc_lo[nn] = _mm512_fmadd_ps(bvf_lo, av_lo, acc_lo[nn]); - acc_lo[nn] = _mm512_fmadd_ps(bvf_hi, av_hi, acc_lo[nn]); - } - } - b += Q4Type::BlobSize; - } - - for (int64_t nn = 0; nn < nblk; nn++) { - sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); - sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; - } - } - - // Prepare pointers for the next row - C += ldc; - A += lda; - } - return CountM; -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernel( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, - ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernel( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, - ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernel( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, - ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ4GemmKernel( - const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ4GemmKernelAvx512f(A, PackedB, C, CountM, CountN, CountK, lda, - ldb, ldc, Bias); -} - - -MLAS_FORCEINLINE -void -Transpose16x16Avx512( - float* dest, - __m512i r0, - __m512i r1, - __m512i r2, - __m512i r3, - __m512i r4, - __m512i r5, - __m512i r6, - __m512i r7, - __m512i r8, - __m512i r9, - __m512i ra, - __m512i rb, - __m512i rc, - __m512i rd, - __m512i re, - __m512i rf) -{ - - __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; - - t0 = _mm512_unpacklo_epi32( - r0, r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 - t1 = _mm512_unpackhi_epi32( - r0, r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 - t2 = _mm512_unpacklo_epi32(r2, r3); // 32 48 33 49 ... - t3 = _mm512_unpackhi_epi32(r2, r3); // 34 50 35 51 ... - t4 = _mm512_unpacklo_epi32(r4, r5); // 64 80 65 81 ... - t5 = _mm512_unpackhi_epi32(r4, r5); // 66 82 67 83 ... - t6 = _mm512_unpacklo_epi32(r6, r7); // 96 112 97 113 ... - t7 = _mm512_unpackhi_epi32(r6, r7); // 98 114 99 115 ... - t8 = _mm512_unpacklo_epi32(r8, r9); // 128 ... - t9 = _mm512_unpackhi_epi32(r8, r9); // 130 ... - ta = _mm512_unpacklo_epi32(ra, rb); // 160 ... - tb = _mm512_unpackhi_epi32(ra, rb); // 162 ... - tc = _mm512_unpacklo_epi32(rc, rd); // 196 ... - td = _mm512_unpackhi_epi32(rc, rd); // 198 ... - te = _mm512_unpacklo_epi32(re, rf); // 228 ... - tf = _mm512_unpackhi_epi32(re, rf); // 230 ... - - r0 = _mm512_unpacklo_epi64(t0, t2); // 0 16 32 48 ... - r1 = _mm512_unpackhi_epi64(t0, t2); // 1 17 33 49 ... - r2 = _mm512_unpacklo_epi64(t1, t3); // 2 18 34 49 ... - r3 = _mm512_unpackhi_epi64(t1, t3); // 3 19 35 51 ... - r4 = _mm512_unpacklo_epi64(t4, t6); // 64 80 96 112 ... - r5 = _mm512_unpackhi_epi64(t4, t6); // 65 81 97 114 ... - r6 = _mm512_unpacklo_epi64(t5, t7); // 66 82 98 113 ... - r7 = _mm512_unpackhi_epi64(t5, t7); // 67 83 99 115 ... - r8 = _mm512_unpacklo_epi64(t8, ta); // 128 144 160 176 ... - r9 = _mm512_unpackhi_epi64(t8, ta); // 129 145 161 178 ... - ra = _mm512_unpacklo_epi64(t9, tb); // 130 146 162 177 ... - rb = _mm512_unpackhi_epi64(t9, tb); // 131 147 163 179 ... - rc = _mm512_unpacklo_epi64(tc, te); // 192 208 228 240 ... - rd = _mm512_unpackhi_epi64(tc, te); // 193 209 229 241 ... - re = _mm512_unpacklo_epi64(td, tf); // 194 210 230 242 ... - rf = _mm512_unpackhi_epi64(td, tf); // 195 211 231 243 ... - - t0 = - _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... - t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... - t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... - t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... - t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... - t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... - t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... - t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... - t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... - t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... - ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... - tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... - tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... - td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... - te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... - tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... - - r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 - r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 - r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 - r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 - r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... - r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... - r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... - r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... - r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... - r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... - ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... - rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... - rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... - rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... - re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... - rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 - - _mm512_storeu_si512(dest, r0); - dest += 16; - _mm512_storeu_si512(dest, r1); - dest += 16; - _mm512_storeu_si512(dest, r2); - dest += 16; - _mm512_storeu_si512(dest, r3); - dest += 16; - _mm512_storeu_si512(dest, r4); - dest += 16; - _mm512_storeu_si512(dest, r5); - dest += 16; - _mm512_storeu_si512(dest, r6); - dest += 16; - _mm512_storeu_si512(dest, r7); - dest += 16; - _mm512_storeu_si512(dest, r8); - dest += 16; - _mm512_storeu_si512(dest, r9); - dest += 16; - _mm512_storeu_si512(dest, ra); - dest += 16; - _mm512_storeu_si512(dest, rb); - dest += 16; - _mm512_storeu_si512(dest, rc); - dest += 16; - _mm512_storeu_si512(dest, rd); - dest += 16; - _mm512_storeu_si512(dest, re); - dest += 16; - _mm512_storeu_si512(dest, rf); - dest += 16; -} - - -template -MLAS_FORCEINLINE -void -BlkQ4DequantBAvx512f( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) -{ - const __m256i lowMask = _mm256_set1_epi8(0xF); - - const auto* b_col = PackedB; - - int64_t nblk = (int64_t)(CountN)-16; - while (nblk >= 0) { - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - size_t ck = std::min(CountK - k, Q4Type::BlkLen); - - const float scale_v0 = MlasQ4BlkScale(b); - const float scale_v1 = MlasQ4BlkScale(b + ldb); - const float scale_v2 = MlasQ4BlkScale(b + ldb * 2); - const float scale_v3 = MlasQ4BlkScale(b + ldb * 3); - const float scale_v4 = MlasQ4BlkScale(b + ldb * 4); - const float scale_v5 = MlasQ4BlkScale(b + ldb * 5); - const float scale_v6 = MlasQ4BlkScale(b + ldb * 6); - const float scale_v7 = MlasQ4BlkScale(b + ldb * 7); - const float scale_v8 = MlasQ4BlkScale(b + ldb * 8); - const float scale_v9 = MlasQ4BlkScale(b + ldb * 9); - const float scale_va = MlasQ4BlkScale(b + ldb * 10); - const float scale_vb = MlasQ4BlkScale(b + ldb * 11); - const float scale_vc = MlasQ4BlkScale(b + ldb * 12); - const float scale_vd = MlasQ4BlkScale(b + ldb * 13); - const float scale_ve = MlasQ4BlkScale(b + ldb * 14); - const float scale_vf = MlasQ4BlkScale(b + ldb * 15); - - const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); - const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); - const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); - const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); - const __m128i* b4ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 4); - const __m128i* b5ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 5); - const __m128i* b6ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 6); - const __m128i* b7ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 7); - const __m128i* b8ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 8); - const __m128i* b9ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 9); - const __m128i* baptr = (const __m128i*)MlasQ4BlkData(b + ldb * 10); - const __m128i* bbptr = (const __m128i*)MlasQ4BlkData(b + ldb * 11); - const __m128i* bcptr = (const __m128i*)MlasQ4BlkData(b + ldb * 12); - const __m128i* bdptr = (const __m128i*)MlasQ4BlkData(b + ldb * 13); - const __m128i* beptr = (const __m128i*)MlasQ4BlkData(b + ldb * 14); - const __m128i* bfptr = (const __m128i*)MlasQ4BlkData(b + ldb * 15); - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); - - // Load B col vectors - const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); - const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); - const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); - const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); - - // expand 4b into byte array - __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); - __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); - __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); - __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - bytes0 = _mm256_sub_epi8( - bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); - bytes1 = _mm256_sub_epi8( - bytes1, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - __m256i vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - __m256i vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - __m256i vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - __m256i vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - __m256i vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - __m256i vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - __m256i vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - __m256i vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - __m512 s = _mm512_set1_ps(scale_v0); - bvf_lo0 = _mm512_mul_ps(bvf_lo0, s); - bvf_hi0 = _mm512_mul_ps(bvf_hi0, s); - s = _mm512_set1_ps(scale_v1); - bvf_lo1 = _mm512_mul_ps(bvf_lo1, s); - bvf_hi1 = _mm512_mul_ps(bvf_hi1, s); - s = _mm512_set1_ps(scale_v2); - bvf_lo2 = _mm512_mul_ps(bvf_lo2, s); - bvf_hi2 = _mm512_mul_ps(bvf_hi2, s); - s = _mm512_set1_ps(scale_v3); - bvf_lo3 = _mm512_mul_ps(bvf_lo3, s); - bvf_hi3 = _mm512_mul_ps(bvf_hi3, s); - - // Load B col vectors - const __m128i bvi4_4 = _mm_loadu_si128(b4ptr++); - const __m128i bvi4_5 = _mm_loadu_si128(b5ptr++); - const __m128i bvi4_6 = _mm_loadu_si128(b6ptr++); - const __m128i bvi4_7 = _mm_loadu_si128(b7ptr++); - - // expand 4b into byte array - bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_4, 4), bvi4_4); - bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_5, 4), bvi4_5); - bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_6, 4), bvi4_6); - bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_7, 4), bvi4_7); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - bytes0 = _mm256_sub_epi8( - bytes0, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 4))); - bytes1 = _mm256_sub_epi8( - bytes1, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 5))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 6))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 7))); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo4 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hi4 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lo5 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hi5 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_lo6 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hi6 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lo7 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hi7 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - s = _mm512_set1_ps(scale_v4); - bvf_lo4 = _mm512_mul_ps(bvf_lo4, s); - bvf_hi4 = _mm512_mul_ps(bvf_hi4, s); - s = _mm512_set1_ps(scale_v5); - bvf_lo5 = _mm512_mul_ps(bvf_lo5, s); - bvf_hi5 = _mm512_mul_ps(bvf_hi5, s); - s = _mm512_set1_ps(scale_v6); - bvf_lo6 = _mm512_mul_ps(bvf_lo6, s); - bvf_hi6 = _mm512_mul_ps(bvf_hi6, s); - s = _mm512_set1_ps(scale_v7); - bvf_lo7 = _mm512_mul_ps(bvf_lo7, s); - bvf_hi7 = _mm512_mul_ps(bvf_hi7, s); - - // Load B col vectors - const __m128i bvi4_8 = _mm_loadu_si128(b8ptr++); - const __m128i bvi4_9 = _mm_loadu_si128(b9ptr++); - const __m128i bvi4_a = _mm_loadu_si128(baptr++); - const __m128i bvi4_b = _mm_loadu_si128(bbptr++); - - // expand 4b into byte array - bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_8, 4), bvi4_8); - bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_9, 4), bvi4_9); - bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_a, 4), bvi4_a); - bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_b, 4), bvi4_b); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - bytes0 = _mm256_sub_epi8( - bytes0, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 8))); - bytes1 = _mm256_sub_epi8( - bytes1, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 9))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 10))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 11))); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo8 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hi8 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lo9 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hi9 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_loa = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hia = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lob = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hib = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - s = _mm512_set1_ps(scale_v8); - bvf_lo8 = _mm512_mul_ps(bvf_lo8, s); - bvf_hi8 = _mm512_mul_ps(bvf_hi8, s); - s = _mm512_set1_ps(scale_v9); - bvf_lo9 = _mm512_mul_ps(bvf_lo9, s); - bvf_hi9 = _mm512_mul_ps(bvf_hi9, s); - s = _mm512_set1_ps(scale_va); - bvf_loa = _mm512_mul_ps(bvf_loa, s); - bvf_hia = _mm512_mul_ps(bvf_hia, s); - s = _mm512_set1_ps(scale_vb); - bvf_lob = _mm512_mul_ps(bvf_lob, s); - bvf_hib = _mm512_mul_ps(bvf_hib, s); - - // Load B col vectors - const __m128i bvi4_c = _mm_loadu_si128(bcptr++); - const __m128i bvi4_d = _mm_loadu_si128(bdptr++); - const __m128i bvi4_e = _mm_loadu_si128(beptr++); - const __m128i bvi4_f = _mm_loadu_si128(bfptr++); - - // expand 4b into byte array - bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_c, 4), bvi4_c); - bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_d, 4), bvi4_d); - bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_e, 4), bvi4_e); - bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_f, 4), bvi4_f); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - bytes0 = _mm256_sub_epi8( - bytes0, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 12))); - bytes1 = _mm256_sub_epi8( - bytes1, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 13))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 14))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 15))); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - vx16_lo0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - vx16_hi0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - vx16_lo1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - vx16_hi1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - vx16_lo2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - vx16_hi2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - vx16_lo3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - vx16_hi3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_loc = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hic = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lod = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hid = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_loe = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hie = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lof = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hif = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - s = _mm512_set1_ps(scale_vc); - bvf_loc = _mm512_mul_ps(bvf_loc, s); - bvf_hic = _mm512_mul_ps(bvf_hic, s); - s = _mm512_set1_ps(scale_vd); - bvf_lod = _mm512_mul_ps(bvf_lod, s); - bvf_hid = _mm512_mul_ps(bvf_hid, s); - s = _mm512_set1_ps(scale_ve); - bvf_loe = _mm512_mul_ps(bvf_loe, s); - bvf_hie = _mm512_mul_ps(bvf_hie, s); - s = _mm512_set1_ps(scale_vf); - bvf_lof = _mm512_mul_ps(bvf_lof, s); - bvf_hif = _mm512_mul_ps(bvf_hif, s); - Transpose16x16Avx512(FpData, _mm512_castps_si512(bvf_lo0), - _mm512_castps_si512(bvf_lo1), _mm512_castps_si512(bvf_lo2), - _mm512_castps_si512(bvf_lo3), _mm512_castps_si512(bvf_lo4), - _mm512_castps_si512(bvf_lo5), _mm512_castps_si512(bvf_lo6), - _mm512_castps_si512(bvf_lo7), _mm512_castps_si512(bvf_lo8), - _mm512_castps_si512(bvf_lo9), _mm512_castps_si512(bvf_loa), - _mm512_castps_si512(bvf_lob), _mm512_castps_si512(bvf_loc), - _mm512_castps_si512(bvf_lod), _mm512_castps_si512(bvf_loe), - _mm512_castps_si512(bvf_lof)); - if (kklen > 16) { - Transpose16x16Avx512(FpData + 16 * 16, _mm512_castps_si512(bvf_hi0), - _mm512_castps_si512(bvf_hi1), _mm512_castps_si512(bvf_hi2), - _mm512_castps_si512(bvf_hi3), _mm512_castps_si512(bvf_hi4), - _mm512_castps_si512(bvf_hi5), _mm512_castps_si512(bvf_hi6), - _mm512_castps_si512(bvf_hi7), _mm512_castps_si512(bvf_hi8), - _mm512_castps_si512(bvf_hi9), _mm512_castps_si512(bvf_hia), - _mm512_castps_si512(bvf_hib), _mm512_castps_si512(bvf_hic), - _mm512_castps_si512(bvf_hid), _mm512_castps_si512(bvf_hie), - _mm512_castps_si512(bvf_hif)); - } - FpData += 16 * kklen; - } - - b += Q4Type::BlobSize; - } - - // move to next 16 columns - b_col += 16 * ldb; - nblk -= 16; - } - - // left over columns less than 16 ? - nblk += 16; - if (nblk > 0) { - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - size_t ck = std::min(CountK - k, Q4Type::BlkLen); - - float scale_v[16]; - const __m128i* b_ptr[16]; - for (int64_t nn = 0; nn < nblk; nn++) { - const auto* bb = b + ldb * nn; - scale_v[nn] = MlasQ4BlkScale(bb); - b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); - } - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT, ck - kk); - __m512 bvf_lo[16]; - __m512 bvf_hi[16]; - for (int64_t nn = 0; nn < nblk; nn++) { - const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); - __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); - bytes = _mm256_and_si256(lowMask, bytes); - - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - const auto* bb = b + ldb * nn; - const uint8_t zp = MlasQ4BlkZeroPoint(bb); - bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes = _mm256_sub_epi8(bytes, eight); - } - - // Convert to 16-bit int - const __m256i vx16_lo = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); - const __m256i vx16_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); - - // Convert to 32-bit int -> float 32 - bvf_lo[nn] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); - bvf_hi[nn] = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); - const __m512 s = _mm512_set1_ps(scale_v[nn]); - bvf_lo[nn] = _mm512_mul_ps(bvf_lo[nn], s); - bvf_hi[nn] = _mm512_mul_ps(bvf_hi[nn], s); - } - for (int64_t nn = nblk; nn < 16; nn++) { - bvf_lo[nn] = _mm512_setzero_ps(); - bvf_hi[nn] = _mm512_setzero_ps(); - } - Transpose16x16Avx512( - FpData, _mm512_castps_si512(bvf_lo[0]), _mm512_castps_si512(bvf_lo[1]), - _mm512_castps_si512(bvf_lo[2]), _mm512_castps_si512(bvf_lo[3]), - _mm512_castps_si512(bvf_lo[4]), _mm512_castps_si512(bvf_lo[5]), - _mm512_castps_si512(bvf_lo[6]), _mm512_castps_si512(bvf_lo[7]), - _mm512_castps_si512(bvf_lo[8]), _mm512_castps_si512(bvf_lo[9]), - _mm512_castps_si512(bvf_lo[10]), _mm512_castps_si512(bvf_lo[11]), - _mm512_castps_si512(bvf_lo[12]), _mm512_castps_si512(bvf_lo[13]), - _mm512_castps_si512(bvf_lo[14]), _mm512_castps_si512(bvf_lo[15])); - if (kklen > 16) { - Transpose16x16Avx512( - FpData + 16 * 16, _mm512_castps_si512(bvf_hi[0]), - _mm512_castps_si512(bvf_hi[1]), _mm512_castps_si512(bvf_hi[2]), - _mm512_castps_si512(bvf_hi[3]), _mm512_castps_si512(bvf_hi[4]), - _mm512_castps_si512(bvf_hi[5]), _mm512_castps_si512(bvf_hi[6]), - _mm512_castps_si512(bvf_hi[7]), _mm512_castps_si512(bvf_hi[8]), - _mm512_castps_si512(bvf_hi[9]), _mm512_castps_si512(bvf_hi[10]), - _mm512_castps_si512(bvf_hi[11]), _mm512_castps_si512(bvf_hi[12]), - _mm512_castps_si512(bvf_hi[13]), _mm512_castps_si512(bvf_hi[14]), - _mm512_castps_si512(bvf_hi[15])); - } - FpData += 16 * kklen; - } - b += Q4Type::BlobSize; - } - } -} - - -template<> -MLAS_FORCEINLINE void -MlasBlkQ4DequantB( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) -{ - BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); -} - -template <> -MLAS_FORCEINLINE void -MlasBlkQ4DequantB( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) -{ - BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); -} - -template <> -MLAS_FORCEINLINE void -MlasBlkQ4DequantB( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) -{ - BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); -} - -template <> -MLAS_FORCEINLINE void -MlasBlkQ4DequantB( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) -{ - BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); -} - -/** - * @brief For testing purpose, - * Dequantize the data intp fp32, and then pack them for use - * in sgemm kernel. equivalent to MlasQ4GemmUnPackB and then - * MlasSgemmCopyPackB - * @param QType - * @param FpData - * @param PackedB - * @param CountN - * @param CountK - * @param ldb - */ -void -MlasBlkQ4DequantSgemmPackB( - MLAS_BLK_QUANT_TYPE QType, - float* FpData, - const uint8_t* PackedB, - size_t CountN, - size_t CountK, - size_t ldb) -{ - switch (QType) { - case BlkQ4Zp8: - return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); - case BlkQ4Sym64: - return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); - case BlkQ4Sym128: - return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); - default: - return BlkQ4DequantBAvx512f(FpData, PackedB, CountN, CountK, ldb); - } -} - -template<> -MLAS_FORCEINLINE -void -AddBiasAvx( - const float* Bias, - float* C, - size_t CountM, - size_t CountN, - size_t ldc - ) -{ - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - __m128 acc_x = _mm_loadu_ps(sum); - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias)); - _mm_storeu_ps(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } -} - - -static MLAS_Q4GEMM_OPERATION* Q4Operations_avx512vnni[] = { - MlasQ4GemmOperation, - MlasQ4GemmOperation, - MlasQ4GemmOperation, - nullptr, - MlasQ4GemmOperation -}; - -const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512 = { - Q4Operations_avx512vnni -}; - - -//////////////////////////////////////////////////////////// -// Block int8 quantization, currently we only -// implement symmetric quant, with no zero-point - -template -MLAS_FORCEINLINE void -MlasQ80BlkQuantRow(const float* A, void* Qblob, size_t size) -{ - static_assert(QType::BlkLen % 16 == 0); - const __m512 signBit = _mm512_set1_ps(-0.0f); - int8_t* blob = reinterpret_cast(Qblob); - for (size_t k = 0; k < size; k += QType::BlkLen) { - const size_t step = std::min(QType::BlkLen, size - k); - - __m512 maxAbs = _mm512_setzero_ps(); - for (size_t kk = 0; kk < step; kk += 16) { - const size_t klen = std::min(size_t(16), step - kk); - - uint32_t mask = 0xffff >> (16 - klen); - __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - // Compute max(abs(e)) for the block - maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0)); - } - - __m256 max8 = - _mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0)); - __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8)); - max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); - max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); - const float maxScalar = _mm_cvtss_f32(max4); - - // Quantize these floats - const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); - - const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; - const __m512 mul = _mm512_set1_ps(inverse_scale); - __m128i* dst = reinterpret_cast<__m128i*>(blob); - - for (size_t kk = 0; kk < step; kk += 16) { - const size_t klen = std::min(size_t(16), step - kk); - - uint32_t mask = 0xffff >> (16 - klen); - __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - v0 = _mm512_mul_ps(v0, mul); - - // Round to nearest integer - v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST); - - // Convert floats to integers - __m512i i0 = _mm512_cvtps_epi32(v0); - - // Convert int32 to int8 - _mm_storeu_si128(dst++, _mm512_cvtepi32_epi8(i0)); - } - if (step < QType::BlkLen) { - memset(blob + step, 0, QType::BlkLen - step); - } - blob += QType::BlkLen; - } -} - -template -void -Q80BlkQuant(void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool) -{ - const size_t parts = (size_t)ceil(double(M) * K / (16.0 * 1024)); - const size_t TargetThreadCnt = - std::max(std::min(parts, (size_t)MlasGetMaximumThreadCount(ThreadPool)), (size_t)1); - const size_t linesize = MlasQ80BlkQuantSizeImpl(1, K); - - size_t M_stride = MlasDivRoundup(M, TargetThreadCnt); - size_t threads = MlasDivRoundup(M, M_stride); - MlasTrySimpleParallel(ThreadPool, threads, [&](ptrdiff_t tid) { - const size_t m = tid * M_stride; - const float* src = A + lda * m; - uint8_t* dst = reinterpret_cast(Qblob) + m * linesize; - for (size_t i = 0; i < std::min(M_stride, M-m); i++) { - MlasQ80BlkQuantRow(src, dst, K); - src += lda; - dst += linesize; - } - }); -} - -static MLAS_Q80_BLKQUANT* Q80Quant_avx512vnni[] = { - Q80BlkQuant, - Q80BlkQuant, - Q80BlkQuant, - nullptr, - Q80BlkQuant -}; - - -static -MLAS_FORCEINLINE -__m128 -FoldAccumulators( - const __m256& acc0, - const __m256& acc1, - const __m256& acc2, - const __m256& acc3 - ) -{ - __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); - __m256 acc_hi01 = _mm256_unpackhi_ps(acc0, acc1); - __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); - __m256 acc_hi23 = _mm256_unpackhi_ps(acc2, acc3); - - __m256 acc_lo0123 = _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); - __m256 acc_hi0123 = _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23))); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23))); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - - return _mm_add_ps(_mm256_extractf32x4_ps(acc_lo0123, 0), _mm256_extractf32x4_ps(acc_lo0123, 1)); -} - -static inline float -mm256_reduce_add_ps(__m256& x) -{ - /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ - const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); - /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ - const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); - /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ - const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); - /* Conversion to float is a no-op on x86-64 */ - return _mm_cvtss_f32(x32); -} - - -template -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernelAvx512f( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - // We process 32 quantized values in a batch. - static_assert(MLAS_QUANT4_BLK_UNIT == 32); - static_assert(Q4Type::BlkLen % MLAS_QUANT4_BLK_UNIT == 0); - - const __m256i zero = _mm256_setzero_si256(); - const __m256i lowMask = _mm256_set1_epi8(0xF); - - for (size_t m = 0; m < CountM; m++) { - const uint8_t* b_col = PackedB; - auto* sum_ptr = C; - auto* bias_ptr = Bias; - - int64_t nblk = (int64_t)(CountN) - 4; - while (nblk >= 0) { - __m256 acc_lo0 = _mm256_setzero_ps(); - __m256 acc_lo1 = _mm256_setzero_ps(); - __m256 acc_lo2 = _mm256_setzero_ps(); - __m256 acc_lo3 = _mm256_setzero_ps(); - const int8_t* ablob = QuantA; - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - const float a_scale = *reinterpret_cast(ablob); - ablob += sizeof(float); - const float scale_v0 = MlasQ4BlkScale(b) * a_scale; - const float scale_v1 = MlasQ4BlkScale(b + ldb) * a_scale; - const float scale_v2 = MlasQ4BlkScale(b + ldb * 2) * a_scale; - const float scale_v3 = MlasQ4BlkScale(b + ldb * 3) * a_scale; - - const __m128i* b0ptr = (const __m128i*)MlasQ4BlkData(b); - const __m128i* b1ptr = (const __m128i*)MlasQ4BlkData(b + ldb); - const __m128i* b2ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 2); - const __m128i* b3ptr = (const __m128i*)MlasQ4BlkData(b + ldb * 3); - - for (size_t kk = 0; kk < Q4Type::BlkLen; kk += MLAS_QUANT4_BLK_UNIT) { - // Load A row vector - const __m256i a_bytes = _mm256_loadu_si256((const __m256i*)ablob); - ablob += MLAS_QUANT4_BLK_UNIT; - - // Load 4 B column vectors (quantized to int4 blobs) - const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); - const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); - const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); - const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); - - // expand 4b into byte array - __m256i bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); - __m256i bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); - __m256i bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); - __m256i bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - - // Subtract zero-point from the integers - if constexpr (std::is_same_v) { - bytes0 = _mm256_sub_epi8( - bytes0, _mm256_set1_epi8(MlasQ4BlkZeroPoint(b))); - bytes1 = _mm256_sub_epi8( - bytes1, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb))); - bytes2 = _mm256_sub_epi8( - bytes2, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 2))); - bytes3 = _mm256_sub_epi8( - bytes3, - _mm256_set1_epi8(MlasQ4BlkZeroPoint(b + ldb * 3))); - } else { - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // to use vnni unsigned x signed int, negate all negative - // b vals to make it all positive, and then also negate the - // corresponding a vals to compensate - const __m256i summed_pairs0 = _mm256_dpbusd_epi32( - zero, _mm256_sign_epi8(bytes0, bytes0), _mm256_sign_epi8(a_bytes, bytes0)); - const __m256i summed_pairs1 = _mm256_dpbusd_epi32( - zero, _mm256_sign_epi8(bytes1, bytes1), _mm256_sign_epi8(a_bytes, bytes1)); - const __m256i summed_pairs2 = _mm256_dpbusd_epi32( - zero, _mm256_sign_epi8(bytes2, bytes2), _mm256_sign_epi8(a_bytes, bytes2)); - const __m256i summed_pairs3 = _mm256_dpbusd_epi32( - zero, _mm256_sign_epi8(bytes3, bytes3), _mm256_sign_epi8(a_bytes, bytes3)); - - const __m256 sums0 = _mm256_cvtepi32_ps(summed_pairs0); - const __m256 sums1 = _mm256_cvtepi32_ps(summed_pairs1); - const __m256 sums2 = _mm256_cvtepi32_ps(summed_pairs2); - const __m256 sums3 = _mm256_cvtepi32_ps(summed_pairs3); - acc_lo0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sums0, acc_lo0); - acc_lo1 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v1), sums1, acc_lo1); - acc_lo2 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v2), sums2, acc_lo2); - acc_lo3 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v3), sums3, acc_lo3); - } - b += Q4Type::BlobSize; - } - - __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); - if (Bias != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - - // move to next 4 columns - b_col += 4 * ldb; - sum_ptr += 4; - bias_ptr += 4; - nblk -= 4; - } - - // left over columns less than 4 ? - nblk += 4; - if (nblk > 0) { - __m256 acc_lo[4]{}; - const int8_t* ablob = QuantA; - const auto* b = b_col; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - const float a_scale = *reinterpret_cast(ablob); - ablob += sizeof(float); - - float scale_v[4]; - const __m128i* b_ptr[4]; - for (int64_t nn = 0; nn < nblk; nn++) { - const auto* bb = b + ldb * nn; - scale_v[nn] = MlasQ4BlkScale(bb) * a_scale; - b_ptr[nn] = (const __m128i*)MlasQ4BlkData(bb); - } - - for (size_t kk = 0; kk < Q4Type::BlkLen; kk += MLAS_QUANT4_BLK_UNIT) { - const __m256i a_bytes = _mm256_loadu_si256((const __m256i*)ablob); - ablob += MLAS_QUANT4_BLK_UNIT; - - for (int64_t nn = 0; nn < nblk; nn++) { - const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); - __m256i b_bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); - b_bytes = _mm256_and_si256(lowMask, b_bytes); - - if constexpr (std::is_same_v) { - // Subtract zero-point from the integers - const auto* bb = b + ldb * nn; - const uint8_t zp = MlasQ4BlkZeroPoint(bb); - b_bytes = _mm256_sub_epi8(b_bytes, _mm256_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - b_bytes = _mm256_sub_epi8(b_bytes, eight); - } - - // to use vnni unsigned x signed int, negate all negative - // b vals to make it all positive, - const __m256i ax = _mm256_sign_epi8(b_bytes, b_bytes); - // and then also negate the corresponding a vals to compensate - const __m256i sy = _mm256_sign_epi8(a_bytes, b_bytes); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - const __m256 sum = _mm256_cvtepi32_ps(summed_pairs); - acc_lo[nn] = _mm256_fmadd_ps(_mm256_set1_ps(scale_v[nn]), sum, acc_lo[nn]); - } - } - b += Q4Type::BlobSize; - } - - for (int64_t nn = 0; nn < nblk; nn++) { - sum_ptr[nn] = mm256_reduce_add_ps(acc_lo[nn]); - sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; - } - } - - // Prepare pointers for the next row - C += ldc; - QuantA += lda; - } - return CountM; -} - - -template<> -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernel( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, - lda, ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernel( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, - lda, ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernel( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, - lda, ldb, ldc, Bias); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasQ8Q4GemmKernel( - const int8_t* QuantA, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias - ) -{ - return MlasQ8Q4GemmKernelAvx512f(QuantA, PackedB, C, CountM, CountN, CountK, - lda, ldb, ldc, Bias); -} - - -static MLAS_Q8Q4GEMM_OPERATION* Q8Q4Operations_avx512vnni[] = { - MlasQ8Q4GemmOperation, - MlasQ8Q4GemmOperation, - MlasQ8Q4GemmOperation, - nullptr, - MlasQ8Q4GemmOperation -}; - - -const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni = { - Q80Quant_avx512vnni, - Q8Q4Operations_avx512vnni -}; diff --git a/onnxruntime/core/mlas/lib/qdwconv.cpp b/onnxruntime/core/mlas/lib/qdwconv.cpp deleted file mode 100644 index 59f6877f70d56..0000000000000 --- a/onnxruntime/core/mlas/lib/qdwconv.cpp +++ /dev/null @@ -1,377 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qdwconv.cpp - -Abstract: - - This module implements the quantized integer depthwise convolution routines. - ---*/ - -#include "mlasi.h" - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const InputType* const* Input, - InputType InputZeroPoint, - const FilterType* Filter, - FilterType FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -{ - // - // TODO Modify MlasConvDepthwiseGetKernelOutputCnt() function if this kernel - // is further optimized. - // -#if defined(MLAS_SSE2_INTRINSICS) - const __m128i ZeroVector = _mm_setzero_si128(); - const __m128i InputZeroPointVector = _mm_set1_epi16(InputZeroPoint); - const __m128i FilterZeroPointVector = _mm_set1_epi16(FilterZeroPoint); -#elif defined(MLAS_NEON_INTRINSICS) - const uint8x8_t InputZeroPointVector = vdup_n_u8(uint8_t(InputZeroPoint)); - const uint8x8_t FilterZeroPointVector = vdup_n_u8(uint8_t(FilterZeroPoint)); -#elif defined(MLAS_LSX_INTRINSICS) - const __m128i ZeroVector = __lsx_vldi(0); - const __m128i InputZeroPointVector = __lsx_vreplgr2vr_h(InputZeroPoint); - const __m128i FilterZeroPointVector = __lsx_vreplgr2vr_h(FilterZeroPoint); -#endif - - while (OutputCount > 0) { - size_t ChannelOffset = 0; - size_t c = Channels; - -#if defined(MLAS_SSE2_INTRINSICS) - - while (c >= 8) { - __m128i Accumulator0 = _mm_setzero_si128(); - __m128i Accumulator1 = _mm_setzero_si128(); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - __m128i InputVector = _mm_loadl_epi64((const __m128i*)&Input[k][ChannelOffset]); - __m128i FilterVector = - _mm_loadl_epi64((const __m128i*)&Filter[ChannelKernelOffset]); - - if (std::is_signed::value) { - InputVector = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, InputVector), 8); - } else { - InputVector = _mm_unpacklo_epi8(InputVector, ZeroVector); - } - - if (std::is_signed::value) { - FilterVector = _mm_srai_epi16(_mm_unpacklo_epi8(ZeroVector, FilterVector), 8); - } else { - FilterVector = _mm_unpacklo_epi8(FilterVector, ZeroVector); - } - - InputVector = _mm_sub_epi16(InputVector, InputZeroPointVector); - FilterVector = _mm_sub_epi16(FilterVector, FilterZeroPointVector); - - // N.B. Emulate PMULLD functionality on SSE2 by computing the low - // and high parts of the result and interleaving the results. - __m128i MultiplyLowWords = _mm_mullo_epi16(InputVector, FilterVector); - __m128i MultiplyHighWords = _mm_mulhi_epi16(InputVector, FilterVector); - __m128i Multiply0 = _mm_unpacklo_epi16(MultiplyLowWords, MultiplyHighWords); - __m128i Multiply1 = _mm_unpackhi_epi16(MultiplyLowWords, MultiplyHighWords); - - Accumulator0 = _mm_add_epi32(Accumulator0, Multiply0); - Accumulator1 = _mm_add_epi32(Accumulator1, Multiply1); - ChannelKernelOffset += Channels; - } - - _mm_storeu_si128((__m128i*)&Output[0], Accumulator0); - _mm_storeu_si128((__m128i*)&Output[4], Accumulator1); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - -#elif defined(MLAS_NEON_INTRINSICS) - - while (c >= 8) { - int32x4_t Accumulator0 = vdupq_n_s32(0); - int32x4_t Accumulator1 = vdupq_n_s32(0); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - uint8x8_t InputVector = - vld1_u8(reinterpret_cast(&Input[k][ChannelOffset])); - uint8x8_t FilterVector = - vld1_u8(reinterpret_cast(&Filter[ChannelKernelOffset])); - - int16x8_t InputVector16; - if (std::is_signed::value) { - InputVector16 = vsubl_s8(vreinterpret_s8_u8(InputVector), - vreinterpret_s8_u8(InputZeroPointVector)); - } else { - InputVector16 = - vreinterpretq_s16_u16(vsubl_u8(InputVector, InputZeroPointVector)); - } - - int16x8_t FilterVector16; - if (std::is_signed::value) { - FilterVector16 = vsubl_s8(vreinterpret_s8_u8(FilterVector), - vreinterpret_s8_u8(FilterZeroPointVector)); - } else { - FilterVector16 = - vreinterpretq_s16_u16(vsubl_u8(FilterVector, FilterZeroPointVector)); - } - - Accumulator0 = vmlal_s16(Accumulator0, vget_low_s16(InputVector16), - vget_low_s16(FilterVector16)); -#if defined(MLAS_NEON64_INTRINSICS) - Accumulator1 = vmlal_high_s16(Accumulator1, InputVector16, FilterVector16); -#else - Accumulator1 = vmlal_s16(Accumulator1, vget_high_s16(InputVector16), - vget_high_s16(FilterVector16)); -#endif - - ChannelKernelOffset += Channels; - } - - vst1q_s32(&Output[0], Accumulator0); - vst1q_s32(&Output[4], Accumulator1); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } -#elif defined(MLAS_LSX_INTRINSICS) - - while (c >= 8) { - __m128i Accumulator0 = __lsx_vldi(0); - __m128i Accumulator1 = __lsx_vldi(0); - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - __m128i InputVector = __lsx_vld((const __m128i*)&Input[k][ChannelOffset], 0); - __lsx_vinsgr2vr_d(InputVector, 0, 1); - __m128i FilterVector = - __lsx_vld((const __m128i*)&Filter[ChannelKernelOffset], 0); - __lsx_vinsgr2vr_d(FilterVector, 0, 1); - - if (std::is_signed::value) { - InputVector = __lsx_vsrai_h(__lsx_vilvl_b(InputVector, ZeroVector), 8); - } else { - InputVector = __lsx_vilvl_b(ZeroVector, InputVector ); - } - - if (std::is_signed::value) { - FilterVector = __lsx_vsrai_h(__lsx_vilvl_b(FilterVector, ZeroVector), 8); - } else { - FilterVector = __lsx_vilvl_b(ZeroVector, FilterVector); - } - - InputVector = __lsx_vsub_h(InputVector, InputZeroPointVector); - FilterVector = __lsx_vsub_h(FilterVector, FilterZeroPointVector); - - // N.B. Emulate PMULLD functionality on LSX by computing the low - // and high parts of the result and interleaving the results. - __m128i MultiplyLowWords = __lsx_vmul_h(InputVector, FilterVector); - __m128i MultiplyHighWords = __lsx_vmuh_h(InputVector, FilterVector); - __m128i Multiply0 = __lsx_vilvl_h(MultiplyHighWords, MultiplyLowWords); - __m128i Multiply1 = __lsx_vilvh_h(MultiplyHighWords, MultiplyLowWords); - - Accumulator0 = __lsx_vadd_w(Accumulator0, Multiply0); - Accumulator1 = __lsx_vadd_w(Accumulator1, Multiply1); - ChannelKernelOffset += Channels; - } - - __lsx_vst(Accumulator0, (__m128i*)&Output[0], 0); - __lsx_vst(Accumulator1, (__m128i*)&Output[4], 0); - Output += 8; - - ChannelOffset += 8; - c -= 8; - } - -#endif - - while (c > 0) { - int32_t Accumulator = 0; - size_t ChannelKernelOffset = ChannelOffset; - - for (size_t k = 0; k < KernelSize; k++) { - int32_t InputValue = int32_t(Input[k][ChannelOffset]) - InputZeroPoint; - int32_t FilterValue = int32_t(Filter[ChannelKernelOffset]) - FilterZeroPoint; - - Accumulator += InputValue * FilterValue; - ChannelKernelOffset += Channels; - } - - *Output++ = Accumulator; - - ChannelOffset += 1; - c -= 1; - } - - Input += KernelSize; - OutputCount -= 1; - } -} - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const uint8_t* const* Input, - uint8_t InputZeroPoint, - const int8_t* Filter, - int8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const uint8_t* const* Input, - uint8_t InputZeroPoint, - const uint8_t* Filter, - uint8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const int8_t* const* Input, - int8_t InputZeroPoint, - const int8_t* Filter, - int8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -template -void -MLASCALL -MlasConvDepthwiseKernel( - const int8_t* const* Input, - int8_t InputZeroPoint, - const uint8_t* Filter, - uint8_t FilterZeroPoint, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ); - -void -MLASCALL -MlasConvDepthwise( - const void* const* Input, - int32_t InputZeroPoint, - bool InputIsSigned, - const void* Filter, - int32_t FilterZeroPoint, - bool FilterIsSigned, - int32_t* Output, - size_t Channels, - size_t OutputCount, - size_t KernelSize - ) -/*++ - -Routine Description: - - This routine implements the depthwise convolution operation. - - The input is supplied as an indirection buffer. Every pointer in the - indirection buffer points at a Channels length vector (either from the - input tensor or a vector of padding values). These are grouped in batches - of length KernelSize that are processed by the kernel to produce a single - output of length Channels. These batches are then repeated OutputCount - times. - - The filter tensor is organized in HW1O format, so the length of each row of - the filter tensor is Channels. The number of columns of the filter tensor - is KernelSize. - -Arguments: - - Input - Supplies an indirection buffer to the elements of the input tensor. - - InputZeroPoint - Supplies the zero point offset of the input tensor. - - InputIsSigned - Supplies true if the input tensor is signed data, else - false if the input tensor is unsigned data. - - Filter - Supplies the filter tensor. - - FilterZeroPoint - Supplies the zero point offset of the filter tensor. - - FilterIsSigned - Supplies true if the filter tensor is signed data, else - false if the filter tensor is unsigned data. - - Output - Supplies the output tensor in channels last format. - - Channels - Supplies the number of channels. - - OutputCount - Supplies the number of channel sized output elements to - produce. - - KernelSize - Supplies the total number of channel sized kernel elements to - consume. - -Return Value: - - None. - ---*/ -{ - if (InputIsSigned) { - if (FilterIsSigned) { - - GetMlasPlatform().ConvDepthwiseS8S8Kernel( - reinterpret_cast(Input), static_cast(InputZeroPoint), - reinterpret_cast(Filter), static_cast(FilterZeroPoint), - Output, Channels, OutputCount, KernelSize - ); - } else { - - GetMlasPlatform().ConvDepthwiseS8U8Kernel( - reinterpret_cast(Input), static_cast(InputZeroPoint), - reinterpret_cast(Filter), static_cast(FilterZeroPoint), - Output, Channels, OutputCount, KernelSize - ); - } - } else { - if (FilterIsSigned) { - - GetMlasPlatform().ConvDepthwiseU8S8Kernel( - reinterpret_cast(Input), static_cast(InputZeroPoint), - reinterpret_cast(Filter), static_cast(FilterZeroPoint), - Output, Channels, OutputCount, KernelSize - ); - } else { - - GetMlasPlatform().ConvDepthwiseU8U8Kernel( - reinterpret_cast(Input), static_cast(InputZeroPoint), - reinterpret_cast(Filter), static_cast(FilterZeroPoint), - Output, Channels, OutputCount, KernelSize - ); - } - } -} diff --git a/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp b/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp deleted file mode 100644 index 4985f91b64f36..0000000000000 --- a/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp +++ /dev/null @@ -1,621 +0,0 @@ -/*Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -Copyright 2019 Google LLC - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - * Neither the name Facebook nor the names of its contributors may be used to - endorse or promote products derived from this software without specific - prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -==============================================================================*/ -/* Modifications Copyright (c) Microsoft. */ - - -/*++ -Module Name: - - qdwconv_kernelsize.cpp - -Abstract: - - This module implements kernel of the quantized integer depthwise convolution with kernalsize 25. - ---*/ - -#include "mlasi.h" - -extern "C" { - -#if defined(MLAS_TARGET_ARM64) - -void -MLASCALL -MlasConvSymDepthwiseKernelSize25ArmU8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ) -{ - uint8_t const* const* IndirectBuf = (uint8_t const* const*)InputIndirection; - uint8_t* OutBuf = (uint8_t*)Output; - const uint8x16_t vu128 = vdupq_n_u8(128); - const int16x8_t voutput_zero_point = vld1q_dup_s16((int16_t const*)&PostProcessParams->OutputZeroPoint); - float32x4_t vscale_0123, vscale_4567, vscale_89AB, vscale_CDEF; - const bool is_per_channel = ((KernelFlags & MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE) != 0); - // Init them anyway due to some compiler will generate uninitialized warnings. - vscale_0123 = vscale_4567 = vscale_89AB = vscale_CDEF = vld1q_dup_f32(PostProcessParams->Scale); - while (OutputCount-- > 0) { - const uint8_t* i00 = IndirectBuf[0]; - const uint8_t* i01 = IndirectBuf[1]; - const uint8_t* i02 = IndirectBuf[2]; - const uint8_t* i03 = IndirectBuf[3]; - const uint8_t* i04 = IndirectBuf[4]; - const uint8_t* i05 = IndirectBuf[5]; - const uint8_t* i06 = IndirectBuf[6]; - const uint8_t* i07 = IndirectBuf[7]; - const uint8_t* i08 = IndirectBuf[8]; - const uint8_t* i09 = IndirectBuf[9]; - - const uint8_t* i10 = IndirectBuf[10]; - const uint8_t* i11 = IndirectBuf[11]; - const uint8_t* i12 = IndirectBuf[12]; - const uint8_t* i13 = IndirectBuf[13]; - const uint8_t* i14 = IndirectBuf[14]; - const uint8_t* i15 = IndirectBuf[15]; - const uint8_t* i16 = IndirectBuf[16]; - const uint8_t* i17 = IndirectBuf[17]; - const uint8_t* i18 = IndirectBuf[18]; - const uint8_t* i19 = IndirectBuf[19]; - - const uint8_t* i20 = IndirectBuf[20]; - const uint8_t* i21 = IndirectBuf[21]; - const uint8_t* i22 = IndirectBuf[22]; - const uint8_t* i23 = IndirectBuf[23]; - const uint8_t* i24 = IndirectBuf[24]; - - IndirectBuf += 25; - int32_t const* bias = PostProcessParams->Bias; - float const* scale = PostProcessParams->Scale; - for (size_t c = 0; c < Channels; c += 16) { - int8_t const* w = Filter + c; - int32x4_t vacc_0123 = vld1q_s32(bias); bias += 4; - int32x4_t vacc_4567 = vld1q_s32(bias); bias += 4; - int32x4_t vacc_89AB = vld1q_s32(bias); bias += 4; - int32x4_t vacc_CDEF = vld1q_s32(bias); bias += 4; - - // kernel pixel 0, 1 - const int8x16_t vi00 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i00))); i00 += 16; - const int8x16_t vk00 = vld1q_s8(w); w += Channels; - int16x8_t vprod_01234567 = vmull_s8(vget_low_s8(vi00), vget_low_s8(vk00)); - int16x8_t vprod_89ABCDEF = vmull_s8(vget_high_s8(vi00), vget_high_s8(vk00)); - - const int8x16_t vi01 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i01))); i01 += 16; - const int8x16_t vk01 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi01), vget_low_s8(vk01)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi01), vget_high_s8(vk01)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 2, 3 - const int8x16_t vi02 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i02))); i02 += 16; - const int8x16_t vk02 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi02), vget_low_s8(vk02)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi02), vget_high_s8(vk02)); - - const int8x16_t vi03 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i03))); i03 += 16; - const int8x16_t vk03 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi03), vget_low_s8(vk03)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi03), vget_high_s8(vk03)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 4, 5 - const int8x16_t vi04 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i04))); i04 += 16; - const int8x16_t vk04 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi04), vget_low_s8(vk04)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi04), vget_high_s8(vk04)); - - const int8x16_t vi05 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i05))); i05 += 16; - const int8x16_t vk05 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi05), vget_low_s8(vk05)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi05), vget_high_s8(vk05)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 6, 7 - const int8x16_t vi06 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i06))); i06 += 16; - const int8x16_t vk06 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi06), vget_low_s8(vk06)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi06), vget_high_s8(vk06)); - - const int8x16_t vi07 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i07))); i07 += 16; - const int8x16_t vk07 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi07), vget_low_s8(vk07)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi07), vget_high_s8(vk07)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 8, 9 - const int8x16_t vi08 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i08))); i08 += 16; - const int8x16_t vk08 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi08), vget_low_s8(vk08)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi08), vget_high_s8(vk08)); - - const int8x16_t vi09 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i09))); i09 += 16; - const int8x16_t vk09 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi09), vget_low_s8(vk09)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi09), vget_high_s8(vk09)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 10, 11 - const int8x16_t vi10 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i10))); i10 += 16; - const int8x16_t vk10 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi10), vget_low_s8(vk10)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi10), vget_high_s8(vk10)); - - const int8x16_t vi11 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i11))); i11 += 16; - const int8x16_t vk11 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi11), vget_low_s8(vk11)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi11), vget_high_s8(vk11)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 12, 13 - const int8x16_t vi12 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i12))); i12 += 16; - const int8x16_t vk12 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi12), vget_low_s8(vk12)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi12), vget_high_s8(vk12)); - - const int8x16_t vi13 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i13))); i13 += 16; - const int8x16_t vk13 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi13), vget_low_s8(vk13)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi13), vget_high_s8(vk13)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 14, 15 - const int8x16_t vi14 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i14))); i14 += 16; - const int8x16_t vk14 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi14), vget_low_s8(vk14)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi14), vget_high_s8(vk14)); - - const int8x16_t vi15 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i15))); i15 += 16; - const int8x16_t vk15 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi15), vget_low_s8(vk15)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi15), vget_high_s8(vk15)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 16, 17 - const int8x16_t vi16 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i16))); i16 += 16; - const int8x16_t vk16 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi16), vget_low_s8(vk16)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi16), vget_high_s8(vk16)); - - const int8x16_t vi17 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i17))); i17 += 16; - const int8x16_t vk17 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi17), vget_low_s8(vk17)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi17), vget_high_s8(vk17)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 18, 19 - const int8x16_t vi18 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i18))); i18 += 16; - const int8x16_t vk18 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi18), vget_low_s8(vk18)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi18), vget_high_s8(vk18)); - - const int8x16_t vi19 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i19))); i19 += 16; - const int8x16_t vk19 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi19), vget_low_s8(vk19)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi19), vget_high_s8(vk19)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 20, 21 - const int8x16_t vi20 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i20))); i20 += 16; - const int8x16_t vk20 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi20), vget_low_s8(vk20)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi20), vget_high_s8(vk20)); - - const int8x16_t vi21 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i21))); i21 += 16; - const int8x16_t vk21 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi21), vget_low_s8(vk21)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi21), vget_high_s8(vk21)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 22, 23 - const int8x16_t vi22 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i22))); i22 += 16; - const int8x16_t vk22 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi22), vget_low_s8(vk22)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi22), vget_high_s8(vk22)); - - const int8x16_t vi23 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i23))); i23 += 16; - const int8x16_t vk23 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi23), vget_low_s8(vk23)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi23), vget_high_s8(vk23)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 24 - const int8x16_t vi24 = vreinterpretq_s8_u8(veorq_u8(vu128, vld1q_u8(i24))); i24 += 16; - const int8x16_t vk24 = vld1q_s8(w); // w += Channels; no need to add - vprod_01234567 = vmull_s8(vget_low_s8(vi24), vget_low_s8(vk24)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi24), vget_high_s8(vk24)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - if (is_per_channel) { - vscale_0123 = vld1q_f32(scale); scale += 4; - vscale_4567 = vld1q_f32(scale); scale += 4; - vscale_89AB = vld1q_f32(scale); scale += 4; - vscale_CDEF = vld1q_f32(scale); scale += 4; - } - - // requantize - vacc_0123 = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_0123), vscale_0123)); - vacc_4567 = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_4567), vscale_4567)); - vacc_89AB = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_89AB), vscale_89AB)); - vacc_CDEF = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_CDEF), vscale_CDEF)); - - const int16x8_t vacc_01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_0123), vacc_4567), voutput_zero_point); - const int16x8_t vacc_89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_89AB), vacc_CDEF), voutput_zero_point); - uint8x16_t vout = vqmovun_high_s16(vqmovun_s16(vacc_01234567), vacc_89ABCDEF); - - vst1q_u8(OutBuf, vout); - OutBuf += 16; - } - } -} - -void -MLASCALL -MlasConvSymDepthwiseKernelSize25ArmS8S8( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ) -{ - int8_t const* const* IndirectBuf = (int8_t const* const*)InputIndirection; - int8_t* OutBuf = (int8_t*)Output; - const int16x8_t voutput_zero_point = - vld1q_dup_s16((int16_t const*)&PostProcessParams->OutputZeroPoint); - float32x4_t vscale_0123, vscale_4567, vscale_89AB, vscale_CDEF; - const bool is_per_channel = ((KernelFlags & MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE) != 0); - // Init them anyway due to some compiler will generate uninitialized warnings. - vscale_0123 = vscale_4567 = vscale_89AB = vscale_CDEF = vld1q_dup_f32(PostProcessParams->Scale); - while (OutputCount-- > 0) { - const int8_t* i00 = IndirectBuf[0]; - const int8_t* i01 = IndirectBuf[1]; - const int8_t* i02 = IndirectBuf[2]; - const int8_t* i03 = IndirectBuf[3]; - const int8_t* i04 = IndirectBuf[4]; - const int8_t* i05 = IndirectBuf[5]; - const int8_t* i06 = IndirectBuf[6]; - const int8_t* i07 = IndirectBuf[7]; - const int8_t* i08 = IndirectBuf[8]; - const int8_t* i09 = IndirectBuf[9]; - - const int8_t* i10 = IndirectBuf[10]; - const int8_t* i11 = IndirectBuf[11]; - const int8_t* i12 = IndirectBuf[12]; - const int8_t* i13 = IndirectBuf[13]; - const int8_t* i14 = IndirectBuf[14]; - const int8_t* i15 = IndirectBuf[15]; - const int8_t* i16 = IndirectBuf[16]; - const int8_t* i17 = IndirectBuf[17]; - const int8_t* i18 = IndirectBuf[18]; - const int8_t* i19 = IndirectBuf[19]; - - const int8_t* i20 = IndirectBuf[20]; - const int8_t* i21 = IndirectBuf[21]; - const int8_t* i22 = IndirectBuf[22]; - const int8_t* i23 = IndirectBuf[23]; - const int8_t* i24 = IndirectBuf[24]; - - IndirectBuf += 25; - int32_t const* bias = PostProcessParams->Bias; - float const* scale = PostProcessParams->Scale; - for (size_t c = 0; c < Channels; c += 16) { - int8_t const* w = Filter + c; - int32x4_t vacc_0123 = vld1q_s32(bias); bias += 4; - int32x4_t vacc_4567 = vld1q_s32(bias); bias += 4; - int32x4_t vacc_89AB = vld1q_s32(bias); bias += 4; - int32x4_t vacc_CDEF = vld1q_s32(bias); bias += 4; - - // kernel pixel 0, 1 - const int8x16_t vi00 = vld1q_s8(i00); i00 += 16; - const int8x16_t vk00 = vld1q_s8(w); w += Channels; - int16x8_t vprod_01234567 = vmull_s8(vget_low_s8(vi00), vget_low_s8(vk00)); - int16x8_t vprod_89ABCDEF = vmull_s8(vget_high_s8(vi00), vget_high_s8(vk00)); - - const int8x16_t vi01 = vld1q_s8(i01); i01 += 16; - const int8x16_t vk01 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi01), vget_low_s8(vk01)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi01), vget_high_s8(vk01)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 2, 3 - const int8x16_t vi02 = vld1q_s8(i02); i02 += 16; - const int8x16_t vk02 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi02), vget_low_s8(vk02)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi02), vget_high_s8(vk02)); - - const int8x16_t vi03 = vld1q_s8(i03); i03 += 16; - const int8x16_t vk03 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi03), vget_low_s8(vk03)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi03), vget_high_s8(vk03)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 4, 5 - const int8x16_t vi04 = vld1q_s8(i04); i04 += 16; - const int8x16_t vk04 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi04), vget_low_s8(vk04)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi04), vget_high_s8(vk04)); - - const int8x16_t vi05 = vld1q_s8(i05); i05 += 16; - const int8x16_t vk05 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi05), vget_low_s8(vk05)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi05), vget_high_s8(vk05)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 6, 7 - const int8x16_t vi06 = vld1q_s8(i06); i06 += 16; - const int8x16_t vk06 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi06), vget_low_s8(vk06)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi06), vget_high_s8(vk06)); - - const int8x16_t vi07 = vld1q_s8(i07); i07 += 16; - const int8x16_t vk07 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi07), vget_low_s8(vk07)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi07), vget_high_s8(vk07)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 8, 9 - const int8x16_t vi08 = vld1q_s8(i08); i08 += 16; - const int8x16_t vk08 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi08), vget_low_s8(vk08)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi08), vget_high_s8(vk08)); - - const int8x16_t vi09 = vld1q_s8(i09); i09 += 16; - const int8x16_t vk09 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi09), vget_low_s8(vk09)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi09), vget_high_s8(vk09)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 10, 11 - const int8x16_t vi10 = vld1q_s8(i10); i10 += 16; - const int8x16_t vk10 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi10), vget_low_s8(vk10)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi10), vget_high_s8(vk10)); - - const int8x16_t vi11 = vld1q_s8(i11); i11 += 16; - const int8x16_t vk11 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi11), vget_low_s8(vk11)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi11), vget_high_s8(vk11)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 12, 13 - const int8x16_t vi12 = vld1q_s8(i12); i12 += 16; - const int8x16_t vk12 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi12), vget_low_s8(vk12)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi12), vget_high_s8(vk12)); - - const int8x16_t vi13 = vld1q_s8(i13); i13 += 16; - const int8x16_t vk13 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi13), vget_low_s8(vk13)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi13), vget_high_s8(vk13)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 14, 15 - const int8x16_t vi14 = vld1q_s8(i14); i14 += 16; - const int8x16_t vk14 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi14), vget_low_s8(vk14)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi14), vget_high_s8(vk14)); - - const int8x16_t vi15 = vld1q_s8(i15); i15 += 16; - const int8x16_t vk15 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi15), vget_low_s8(vk15)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi15), vget_high_s8(vk15)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 16, 17 - const int8x16_t vi16 = vld1q_s8(i16); i16 += 16; - const int8x16_t vk16 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi16), vget_low_s8(vk16)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi16), vget_high_s8(vk16)); - - const int8x16_t vi17 = vld1q_s8(i17); i17 += 16; - const int8x16_t vk17 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi17), vget_low_s8(vk17)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi17), vget_high_s8(vk17)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 18, 19 - const int8x16_t vi18 = vld1q_s8(i18); i18 += 16; - const int8x16_t vk18 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi18), vget_low_s8(vk18)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi18), vget_high_s8(vk18)); - - const int8x16_t vi19 = vld1q_s8(i19); i19 += 16; - const int8x16_t vk19 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi19), vget_low_s8(vk19)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi19), vget_high_s8(vk19)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 20, 21 - const int8x16_t vi20 = vld1q_s8(i20); i20 += 16; - const int8x16_t vk20 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi20), vget_low_s8(vk20)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi20), vget_high_s8(vk20)); - - const int8x16_t vi21 = vld1q_s8(i21); i21 += 16; - const int8x16_t vk21 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi21), vget_low_s8(vk21)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi21), vget_high_s8(vk21)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 22, 23 - const int8x16_t vi22 = vld1q_s8(i22); i22 += 16; - const int8x16_t vk22 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmull_s8(vget_low_s8(vi22), vget_low_s8(vk22)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi22), vget_high_s8(vk22)); - - const int8x16_t vi23 = vld1q_s8(i23); i23 += 16; - const int8x16_t vk23 = vld1q_s8(w); w += Channels; - vprod_01234567 = vmlal_s8(vprod_01234567, vget_low_s8(vi23), vget_low_s8(vk23)); - vprod_89ABCDEF = vmlal_s8(vprod_89ABCDEF, vget_high_s8(vi23), vget_high_s8(vk23)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - // kernel pixel 24 - const int8x16_t vi24 = vld1q_s8(i24); i24 += 16; - const int8x16_t vk24 = vld1q_s8(w); // w += Channels; no need to add - vprod_01234567 = vmull_s8(vget_low_s8(vi24), vget_low_s8(vk24)); - vprod_89ABCDEF = vmull_s8(vget_high_s8(vi24), vget_high_s8(vk24)); - - vacc_0123 = vaddw_s16(vacc_0123, vget_low_s16(vprod_01234567)); - vacc_4567 = vaddw_s16(vacc_4567, vget_high_s16(vprod_01234567)); - vacc_89AB = vaddw_s16(vacc_89AB, vget_low_s16(vprod_89ABCDEF)); - vacc_CDEF = vaddw_s16(vacc_CDEF, vget_high_s16(vprod_89ABCDEF)); - - if (is_per_channel) { - vscale_0123 = vld1q_f32(scale); scale += 4; - vscale_4567 = vld1q_f32(scale); scale += 4; - vscale_89AB = vld1q_f32(scale); scale += 4; - vscale_CDEF = vld1q_f32(scale); scale += 4; - } - - // requantize - vacc_0123 = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_0123), vscale_0123)); - vacc_4567 = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_4567), vscale_4567)); - vacc_89AB = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_89AB), vscale_89AB)); - vacc_CDEF = vcvtnq_s32_f32(vmulq_f32(vcvtq_f32_s32(vacc_CDEF), vscale_CDEF)); - - const int16x8_t vacc_01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_0123), vacc_4567), voutput_zero_point); - const int16x8_t vacc_89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_89AB), vacc_CDEF), voutput_zero_point); - int8x16_t vout = vqmovn_high_s16(vqmovn_s16(vacc_01234567), vacc_89ABCDEF); - - vst1q_s8(OutBuf, vout); - OutBuf += 16; - } - } -} - -#endif -} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp deleted file mode 100644 index 859fcd049ac7d..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ /dev/null @@ -1,552 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm.cpp - -Abstract: - - This module implements the quantized integer matrix/matrix multiply - operation (QGEMM). - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the parameters to execute segments of a QGEMM operation on worker -// threads. -// - -struct MLAS_GEMM_QUANT_WORK_BLOCK { - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; -}; - -void -MlasGemmQuantThreaded( - const MLAS_GEMM_QUANT_WORK_BLOCK* WorkBlock, - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* Data, - ptrdiff_t ThreadId - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - QGEMM operation. - -Arguments: - - ThreadInfo - Supplies the structure containing the thread task partition info. - - Shape - Supplies the structure containing the GEMM input and output shapes. - - Data - Supplies the structure containing the GEMM input and output data layout - - ThreadId - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - const ptrdiff_t ThreadIdM = ThreadId / WorkBlock->ThreadCountN; - const ptrdiff_t ThreadIdN = ThreadId % WorkBlock->ThreadCountN; - - // - // Partition the operation along the M dimension. - // - - size_t RangeStartM; - size_t RangeCountM; - - const size_t M = Shape->M; - - MlasPartitionWork(ThreadIdM, WorkBlock->ThreadCountM, M, &RangeStartM, &RangeCountM); - - // - // Partition the operation along the N dimension. - // - - size_t RangeStartN; - size_t RangeCountN; - - const size_t N = Shape->N; - - const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - - MlasPartitionWork(ThreadIdN, WorkBlock->ThreadCountN, BlockedN, - &RangeStartN, &RangeCountN); - - RangeStartN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); - - // - // Dispatch the partitioned operation. - // - - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(Shape->AIsSigned, Shape->BIsSigned); - MLAS_GEMM_QUANT_OPERATION* GemmQuantOperation; - - if (Data->BIsPacked) { - GemmQuantOperation = GemmQuantDispatch->PackedOperation; - } else { - GemmQuantOperation = GemmQuantDispatch->Operation; - } - - GemmQuantOperation(Shape, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); -} - - -int32_t -MlasQgemmGetKernelOutputCnt( - bool AIsSigned, - bool BIsSigned - ) -{ - const auto* dispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); - return int32_t(dispatch->StrideM); -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// VC++ suggests we can attempt to make 'MlasBitsOfFp32' constexpr, but it is not valid. -#pragma warning(disable : 26451) -#endif - -void -MLASCALL -MlasGemmBatch( - const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, - MLAS_THREADPOOL* ThreadPool) -{ - const size_t M = Shape.M; - const size_t N = Shape.N; - const size_t K = Shape.K; - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - // - // Segment the operation across multiple threads. - // - // N.B. Currently, the operation is segmented as a 1D partition, which - // works okay for operations involving skinny matrices. - // - - MLAS_GEMM_QUANT_WORK_BLOCK WorkBlock; - - if (N > M) { - - const size_t BlockedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_QGEMM_STRIDEN_THREAD_ALIGN; - - if (size_t(ThreadsPerGemm) > BlockedN) { - ThreadsPerGemm = ptrdiff_t(BlockedN); - } - - WorkBlock.ThreadCountM = 1; - WorkBlock.ThreadCountN = ThreadsPerGemm; - - } else { - - if (size_t(ThreadsPerGemm) > M) { - ThreadsPerGemm = ptrdiff_t(M); - } - - WorkBlock.ThreadCountM = ThreadsPerGemm; - WorkBlock.ThreadCountN = 1; - } - TargetThreadCount = ThreadsPerGemm * BatchN; - - MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) { - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i); - }); -} - - -int32_t -MlasSymmQgemmGetKernelOutputCnt() -{ - const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; - return int32_t(dispatch->StrideM); -} - - -void -MLASCALL -MlasSymmQgemmBatch( - const MLAS_GEMM_QUANT_SHAPE_PARAMS& Shape, - const MLAS_SYMM_QGEMM_DATA_PARAMS* DataParams, - const size_t BatchN, - MLAS_THREADPOOL* ThreadPool - ) -{ - const size_t M = Shape.M; - const size_t N = Shape.N; - const size_t K = Shape.K; - const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; - - if (ThreadPool == nullptr) { - // So our caller handles threaded job partition. - // Call single threaded operation directly - auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().IsCurrentCoreArmv8NarrowLd(); - MLAS_SYMM_QGEMM_OPERATION* operation = - uarch ? dispatch->LitOperation : dispatch->BigOperation; - - for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - auto Data = &DataParams[gemm_i]; - operation(&Shape, Data, 0, M, 0, N); - } - return; - } - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - const size_t StrideM = dispatch->StrideM; - - size_t nc = N; - if ((size_t)MlasGetMaximumThreadCount(ThreadPool) > BatchN) { - // more than one thread per GEMM - - const size_t BlockedM = MlasDivRoundup(M, StrideM); - const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min(nc, MlasDivRoundup(nc, max_nc * MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * - MLAS_QGEMM_STRIDEN_THREAD_ALIGN); - } - } - const size_t StrideN = nc; - - const size_t ThreadCountM = MlasDivRoundup(M, StrideM); - const size_t ThreadCountN = MlasDivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; - - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().IsCurrentCoreArmv8NarrowLd(); - MLAS_SYMM_QGEMM_OPERATION* operation = - uarch ? dispatch->LitOperation : dispatch->BigOperation; - - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - auto Data = &DataParams[gemm_i]; - - const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; - const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; - - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(Shape.M - RangeStartM, (size_t)StrideM); - - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(Shape.N - RangeStartN, (size_t)StrideN); - - operation(&Shape, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - }); -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - -size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K, - bool AIsSigned, - bool BIsSigned - ) -/*++ - -Routine Description: - - This routine computes the number of bytes required to pack a matrix with - the supplied shape and type. - -Arguments: - - N - Supplies the number of columns of matrix B. - - K - Supplies the the number of rows of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - -Return Value: - - Returns the number of bytes required to pack the matrix, else zero if the - current implementation does not support packing. - ---*/ -{ - // - // Retrieve the packing parameters. - // - - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); - - size_t PackedK = GemmQuantDispatch->PackedK; - size_t PackedStrideK = GemmQuantDispatch->PackedStrideK; - - if (PackedStrideK == 0) { - return 0; - } - - // - // Compute the number of bytes required to hold the packed buffer. - // - - const size_t AlignedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - - const size_t BytesRequired = - (AlignedN * sizeof(int32_t)) + (AlignedN * AlignedK * sizeof(uint8_t)); - const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & - ~(BufferAlignment - 1); - - return AlignedBytesRequired; -} - -void -MLASCALL -MlasGemmPackB( - size_t N, - size_t K, - const uint8_t* B, - size_t ldb, - bool AIsSigned, - bool BIsSigned, - void* PackedB - ) -/*++ - -Routine Description: - - This routine packs the supplied matrix B to the supplied packed matrix B - buffer. The size of the packed buffer was obtained from MlasGemmPackBSize. - -Arguments: - - N - Supplies the number of columns of matrix B. - - K - Supplies the the number of rows of matrix B. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - BIsSigned - Supplies true if matrix B is signed data, else false if matrix - B is unsigned data. - - PackedB - Supplies the address of packed matrix B. - -Return Value: - - None. - ---*/ -{ - // - // Retrieve the packing parameters. - // - - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); - - size_t PackedK = GemmQuantDispatch->PackedK; - size_t PackedStrideK = GemmQuantDispatch->PackedStrideK; - - // - // Reserve and initialize storage for the column sum buffer to hold the sums - // of the elements along each of the columns. - // - - const size_t AlignedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - - int32_t* PackedColumnSumBuffer = (int32_t*)PackedB; - std::fill_n(PackedColumnSumBuffer, AlignedN, 0); - PackedB = PackedColumnSumBuffer + AlignedN; - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, PackedStrideK); - - // - // Step through each slice of matrix B along the N dimension. - // - - const size_t AlignedK = (CountK + PackedK - 1) & ~(PackedK - 1); - uint8_t* pb = (uint8_t*)PackedB; - size_t CountN; - - for (size_t n = 0; n < N; n += CountN) { - - constexpr size_t BatchedN = 128; - MLAS_DECLSPEC_ALIGN(int32_t ColumnSumBuffer[BatchedN], 64); - - CountN = std::min(N - n, BatchedN); - - GemmQuantDispatch->CopyPackBRoutine(pb, B + n, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); - - // - // Accumulate this batch of the column sum buffer into the packed - // buffer accumulators. - // - - for (size_t nn = 0; nn < CountN; nn++) { - PackedColumnSumBuffer[n + nn] += ColumnSumBuffer[nn]; - } - - pb += CountN * AlignedK; - } - - PackedB = (uint8_t*)PackedB + AlignedN * AlignedK; - B += ldb * CountK; - } -} - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// We can not make this function constexpr across different platforms -#pragma warning(disable : 26497) -#endif - -size_t -MLASCALL -MlasSymmQgemmPackBSize( - size_t N, - size_t K, - bool AIsSigned - ) -{ -#ifndef MLAS_TARGET_ARM64 - - // Only have arm64 impl for now - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - return 0; -#else - - // Only support s8s8 for now - if (!AIsSigned) { - return 0; - } - - const auto* Dispatch = GetMlasPlatform().SymmQgemmDispatch; - - size_t PackedK = Dispatch->PackedK; - - // - // Compute the number of bytes required to hold the packed buffer. - // - - const size_t AlignedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - - const size_t BytesRequired = - (AlignedN * sizeof(int32_t)) + (AlignedN * AlignedK * sizeof(uint8_t)); - const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & - ~(BufferAlignment - 1); - - return AlignedBytesRequired; -#endif // !MLAS_TARGET_ARM64 -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - - -void -MLASCALL -MlasSymmQgemmPackB( - size_t N, - size_t K, - const int8_t* B, - size_t ldb, - bool AIsSigned, - int32_t ZeroPointA, - void* PackedB - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - - const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch = GetMlasPlatform().SymmQgemmDispatch; - - const size_t AlignedN = - (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - int32_t* PackedColumnSumBuffer = (int32_t*)PackedB; - PackedB = PackedColumnSumBuffer + AlignedN; - - SymmQgemmDispatch->CopyPackBRoutine((uint8_t*)PackedB, (const uint8_t*)B, ldb, N, K, - PackedColumnSumBuffer, true); - for (size_t n = 0; n < AlignedN; n++) { - PackedColumnSumBuffer[n] *= -ZeroPointA; - } -} diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h deleted file mode 100644 index 1ef5b5f7411f0..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ /dev/null @@ -1,908 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm.h - -Abstract: - - This module defines the set of template functions to implement a kernel of - quantized integer matrix/matrix multiply operation (QGEMM). - - To implement a new kernel, template functions below need to be specialized: - MlasGemmQuantFixupZeroPointA - MlasGemmQuantFixupZeroPointB - MlasGemmQuantCopyPackA - MlasGemmQuantCopyPackB - MlasGemmQuantKernel - Specialization of MlasGemmQuantTryGemvKernel is optional. - - MlasGemmQuantOperation and MlasGemmQuantPackedOperation are shared kernel drivers. - MlasGemmQuantScaleSumBuffer is a helper function. - - It also includes the dispatcher logics. - ---*/ - -#pragma once - -#include "mlasi.h" - -#include -#include -#include - -// -// Define the default striding parameters used for the quantized integer -// matrix/matrix multiply operation. -// - -struct MLAS_GEMM_QUANT_STRIDES { - size_t M; - size_t N; - size_t K; -}; - -template -MLAS_FORCEINLINE -bool -MlasGemmQuantTryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool AIsSigned, - bool BIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return false; -} - -template -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - return ZeroPointA; -} - -template -int32_t constexpr -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - return ZeroPointB; -} - -template -MLAS_FORCEINLINE -void -MlasGemmQuantFixupZeroPointB( - const uint8_t* PackedZeroPointB, - int32_t* ZeroPointBBuffer, - size_t N, - bool BIsSigned -) -{ - int32_t ZeroPointB; - - for (size_t n = 0; n < N; n++) { - - ZeroPointB = typename KernelType::OffsetBType(PackedZeroPointB[n]); - ZeroPointB = MlasGemmQuantFixupZeroPointB(ZeroPointB, BIsSigned); - - ZeroPointBBuffer[n] = -ZeroPointB; - } - - // - // Fill the misaligned slots of the zero point buffer with zeros to guard - // against tools that check for uninitialized data usage. - // - - size_t AlignedN = (N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - - for (size_t n = N; n < AlignedN; n++) { - ZeroPointBBuffer[n] = 0; - } -} - -template -void -MlasGemmQuantCopyPackA( - typename KernelType::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned -); - -template -void -MlasGemmQuantCopyPackB( - typename KernelType::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned -); - -template -size_t -MlasGemmQuantKernel( - const typename KernelType::PackedAType* A, - const typename KernelType::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode -); - -/** - * @brief Usually a wrapper of assembly/intrinsic kernel - * of symmetric quant gemm - * @tparam KernelType - * @param A Left hand side matrix - * @param B Prepacked right hand side matrix - * @param C Result matrix - * @param PackedCountK Number of packed rows from B - * @param CountM Number of rows to process - * @param CountN Number of columns to process - * @param ldc Row stride of C - * @param lda Row stride of A - * @param ColumnSumVector Column sum of B scaled by zero point A - * @return Number of rows processed -*/ -template -size_t -MlasSymmQGemmKernel( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector -); - -inline -void -MlasGemmQuantScaleSumBuffer( - int32_t* Output, - const int32_t* Input, - size_t N, - int32_t Scale -) -{ - for (size_t n = 0; n < N; n++) { - Output[n] = Input[n] * Scale; - } -} - - -MLAS_FORCEINLINE -void -MlasGemmQuantScaleSumBuffer( - int32_t* SumBuffer, - size_t N, - int32_t Scale -) -{ - return MlasGemmQuantScaleSumBuffer(SumBuffer, SumBuffer, N, Scale); -} - -template -MLAS_FORCEINLINE -void -MlasGemmQuantThreadInit() -{ - constexpr MLAS_GEMM_QUANT_STRIDES Strides = KernelType::Strides; - constexpr size_t packASize = - UpAlignSize(Strides.M * Strides.K * sizeof(typename KernelType::PackedAType)); - constexpr size_t packBSize = - UpAlignSize(Strides.N * Strides.K * sizeof(typename KernelType::PackedBType)); - constexpr size_t rowSumSize = UpAlignSize(Strides.M * sizeof(int32_t)); - constexpr size_t colSumSize = UpAlignSize(Strides.N * sizeof(int32_t)); - constexpr size_t zpbSize = UpAlignSize(Strides.N * sizeof(int32_t)); - - constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides = KernelType::PackedStrides; - constexpr size_t packedASize = - UpAlignSize(PackedStrides.M * PackedStrides.K * sizeof(typename KernelType::PackedAType)); - - constexpr size_t bufsize = std::max(packASize + packBSize, packedASize) + rowSumSize + colSumSize + zpbSize; - - MlasThreadedBufAlloc(bufsize); -} - -template -void -MlasGemmQuantOperation( - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - Shape - Supplies the structure containing the GEMM input and output shapes. - - Data - Supplies the structure containing the GEMM input and output data layout - - RangeStartM - Supplies the starting row index to output. - - RangeCountM - Supplies the number of rows to output. - - RangeStartN - Supplies the starting column index to output. - - RangeCountN - Supplies the number of columns to output. - -Return Value: - - None. - ---*/ -{ - constexpr MLAS_GEMM_QUANT_STRIDES Strides = KernelType::Strides; - constexpr size_t packASize = - UpAlignSize(Strides.M * Strides.K * sizeof(typename KernelType::PackedAType)); - constexpr size_t packBSize = - UpAlignSize(Strides.N * Strides.K * sizeof(typename KernelType::PackedBType)); - constexpr size_t rowSumSize = UpAlignSize(Strides.M * sizeof(int32_t)); - constexpr size_t colSumSize = UpAlignSize(Strides.N * sizeof(int32_t)); - - MlasGemmQuantThreadInit(); - - uint8_t* p = ThreadedBufHolder.get(); - typename KernelType::PackedAType* PanelA = - reinterpret_cast(p); - p += packASize; - typename KernelType::PackedBType* PanelB = - reinterpret_cast(p); - p += packBSize; - int32_t* RowSumBuffer = reinterpret_cast(p); - p += rowSumSize; - int32_t* ColumnSumBuffer = reinterpret_cast(p); - p += colSumSize; - int32_t* ZeroPointBBuffer = reinterpret_cast(p); - - - const size_t K = Shape->K; - - const size_t lda = Data->lda; - const size_t ldb = Data->ldb; - const size_t ldc = Data->ldc; - - const uint8_t* A = Data->A + RangeStartM * lda; - const uint8_t* B = (const uint8_t*)Data->B + RangeStartN; - int32_t* C = Data->C + RangeStartM * ldc + RangeStartN; - const uint8_t* PackedZeroPointB = Data->PerColumnZeroPoints ? - Data->ZeroPointB + RangeStartN : nullptr; - bool IsAccumulateMode = Shape->IsAccumulateMode; - - int32_t ZeroPointA = typename KernelType::OffsetAType(Data->ZeroPointA); - int32_t ZeroPointB = typename KernelType::OffsetBType(*Data->ZeroPointB); - - // - // Try to use a GEMV kernel if supported by this kernel type. - // - - if ((RangeCountM == 1) && - (ZeroPointA == 0) && (PackedZeroPointB == nullptr) && (ZeroPointB == 0) && - (Data->OutputProcessor == nullptr)) { - if (MlasGemmQuantTryGemvKernel(A, B, ldb, C, K, RangeCountN, Shape->AIsSigned, Shape->BIsSigned)) { - return; - } - } - - // - // Fixup the sign bit of the per-matrix zero point offset of matrix A if the - // kernel requires opposite-signed data. - // - - ZeroPointA = MlasGemmQuantFixupZeroPointA(ZeroPointA, Shape->AIsSigned); - - // - // Fixup the sign bit of the per-matrix zero point offset of matrix B if the - // data is the opposite format of the kernel implementation. This value is - // ignored if per-column zero point offsets are used instead. - // - - ZeroPointB = MlasGemmQuantFixupZeroPointB(ZeroPointB, Shape->BIsSigned); - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, Strides.K); - - const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < RangeCountN; n += CountN) { - - CountN = std::min(RangeCountN - n, Strides.N); - - // - // Fixup the sign bit of the per-column zero point offsets of matrix B - // if the data is the opposite format of the kernel implementation. - // - - if (PackedZeroPointB != nullptr) { - MlasGemmQuantFixupZeroPointB( - PackedZeroPointB + n, - ZeroPointBBuffer, - CountN, - Shape->BIsSigned); - } - - // - // Copy a panel of matrix B to a local packed buffer. - // - - MlasGemmQuantCopyPackB( - PanelB, - B + n, - ldb, - CountN, - CountK, - ColumnSumBuffer, - Shape->BIsSigned); - - MlasGemmQuantScaleSumBuffer(ColumnSumBuffer, CountN, -ZeroPointA); - - // - // Step through each slice of matrix A along the M dimension. - // - - int32_t* c = C + n; - size_t CountM; - - for (size_t m = 0; m < RangeCountM; m += CountM) { - - CountM = std::min(RangeCountM - m, Strides.M); - - // - // Copy a panel of matrix A to a local packed buffer. - // - - MlasGemmQuantCopyPackA( - PanelA, - A + m * lda, - lda, - CountM, - CountK, - RowSumBuffer, - Shape->AIsSigned); - - // - // Apply the global depth value constant without the ZeroPointB scaling from: - // - // (A[i] - ZeroPointA) * (B[i] - ZeroPointB) - // ==> - // A[i] * B[i] - A[i] * ZeroPointB - B[i] * ZeroPointA + ZeroPointA * ZeroPointB - // - // The ZeroPointB term is factored out and either applied below for per-matrix - // quantization or inside the kernel for per-column quantization. - // - - for (size_t mm = 0; mm < CountM; mm++) { - RowSumBuffer[mm] -= int32_t(CountK) * ZeroPointA; - } - - // - // Scale the row sums by the per-matrix zero point offset of matrix B. - // - - if (PackedZeroPointB == nullptr) { - MlasGemmQuantScaleSumBuffer(RowSumBuffer, CountM, -ZeroPointB); - } - - // - // Step through the rows of the local packed buffer. - // - - typename KernelType::PackedAType* pa = PanelA; - int32_t* RowSums = RowSumBuffer; - size_t RowsRemaining = CountM; - - bool ZeroMode = (k == 0) && !IsAccumulateMode; - bool PostProcess = (k + CountK == K); - - while (RowsRemaining > 0) { - - size_t RowsHandled = MlasGemmQuantKernel( - pa, - PanelB, - c, - PackedCountK, - RowsRemaining, - CountN, - ldc, - RowSums, - ColumnSumBuffer, - (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, - ZeroMode); - - if (PostProcess && Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + m + CountM - RowsRemaining, - RangeStartN + n, - RowsHandled, - CountN, - Data->ldc); - } - - c += ldc * RowsHandled; - pa += KernelType::PackedK * PackedCountK * RowsHandled; - RowSums += RowsHandled; - RowsRemaining -= RowsHandled; - } - } - } - - A += CountK; - B += CountK * ldb; - } -} - - -template -void -MlasGemmQuantPackedOperation( - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -/*++ - -Routine Description: - - This routine implements the quantized integer matrix/matrix multiply - operation (QGEMM). - -Arguments: - - Shape - Supplies the structure containing the GEMM input and output shapes. - - Data - Supplies the structure containing the GEMM input and output data layout - - RangeStartM - Supplies the starting row index to output. - - RangeCountM - Supplies the number of rows to output. - - RangeStartN - Supplies the starting column index to output. - - RangeCountN - Supplies the number of columns to output. - -Return Value: - - None. - ---*/ -{ - constexpr MLAS_GEMM_QUANT_STRIDES Strides = KernelType::PackedStrides; - constexpr size_t packASize = - UpAlignSize(Strides.M * Strides.K * sizeof(typename KernelType::PackedAType)); - constexpr size_t rowSumSize = UpAlignSize(Strides.M * sizeof(int32_t)); - constexpr size_t colSumSize = UpAlignSize(Strides.N * sizeof(int32_t)); - - MlasGemmQuantThreadInit(); - - uint8_t* p = ThreadedBufHolder.get(); - typename KernelType::PackedAType* PanelA = - reinterpret_cast(p); - p += packASize; - int32_t* RowSumBuffer = reinterpret_cast(p); - p += rowSumSize; - int32_t* ColumnSumBuffer = reinterpret_cast(p); - p += colSumSize; - int32_t* ZeroPointBBuffer = reinterpret_cast(p); - - const size_t K = Shape->K; - - const size_t lda = Data->lda; - const size_t ldc = Data->ldc; - - const uint8_t* A = Data->A + RangeStartM * lda; - const uint8_t* PackedB = (const uint8_t*)Data->B; - int32_t* C = Data->C + RangeStartM * ldc + RangeStartN; - const uint8_t* PackedZeroPointB = Data->PerColumnZeroPoints ? - Data->ZeroPointB + RangeStartN : nullptr; - bool IsAccumulateMode = Shape->IsAccumulateMode; - - int32_t ZeroPointA = typename KernelType::OffsetAType(Data->ZeroPointA); - int32_t ZeroPointB = typename KernelType::OffsetBType(*Data->ZeroPointB); - - // - // Fixup the sign bit of the per-matrix zero point offset of matrix A if the - // kernel requires signed data. - // - - ZeroPointA = MlasGemmQuantFixupZeroPointA(ZeroPointA, Shape->AIsSigned); - - // - // Fixup the sign bit of the per-matrix zero point offset of matrix B if the - // data is the opposite format of the kernel implementation. This value is - // ignored if per-column zero point offsets are used instead. - // - - ZeroPointB = MlasGemmQuantFixupZeroPointB(ZeroPointB, Shape->BIsSigned); - - // - // Extract the pointer to the column sum buffer from the packed matrix. - // - - const size_t AlignedN = - (Shape->N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - const int32_t* PackedColumnSumBuffer = (const int32_t*)PackedB; - PackedB = (const uint8_t*)(PackedColumnSumBuffer + AlignedN); - PackedColumnSumBuffer += RangeStartN; - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, Strides.K); - - const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; - - if (k > 0) { - std::fill_n(ColumnSumBuffer, Strides.N, 0); - } - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < RangeCountN; n += CountN) { - - CountN = std::min(RangeCountN - n, Strides.N); - - if (k == 0) { - MlasGemmQuantScaleSumBuffer(ColumnSumBuffer, PackedColumnSumBuffer + n, - CountN, -ZeroPointA); - } - - // - // Fixup the sign bit of the per-column zero point offsets of matrix B - // if the data is the opposite format of the kernel implementation. - // - - if (PackedZeroPointB != nullptr) { - MlasGemmQuantFixupZeroPointB( - PackedZeroPointB + n, - ZeroPointBBuffer, - CountN, - Shape->BIsSigned); - } - - // - // Step through each slice of matrix A along the M dimension. - // - - const uint8_t* b = PackedB + (RangeStartN + n) * - KernelType::PackedK * PackedCountK; - int32_t* c = C + n; - size_t CountM; - - for (size_t m = 0; m < RangeCountM; m += CountM) { - - CountM = std::min(RangeCountM - m, Strides.M); - - // - // Copy a panel of matrix A to a local packed buffer. - // - - MlasGemmQuantCopyPackA( - PanelA, - A + m * lda, - lda, - CountM, - CountK, - RowSumBuffer, - Shape->AIsSigned); - - // - // Apply the global depth value constant without the ZeroPointB scaling from: - // - // (A[i] - ZeroPointA) * (B[i] - ZeroPointB) - // ==> - // A[i] * B[i] - A[i] * ZeroPointB - B[i] * ZeroPointA + ZeroPointA * ZeroPointB - // - // The ZeroPointB term is factored out and either applied below for per-matrix - // quantization or inside the kernel for per-column quantization. - // - - for (size_t mm = 0; mm < CountM; mm++) { - RowSumBuffer[mm] -= int32_t(CountK) * ZeroPointA; - } - - // - // Scale the row sums by the per-matrix zero point offset of matrix B. - // - - if (PackedZeroPointB == nullptr) { - MlasGemmQuantScaleSumBuffer(RowSumBuffer, CountM, -ZeroPointB); - } - - // - // Step through the rows of the local packed buffer. - // - - typename KernelType::PackedAType* pa = PanelA; - int32_t* RowSums = RowSumBuffer; - size_t RowsRemaining = CountM; - - bool ZeroMode = (k == 0) && !IsAccumulateMode; - bool PostProcess = (k + CountK == K); - - while (RowsRemaining > 0) { - - size_t RowsHandled = MlasGemmQuantKernel( - pa, - b, - c, - PackedCountK, - RowsRemaining, - CountN, - ldc, - RowSums, - ColumnSumBuffer, - (PackedZeroPointB != nullptr) ? ZeroPointBBuffer : nullptr, - ZeroMode); - - if (PostProcess && Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + m + CountM - RowsRemaining, - RangeStartN + n, - RowsHandled, - CountN, - Data->ldc); - } - - c += ldc * RowsHandled; - pa += KernelType::PackedK * PackedCountK * RowsHandled; - RowSums += RowsHandled; - RowsRemaining -= RowsHandled; - } - } - } - - A += CountK; - PackedB = (const uint8_t*)PackedB + AlignedN * CountK; - } -} - -/** - * @brief Operation for Quantized GEMM where B is symmetrically - * quantized and packed matrix - * @param Shape - * @param Data - * @param RangeStartM - * @param RangeCountM - * @param RangeStartN - * @param RangeCountN -*/ -template -void -MlasSymmQGemmPackedOperation( - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_SYMM_QGEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -{ - - const size_t K = Shape->K; - - const size_t lda = Data->lda; - const size_t ldc = Data->ldc; - - const int8_t* PanelA = (const int8_t*)(Data->A) + RangeStartM * lda; - const int8_t* PackedB = (const int8_t*)Data->B; - int32_t* C = (int32_t*)(Data->C) + RangeStartM * ldc + RangeStartN; - - // - // Extract the pointer to the column sum buffer from the packed matrix. - // - const size_t AlignedN = - (Shape->N + MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_QGEMM_STRIDEN_THREAD_ALIGN - 1); - const int32_t* PackedColumnSumBuffer = (const int32_t*)PackedB; - PackedB = (const int8_t*)(PackedColumnSumBuffer + AlignedN); - PackedColumnSumBuffer += RangeStartN; - - const size_t PackedCountK = (K + KernelType::PackedK - 1) / KernelType::PackedK; - - // - // Apply the global depth value constant without the ZeroPointB scaling from: - // - // (A[i] - ZeroPointA) * (B[i] - ZeroPointB) - // ==> - // A[i] * B[i] - A[i] * ZeroPointB - B[i] * ZeroPointA + ZeroPointA * ZeroPointB - // - // ZeroPointB is zero, which makes this much simpler - // - - const int8_t* b = PackedB + RangeStartN * KernelType::PackedK * PackedCountK; - int32_t* c = C; - - auto pa = PanelA; - size_t RowsRemaining = RangeCountM; - - while (RowsRemaining > 0) { - size_t RowsHandled = MlasSymmQGemmKernel( - pa, b, c, PackedCountK, RowsRemaining, RangeCountN, ldc, lda, PackedColumnSumBuffer); - - c += ldc * RowsHandled; - pa += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } -} - - -// -// Quantized integer matrix/matrix dispatch structure. -// - -typedef -void -(MLAS_GEMM_QUANT_OPERATION)( - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_GEMM_QUANT_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ); - -typedef -void -(MLAS_SYMM_QGEMM_OPERATION)( - const MLAS_GEMM_QUANT_SHAPE_PARAMS* Shape, - const MLAS_SYMM_QGEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ); - -typedef -void -(MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE)( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ); - -struct MLAS_GEMM_QUANT_DISPATCH { - MLAS_GEMM_QUANT_OPERATION* Operation; - MLAS_GEMM_QUANT_OPERATION* PackedOperation; - MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; - size_t PackedK; - size_t PackedStrideK; - size_t StrideM; -}; - -struct MLAS_SYMM_QGEMM_DISPATCH { - MLAS_SYMM_QGEMM_OPERATION* LitOperation; /// running on little cores with narrow memory load - MLAS_SYMM_QGEMM_OPERATION* BigOperation; /// running on big cores with wider memory load - MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; - size_t StrideM; /**< num of rows processed by kernel at a time */ - size_t PackedK; -}; - -MLAS_FORCEINLINE -const MLAS_GEMM_QUANT_DISPATCH* -MlasGemmQuantGetDispatch( - bool AIsSigned, - bool BIsSigned -) -{ - const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault; - -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) - if (AIsSigned) { - GemmQuantDispatch = - BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; - } else { - GemmQuantDispatch = - BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; - } -#elif defined(MLAS_TARGET_ARM64) - if(BIsSigned) { - GemmQuantDispatch = AIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmU8S8Dispatch; - } else if(!AIsSigned) { - GemmQuantDispatch = GetMlasPlatform().GemmU8U8Dispatch; - } -#elif defined(MLAS_TARGET_ARM64EC) || (defined(MLAS_TARGET_ARM) && !defined(_MSC_VER)) - if(BIsSigned || !AIsSigned) { - GemmQuantDispatch = &MlasGemmU8X8DispatchNeon; - } -#elif defined(MLAS_TARGET_WASM_SIMD) - if (!AIsSigned) { - GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; - } -#elif defined(MLAS_TARGET_POWER) && (defined(__linux__) || defined(_AIX)) && defined(POWER10) && \ - ((defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ - (defined(__clang__) && (__clang_major__ >= 12))) - if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { - GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; - } -#endif - - if (nullptr == GemmQuantDispatch) { - std::stringstream ss; - ss << "Quant GEMM format: AIsSigned(" << AIsSigned << "), BIsSigned(" << BIsSigned - << ") is not supported on this device"; - MLAS_THROW_EX(std::invalid_argument, ss.str()); - } - - return GemmQuantDispatch; -} diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp deleted file mode 100644 index 479a82e712c5e..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp +++ /dev/null @@ -1,827 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_amx.cpp - -Abstract: - - This module implements QGEMM kernels for amx. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" -#include "amx_common.h" - - -#define TMM0 0 -#define TMM1 1 -#define TMM2 2 -#define TMM3 3 -#define TMM4 4 -#define TMM5 5 -#define TMM6 6 -#define TMM7 7 - -#define KPACK (4 / sizeof(type_t)) // Vertical K packing into Dword - -#define TILE_M 16 -#define TILE_N 16 -#define TILE_K 64 - -/******************************************************************* - * Packing and Gemm kernels for U8S8 AMX - ******************************************************************/ -struct MLAS_GEMM_U8S8_KERNEL_AMX { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = TILE_K; - - // Use smaller stride for debugging, - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{32, 128, 1024}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{32, 512, 2048}; -}; - -constexpr size_t MLAS_GEMM_U8S8_KERNEL_AMX::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_AMX::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_AMX::PackedStrides; - -extern "C" { - - void - MLASCALL - MlasGemmU8S8CopyPackAAmx( - uint8_t* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ); - - void - MLASCALL - MlasGemmU8S8CopyPackBAmx( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ); - - /* Fall back to AVX512VNNI when GEMM size is small - * TODO!! What if some future AMX chips does NOT support AVX512VNNI? - */ - size_t - MlasGemmU8S8KernelAvx512Vnni( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ); - -} - - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned - ) -{ - if (AIsSigned) { - ZeroPointA = (uint8_t)(ZeroPointA ^ 0x80); - } - - return ZeroPointA; -} - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (!BIsSigned) { - ZeroPointB = MLAS_GEMM_U8S8_KERNEL_AMX::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8S8CopyPackAAmx(D, A, lda, CountM, CountK, RowSumBuffer); -} - - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MlasGemmU8S8CopyPackBAmx(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); -} - - -// Tile configure structure -struct tileconfig_t { - uint8_t palette_id = 0; - uint8_t start_row = 0; - uint8_t reserved1[14] = {0}; - uint16_t colb[8] = {0}; - uint8_t reserved2[16] = {0}; - uint8_t rows[8] = {0}; - uint8_t reserved3[8] = {0}; -}; - -template <> -MLAS_FORCEINLINE -void -MlasGemmQuantThreadInit() -{ - constexpr MLAS_GEMM_QUANT_STRIDES Strides = MLAS_GEMM_U8S8_KERNEL_AMX::Strides; - constexpr size_t packASize = UpAlignSize( - Strides.M * Strides.K * sizeof(typename MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType)); - constexpr size_t packBSize = UpAlignSize( - Strides.N * Strides.K * sizeof(typename MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType)); - constexpr size_t rowSumSize = UpAlignSize(Strides.M * sizeof(int32_t)); - constexpr size_t colSumSize = UpAlignSize(Strides.N * sizeof(int32_t)); - constexpr size_t zpbSize = UpAlignSize(Strides.N * sizeof(int32_t)); - - constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides = MLAS_GEMM_U8S8_KERNEL_AMX::PackedStrides; - constexpr size_t packedASize = - UpAlignSize(PackedStrides.M * PackedStrides.K * - sizeof(typename MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType)); - - constexpr size_t bufsize = - std::max(packASize + packBSize, packedASize) + rowSumSize + colSumSize + zpbSize; - - MlasThreadedBufAlloc(bufsize); - - static thread_local struct tileconfig_t tc = {0}; - struct tileconfig_t current_tc = {0}; - tile_storeconfig(¤t_tc); - - if (tc.palette_id == 0 || (std::memcmp(¤t_tc.colb, &tc.colb, sizeof(uint16_t) * 8) != 0 && - std::memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { - // Filling tile configure structure. - tc.palette_id = 1; - for (int t = 0; t < 8; t++) { - tc.rows[t] = 16; - tc.colb[t] = 64; - } - - tile_loadconfig(&tc); - } -} - - -static inline -void -InitHalfTileWithRowColSums( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - -static inline -void -InitHalfTileWithRowColSumsZeroPoints( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const __m512i zeropoint, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[7])); - row0 = _mm512_add_epi32(colsum, row0); - row1 = _mm512_add_epi32(colsum, row1); - row2 = _mm512_add_epi32(colsum, row2); - row3 = _mm512_add_epi32(colsum, row3); - row4 = _mm512_add_epi32(colsum, row4); - row5 = _mm512_add_epi32(colsum, row5); - row6 = _mm512_add_epi32(colsum, row6); - row7 = _mm512_add_epi32(colsum, row7); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - - -static inline -void -InitTileWithRowColSumsZeroPoints( - int32_t* Tile, - size_t cntM, - uint16_t MaskN, - const int32_t* rowsum_ptr, - __m512i colsum, - __m512i zeropoint, - bool ZeroMode, - const int32_t* c_blk, - size_t ldc - ) -{ - for (size_t m = 0; m < cntM; m++){ - __m512i row = _mm512_set1_epi32(rowsum_ptr[0]); - row = _mm512_mullo_epi32(zeropoint, row); - row = _mm512_maskz_add_epi32(MaskN, colsum, row); - if (!ZeroMode){ - __m512i c = _mm512_maskz_loadu_epi32(MaskN, c_blk); - row = _mm512_maskz_add_epi32(MaskN, row, c); - } - _mm512_storeu_si512(Tile, row); - Tile += 16; - rowsum_ptr++; - c_blk += ldc; - } -} - -static inline -void -InitTileWithRowColSums( - int32_t* Tile, - size_t cntM, - uint16_t MaskN, - const int32_t* rowsum_ptr, - __m512i colsum, - bool ZeroMode, - const int32_t* c_blk, - size_t ldc - ) -{ - for (size_t m = 0; m < cntM; m++){ - __m512i row = _mm512_set1_epi32(rowsum_ptr[0]); - row = _mm512_maskz_add_epi32(MaskN, colsum, row); - if (!ZeroMode){ - __m512i c = _mm512_maskz_loadu_epi32(MaskN, c_blk); - row = _mm512_maskz_add_epi32(MaskN, row, c); - } - _mm512_storeu_si512(Tile, row); - Tile += 16; - rowsum_ptr++; - c_blk += ldc; - } -} - - -/** - * @brief move data from Tile buffer to C - * - */ -static inline -void -MoveTile(const int32_t* Tile, size_t cntM, uint16_t MaskN, int32_t* c_ptr, size_t ldc) -{ - for (size_t i = 0; i < cntM; i++){ - __m512i c = _mm512_maskz_loadu_epi32(MaskN, Tile); - Tile += TILE_N; - _mm512_mask_storeu_epi32(c_ptr, MaskN, c); - c_ptr += ldc; - } -} - - -template <> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* A, - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) -{ - // All 8 tile registers are utilized in the main block. - // We use Tile 4 - 7 as accumulators, use Tile 2,3 to load - // 32x64 block from A, and Tile 0,1 to load 64x32 block from B: - // B T0 B T1 - // A T2 T4 T6 - // A T3 T5 T7 - // - int32_t Tile4[TILE_M * TILE_N]; - int32_t Tile5[TILE_M * TILE_N]; - int32_t Tile6[TILE_M * TILE_N]; - int32_t Tile7[TILE_M * TILE_N]; - PackedCountK *= TILE_K; - - // Compute masks for left over N - // Values are incorrect when there is no leftover - auto neg = (0LL - static_cast(CountN)) & (2 * TILE_N - 1); - const uint32_t nmasks = 0xFFFFFFFFUL >> neg; - - if (CountM < 2 * TILE_M){ - constexpr uint16_t FullMask = 0xFFFF; - const int leftover_m = static_cast(CountM); - const int m0 = std::min(leftover_m, TILE_M); - const int m1 = std::max(leftover_m - TILE_M, 0); - - int32_t* c_blk = C; // C - beginning of the row - int32_t* c16_blk = C + ldc * TILE_M; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; // restart B - const int32_t* col_sum_ptr = ColumnSumBuffer; - const int32_t* zp_ptr = ZeroPointB; - - size_t n = CountN; - for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_loadu_si512(zp_ptr); - zp_ptr += TILE_N; - InitTileWithRowColSumsZeroPoints( - Tile4, m0, FullMask, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - if (m1 != 0){ - InitTileWithRowColSumsZeroPoints( - Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } - } else { - InitTileWithRowColSums( - Tile4, m0, FullMask, RowSumBuffer, colsum, - ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - if (m1 != 0){ - InitTileWithRowColSums( - Tile5, m1, FullMask, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } - } - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr) { - __m512i zeropoint = _mm512_loadu_si512(zp_ptr); - zp_ptr += TILE_N; - InitTileWithRowColSumsZeroPoints( - Tile6, m0, FullMask, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - if (m1 != 0){ - InitTileWithRowColSumsZeroPoints( - Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } - } else { - InitTileWithRowColSums( - Tile6, m0, FullMask, RowSumBuffer, colsum, - ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - if (m1 != 0){ - InitTileWithRowColSums( - Tile7, m1, FullMask, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } - } - - // Restart A from row start - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - tile_loadd(TMM0, b_blk, TILE_K); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - - tile_dpbusd(TMM4, TMM2, TMM0); - tile_dpbusd(TMM6, TMM2, TMM1); - if (m1 > 0){ - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM5, TMM3, TMM0); - tile_dpbusd(TMM7, TMM3, TMM1); - } - b_blk += TILE_N * TILE_K; - a_blk += TILE_K; - a_next_blk += TILE_K; - } - if (m0 == TILE_M) { - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); - - } else { - tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - - MoveTile(Tile4, m0, FullMask, c_blk, ldc); - MoveTile(Tile6, m0, FullMask, c_blk + TILE_N, ldc); - } - if (m1 != 0){ - tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); - MoveTile(Tile5, m1, FullMask, c16_blk, ldc); - tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); - MoveTile(Tile7, m1, FullMask, c16_blk + TILE_N, ldc); - } - c_blk += 2 * TILE_N; - c16_blk += 2 * TILE_N; - b_blk += PackedCountK * TILE_N; - } - - if (n != 0) { - const uint16_t nmask_high = static_cast(nmasks >> 16); - - __m512i colsum = _mm512_maskz_loadu_epi32(static_cast(nmasks), col_sum_ptr); - col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_maskz_loadu_epi32(static_cast(nmasks), zp_ptr); - zp_ptr += TILE_N; - InitTileWithRowColSumsZeroPoints( - Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - if (m1 > 0){ - InitTileWithRowColSumsZeroPoints( - Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } - } else { - InitTileWithRowColSums( - Tile4, m0, static_cast(nmasks), RowSumBuffer, colsum, - ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - if (m1 > 0){ - InitTileWithRowColSums( - Tile5, m1, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } - } - if (nmask_high != 0){ - colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); - if (ZeroPointB!=nullptr){ - __m512i zeropoint = _mm512_maskz_loadu_epi32(nmask_high, zp_ptr); - InitTileWithRowColSumsZeroPoints( - Tile6, m0, nmask_high, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - if (m1 > 0){ - InitTileWithRowColSumsZeroPoints( - Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } - } else { - InitTileWithRowColSums( - Tile6, m0, nmask_high, RowSumBuffer, colsum, - ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - if (m1 > 0){ - InitTileWithRowColSums( - Tile7, m1, nmask_high, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } - } - } - - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - tile_loadd(TMM0, b_blk, TILE_K); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - - tile_dpbusd(TMM4, TMM2, TMM0); - if (m1 > 0){ - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM5, TMM3, TMM0); - } - if (nmask_high != 0){ - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); - if (m1 > 0){ - tile_dpbusd(TMM7, TMM3, TMM1); - } - } - b_blk += TILE_N * TILE_K; - a_blk += TILE_K; - a_next_blk += TILE_K; - } - if ((static_cast(nmasks) & 0x8000) != 0 && m0 == TILE_M){ - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - } else { - tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - MoveTile(Tile4, m0, static_cast(nmasks), c_blk, ldc); - } - if (m1 > 0){ - tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); - MoveTile(Tile5, m1, static_cast(nmasks), c16_blk, ldc); - } - if (nmask_high != 0){ - tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - MoveTile(Tile6, m0, nmask_high, c_blk + TILE_N, ldc); - if (m1 > 0){ - tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); - MoveTile(Tile7, m1, nmask_high, c16_blk + TILE_N, ldc); - } - } - } - return CountM; - } - - - int32_t* c_blk = C; // C - beginning of the row - int32_t* c16_blk = C + ldc * TILE_M; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; // restart B - const int32_t* col_sum_ptr = ColumnSumBuffer; - const int32_t* zp_ptr = ZeroPointB; - - size_t n = CountN; - for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - // Restart A from row start - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - - if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - __m512i zeropoint = _mm512_loadu_si512(zp_ptr); - zp_ptr += TILE_N; - tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_si512(zp_ptr); - zp_ptr += TILE_N; - InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); - } else { - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); - } - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - - for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { - b_blk += TILE_N * TILE_K; - a_blk += TILE_K; - a_next_blk += TILE_K; - tile_dpbusd(TMM5, TMM3, TMM0); - tile_loadd(TMM0, b_blk, TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM7, TMM3, TMM1); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM4, TMM2, TMM0); - } - tile_dpbusd(TMM5, TMM3, TMM0); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_dpbusd(TMM7, TMM3, TMM1); - - b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); - - c_blk += 2 * TILE_N; - tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); - c16_blk += 2 * TILE_N; - } - - if (n != 0) { - const uint16_t nmask_high = static_cast(nmasks >> 16); - __m512i colsum = _mm512_maskz_loadu_epi32(static_cast(nmasks), col_sum_ptr); - col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_maskz_loadu_epi32(static_cast(nmasks), zp_ptr); - zp_ptr += TILE_N; - InitTileWithRowColSumsZeroPoints( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitTileWithRowColSumsZeroPoints( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } else { - InitTileWithRowColSums( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitTileWithRowColSums( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - } - if (nmask_high != 0){ - colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); - if (ZeroPointB != nullptr){ - __m512i zeropoint = _mm512_maskz_loadu_epi32(nmask_high, zp_ptr); - InitTileWithRowColSumsZeroPoints( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - InitTileWithRowColSumsZeroPoints( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } else { - InitTileWithRowColSums( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - InitTileWithRowColSums( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); - } - } - - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - tile_loadd(TMM0, b_blk, TILE_K); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - - tile_dpbusd(TMM4, TMM2, TMM0); - tile_dpbusd(TMM5, TMM3, TMM0); - - if (nmask_high != 0){ - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_dpbusd(TMM7, TMM3, TMM1); - - } - b_blk += TILE_N * TILE_K; - a_blk += TILE_K; - a_next_blk += TILE_K; - } - if ((static_cast(nmasks) & 0x8000) != 0){ - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); - - } else { - tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); - tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); - - MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); - MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); - } - if (nmask_high != 0){ - tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); - - MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); - MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); - } - } - - return 2 * TILE_M; -} - - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAmx = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8S8_KERNEL_AMX::PackedK, - MLAS_GEMM_U8S8_KERNEL_AMX::PackedStrides.K, - 32 // StridM -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp deleted file mode 100644 index a6dbe8defd0e4..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ /dev/null @@ -1,527 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_avx2.cpp - -Abstract: - - This module implements QGEMM kernels for avx2. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Stores a vector to transpose a 4x4 byte vector using vpshufb. -// - -MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint8_t MlasTranspose4x4BytesAvx[16], 16) = -{ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; - -// -// Define the prototypes of the AVX2/AVX512 routines written in assembly. -// - -extern "C" { - - void - MLASCALL - MlasGemmU8S8CopyPackAAvx2( - uint8_t* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ); - - void - MLASCALL - MlasGemmU8S8CopyPackBAvx2( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ); - - void - MLASCALL - MlasGemmU8U8CopyPackAAvx2( - int16_t* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ); - - void - MLASCALL - MlasGemmU8U8CopyPackBAvx2( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer - ); - - void - MLASCALL - MlasGemmS8CopyPackAAvx2Vnni( - uint8_t* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ); - - void - MLASCALL - MlasGemmU8CopyPackBAvx2Vnni( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer - ); - - void - MLASCALL - MlasGemmS8CopyPackBAvx2Vnni( - uint8_t* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer - ); -} - -struct MLAS_GEMM_U8S8_KERNEL_AVX2 -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 256, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 48, 256, 384 }; -}; - -constexpr size_t MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides; - -template<> -MLAS_FORCEINLINE -bool -MlasGemmQuantTryGemvKernel( - const uint8_t* A, - const uint8_t* B, - size_t ldb, - int32_t* C, - size_t CountK, - size_t CountN, - bool AIsSigned, - bool BIsSigned - ) -{ - if (!AIsSigned && BIsSigned) { - GetMlasPlatform().GemvU8S8Kernel(A, B, C, CountK, CountN, ldb); - return true; - } - - return false; -} - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (!BIsSigned) { - ZeroPointB = MLAS_GEMM_U8S8_KERNEL_AVX2::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); -} - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MlasGemmU8S8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_U8S8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return GetMlasPlatform().GemmU8S8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchAvx2 = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8S8_KERNEL_AVX2::PackedK, - MLAS_GEMM_U8S8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride -}; - -struct MLAS_GEMM_U8U8_KERNEL_AVX2 -{ - typedef int16_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 2; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 256, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 48, 256, 384 }; -}; - -constexpr size_t MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides; - - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8U8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); -} - -template<> -MLAS_FORCEINLINE -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - MlasGemmU8U8CopyPackBAvx2(D, B, ldb, CountN, CountK, ColumnSumBuffer); -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_U8U8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return GetMlasPlatform().GemmU8U8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2 = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8U8_KERNEL_AVX2::PackedK, - MLAS_GEMM_U8U8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride -}; - -// U8U8 AVX-VNNI-INT8 support -struct MLAS_GEMM_U8U8_KERNEL_AVX2VNNI { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 256, 128}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{48, 256, 384}; -}; - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmU8S8CopyPackAAvx2(D, A, lda, CountM, CountK, RowSumBuffer); -} - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - MlasGemmU8CopyPackBAvx2Vnni(D, B, ldb, CountN, CountK, ColumnSumBuffer); -} - -template <> -MLAS_FORCEINLINE - size_t - MlasGemmQuantKernel( - const MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedAType* A, - const MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return MlasGemmU8U8KernelAvx2Vnni(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2Vnni = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedK, - MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedStrides.K, - 6 // assembly kernel M stride -}; - -// S8S8 AVX-VNNI-INT8 support -struct MLAS_GEMM_S8S8_KERNEL_AVX2 { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 256, 128}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{48, 256, 384}; -}; - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackA( - MLAS_GEMM_S8S8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmS8CopyPackAAvx2Vnni(D, A, lda, CountM, CountK, RowSumBuffer); -} - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackB( - MLAS_GEMM_S8S8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - MlasGemmS8CopyPackBAvx2Vnni(D, B, ldb, CountN, CountK, ColumnSumBuffer); -} - -template <> -MLAS_FORCEINLINE - size_t - MlasGemmQuantKernel( - const MLAS_GEMM_S8S8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_S8S8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return GetMlasPlatform().GemmS8S8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchAvx2Vnni = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_S8S8_KERNEL_AVX2::PackedK, - MLAS_GEMM_S8S8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride -}; - -// S8U8 AVX-VNNI-INT8 support -struct MLAS_GEMM_S8U8_KERNEL_AVX2 { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 256, 128}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{48, 256, 384}; -}; - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackA( - MLAS_GEMM_S8U8_KERNEL_AVX2::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - MlasGemmS8CopyPackAAvx2Vnni(D, A, lda, CountM, CountK, RowSumBuffer); -} - -template <> -MLAS_FORCEINLINE void -MlasGemmQuantCopyPackB( - MLAS_GEMM_S8U8_KERNEL_AVX2::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned -) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - MlasGemmU8CopyPackBAvx2Vnni(D, B, ldb, CountN, CountK, ColumnSumBuffer); -} - -template <> -MLAS_FORCEINLINE - size_t - MlasGemmQuantKernel( - const MLAS_GEMM_S8U8_KERNEL_AVX2::PackedAType* A, - const MLAS_GEMM_S8U8_KERNEL_AVX2::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return GetMlasPlatform().GemmS8U8Kernel(A, B, C, PackedCountK, CountM, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchAvx2Vnni = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_S8U8_KERNEL_AVX2::PackedK, - MLAS_GEMM_S8U8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp deleted file mode 100644 index 8f4baaa0ffafc..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_default.cpp +++ /dev/null @@ -1,224 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_default.cpp - -Abstract: - - This module implements default QGEMM kernel. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -struct MLAS_GEMM_QUANT_KERNEL_DEFAULT -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 16, 128, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 16, 128, 128 }; -}; - -constexpr size_t MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_DEFAULT::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedStrides; - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned - ) -{ - if (AIsSigned) { - ZeroPointA = (uint8_t)(ZeroPointA ^ 0x80); - } - - return ZeroPointA; -} - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (BIsSigned) { - ZeroPointB = MLAS_GEMM_QUANT_KERNEL_DEFAULT::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - const size_t AlignedCountK = (CountK + MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK - 1) & - ~(MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK - 1); - - const uint8_t BitFlipValue = (AIsSigned ? 0x80 : 0); - - // - // Process a single row of matrix A in a loop. - // - - while (CountM-- > 0) { - - int32_t RowSum = 0; - - for (size_t k = 0; k < CountK; k++) { - - uint8_t a0 = A[k] ^ BitFlipValue; - D[k] = a0; - - RowSum += a0; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - D[k] = 0; - } - - *RowSumBuffer++ = RowSum; - - A += lda; - D += AlignedCountK; - } -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const size_t AlignedCountK = - (CountK + MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK - 1) & ~(MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK - 1); - const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); - - // - // Process a single column of matrix B in a loop. - // - - while (CountN-- > 0) { - - const uint8_t* b = B; - int32_t ColumnSum = 0; - - // - // Transpose the data from matrix B to the packed buffer. - // - - for (size_t k = 0; k < CountK; k++) { - - uint8_t b0 = b[0] ^ BitFlipValue; - D[k] = b0; - - ColumnSum += b0; - - b += ldb; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - D[k] = 0; - } - - *ColumnSumBuffer++ = ColumnSum; - - B += 1; - D += AlignedCountK; - } -} - -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedAType* A, - const MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - // - // Process a single column of matrix B in a loop. - // - - while (CountN-- > 0) { - - int32_t Accumulator = *RowSumBuffer; - - if (ZeroPointB != nullptr) { - Accumulator *= *ZeroPointB++; - } - - Accumulator += *ColumnSumBuffer++; - - const auto* a = A; - - for (size_t k = 0; k < PackedCountK; k++) { - - Accumulator += a[0] * B[0]; - Accumulator += a[1] * B[1]; - Accumulator += a[2] * B[2]; - Accumulator += a[3] * B[3]; - - a += 4; - B += 4; - } - - if (!ZeroMode) { - Accumulator += C[0]; - } - - C[0] = Accumulator; - C += 1; - } - - return 1; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault = { - MlasGemmQuantOperation, - nullptr, - nullptr, - MLAS_GEMM_QUANT_KERNEL_DEFAULT::PackedK, - 0, - MLAS_GEMM_QUANT_KERNEL_DEFAULT::Strides.M -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp deleted file mode 100644 index 7d5817335bd77..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_lsx.cpp +++ /dev/null @@ -1,531 +0,0 @@ -/*++ - -Copyright (C) 2023 Loongson Technology Corporation Limited. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_lsx.cpp - -Abstract: - - This module implements QGEMM kernels for LSX. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" -#include - -struct MLAS_GEMM_U8X8_KERNEL_LSX -{ - typedef int16_t PackedAType; - typedef int16_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 2; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_LSX::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_LSX::Strides; - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (!BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_LSX::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - const __m128i ZeroVector = __lsx_vrepli_d(0); - uint16_t val = 1; - const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); - uint8_t PaddedMatrixAData[8] = { 0 }; - - // - // Process a single row of matrix A in a loop. - // - - while (CountM > 0) { - - const uint8_t* a = A; - size_t k = CountK; - __m128i ReductionVector = ZeroVector; - - // - // Zero extend the source bytes to 16-bits and write to the packed - // buffer. - // - // The packed buffer has the same data ordering as the source bytes, - // but CountK is aligned up to a multiple of 2 to maintain 32-bit - // alignment. All extra bytes are zero-padded. - // - // These 16-bit values are also accumulated into an intermediate per-row - // accumulator. CountK cannot be greater than 128 to avoid overflowing - // these signed 16-bit accumulators. - // - - while (k >= 8) { - - __m128i Bytes = __lsx_vld((const __m128i*) & a[0], 0); - __lsx_vinsgr2vr_d(Bytes, 0, 1); - __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); - - ReductionVector = __lsx_vadd_h(ReductionVector, Words); - - __lsx_vst(Words, (__m128i*) & D[0], 0); - - a += 8; - D += 8; - k -= 8; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* padded = PaddedMatrixAData; - uint8_t* padded_end = padded + k; - - do { - padded[0] = a[0]; - padded++; - a++; - } while (padded < padded_end); - - __m128i Bytes = __lsx_vld((__m128i*)PaddedMatrixAData, 0); - __lsx_vinsgr2vr_d(Bytes, 0, 1); - __m128i Words = __lsx_vilvl_b(ZeroVector, Bytes); - - ReductionVector = __lsx_vadd_h(ReductionVector, Words); - - // - // Copy pairs of 16-bit values from the vector to the packed - // buffer and rotate the vector for the next iteration. - // - - for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { - __lsx_vstelm_w(Words, (int32_t*)D, 0 , 0); - D += 2; - Words = __lsx_vshuf4i_w(Words, 0x39); //(0, 3, 2, 1) - } - } - - // - // Reduce the partial accumulators. - // - __m128i tmp1 = ZeroVector, tmp2 = ZeroVector; - tmp1 = __lsx_vmaddwev_w_h(tmp1, ReductionVector, OnesWordBroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, ReductionVector, OnesWordBroadcast); - ReductionVector = __lsx_vadd_w(tmp1, tmp2); - ReductionVector = __lsx_vadd_w(ReductionVector, - __lsx_vshuf4i_w(ReductionVector, 0xee)); - ReductionVector = __lsx_vadd_w(ReductionVector, - __lsx_vshuf4i_w(ReductionVector, 0x11)); - - __lsx_vstelm_w(ReductionVector, RowSumBuffer++, 0 , 0); - - A += lda; - CountM -= 1; - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessLSX( - MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, - __m128i BytesRow0, - __m128i BytesRow1, - __m128i BitFlipVector, - __m128i ColumnSums[2] -) -{ - __m128i BytesInterleaved = __lsx_vilvl_b(BytesRow1, BytesRow0); - - BytesInterleaved = __lsx_vxor_v(BytesInterleaved, BitFlipVector); - - __m128i WordsInterleaved0 = __lsx_vsrai_h(__lsx_vilvl_b(BytesInterleaved, BytesInterleaved), 8); - __m128i WordsInterleaved1 = __lsx_vsrai_h(__lsx_vilvh_b(BytesInterleaved, BytesInterleaved), 8); - - ColumnSums[0] = __lsx_vadd_h(ColumnSums[0], WordsInterleaved0); - ColumnSums[1] = __lsx_vadd_h(ColumnSums[1], WordsInterleaved1); - - __lsx_vst(WordsInterleaved0, (__m128i*) & D[0], 0); - __lsx_vst(WordsInterleaved1, (__m128i*) & D[8], 0); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - uint16_t val = 1; - const __m128i OnesWordBroadcast = __lsx_vreplgr2vr_h(val); - const __m128i BitFlipVector = __lsx_vreplgr2vr_w(BIsSigned ? 0 : 0x80808080); - - // - // Process 8 columns of matrix B in a loop. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - size_t k = CountK; - __m128i ColumnSums[2]; - - ColumnSums[0] = __lsx_vldi(0); - ColumnSums[1] = __lsx_vldi(0); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - // These values are also zero-extended and accumulated into an - // intermediate per-column accumulator. CountK cannot be greater than - // 128 to avoid overflowing these signed 16-bit accumulators. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { - - __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); - __lsx_vinsgr2vr_d(BytesRow0, 0, 1); - __m128i BytesRow1 = __lsx_vld((const __m128i*) & b[ldb], 0); - __lsx_vinsgr2vr_d(BytesRow1, 0, 1); - - MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - __m128i BytesRow0 = __lsx_vld((const __m128i*) & b[0], 0); - __lsx_vinsgr2vr_d(BytesRow0, 0, 1); - - MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - - D += 16; - } - - __m128i tmp1, tmp2; - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); - ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); - ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); - - __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); - __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const uint8_t* b = B; - size_t k = CountK; - __m128i ColumnSums[2]; - uint8_t PaddedMatrixBData[16]; - - __lsx_vst(BitFlipVector, (__m128i*)PaddedMatrixBData, 0); - - ColumnSums[0] = __lsx_vldi(0); - ColumnSums[1] = __lsx_vldi(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_LSX::PackedK) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded[8] = bcopy[ldb]; - padded++; - bcopy++; - } while (padded < padded_end); - - __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); - __lsx_vinsgr2vr_d(BytesRow0, 0, 1); - __m128i BytesRow1 = __lsx_vld((__m128i*) & PaddedMatrixBData[8], 0); - __lsx_vinsgr2vr_d(BytesRow1, 0, 1); - - MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded++; - bcopy++; - } while (padded < padded_end); - - __m128i BytesRow0 = __lsx_vld((__m128i*) & PaddedMatrixBData[0], 0); - __lsx_vinsgr2vr_d(BytesRow0, 0, 1); - - MlasGemmU8X8CopyPackBProcessLSX(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - } - - __m128i tmp1, tmp2; - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[0], OnesWordBroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[0], OnesWordBroadcast); - ColumnSums[0]= __lsx_vadd_w(tmp1, tmp2); - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, ColumnSums[1], OnesWordBroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, ColumnSums[1], OnesWordBroadcast); - ColumnSums[1]= __lsx_vadd_w(tmp1, tmp2); - - __lsx_vst(ColumnSums[0], (__m128i*) & ColumnSumBuffer[0], 0); - __lsx_vst(ColumnSums[1], (__m128i*) & ColumnSumBuffer[4], 0); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowLSX( - __m128i ABroadcast, - const int16_t* B, - __m128i Accumulators[2] -) -{ - __m128i BElements0 = __lsx_vld((__m128i*) & B[0], 0); - __m128i BElements1 = __lsx_vld((__m128i*) & B[8], 0); - - __m128i tmp1, tmp2; - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements0, ABroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements0, ABroadcast); - Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vadd_w(tmp1, tmp2)); - tmp1 = tmp2 = __lsx_vldi(0); - tmp1 = __lsx_vmaddwev_w_h(tmp1, BElements1, ABroadcast); - tmp2 = __lsx_vmaddwod_w_h(tmp2, BElements1, ABroadcast); - Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vadd_w(tmp1, tmp2)); -} - -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8X8_KERNEL_LSX::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_LSX::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - __m128i Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { - - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } - - ZeroPointB += 8; - - Accumulators[0] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[0], 0); - Accumulators[1] = __lsx_vld((__m128i*) & ScaledRowSumBuffer[4], 0); - - } - else { - - Accumulators[0] = __lsx_vreplgr2vr_w(RowSumValue); - Accumulators[1] = Accumulators[0]; - } - - Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((const __m128i*) & ColumnSumBuffer[0], 0)); - Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((const __m128i*) & ColumnSumBuffer[4], 0)); - ColumnSumBuffer += 8; - - // - // Broadcast each pair of 16-bit values from the matrix A and multiply - // with the pair of 16-bit values from matrix B, and add the 32-bit - // intermediate into the accumulator registers. - // - - const int16_t* a = A; - size_t k = PackedCountK; - - while (k >= 4) { - - __m128i AElements = __lsx_vld((__m128i*)a, 0); - __m128i ABroadcast; - - ABroadcast = __lsx_vreplvei_w(AElements, 0); - MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); - - ABroadcast = __lsx_vreplvei_w(AElements, 1); - MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[16], Accumulators); - - ABroadcast = __lsx_vreplvei_w(AElements, 2); - MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[32], Accumulators); - - ABroadcast = __lsx_vreplvei_w(AElements, 3); - MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } - - while (k > 0) { - - __m128i ABroadcast = __lsx_vldrepl_w((int32_t*)a, 0); - MlasGemmU8X8MultiplyAccumulateRowLSX(ABroadcast, &B[0], Accumulators); - - a += 2; - B += 16; - k -= 1; - } - - // - // Output the accumulator block after optionally accumulating the values - // from matrix C. - // - - if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); - Accumulators[1] = __lsx_vadd_w(Accumulators[1], __lsx_vld((__m128i*) & C[4], 0)); - } - - __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); - __lsx_vst(Accumulators[1], (__m128i*) & C[4], 0); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vld((__m128i*) & C[0], 0)); - } - - __lsx_vst(Accumulators[0], (__m128i*) & C[0], 0); - C += 4; - - Accumulators[0] = Accumulators[1]; - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = __lsx_vadd_w(Accumulators[0], __lsx_vinsgr2vr_d(__lsx_vld((__m128i*) & C[0], 0), 0, 1)); - } - - *((uint64_t *)&C[0]) = __lsx_vpickve2gr_d(Accumulators[0], 0); - C += 2; - - Accumulators[0] = __lsx_vshuf4i_w(Accumulators[0], 0xee); - } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = __lsx_vpickve2gr_w(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; - } - - C[0] = AccumulatorValue; - } - - CountN = 0; - } - } - - return 1; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchLSX = { - MlasGemmQuantOperation, - nullptr, - nullptr, - MLAS_GEMM_U8X8_KERNEL_LSX::PackedK, - 0, - 1 // aLSXmbly kernel M stride -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp deleted file mode 100644 index 50e23a02510ec..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp +++ /dev/null @@ -1,1229 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_neon.cpp - -Abstract: - - This module implements QGEMM kernel for neon. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the prototypes of the NEON routines written in assembly. -// -// N.B. The kernel has not been ported to build with the Windows ARM32 toolset. -// - -extern "C" { - - size_t - MLASCALL - MlasGemmU8X8KernelNeon( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); -} - -struct MLAS_GEMM_U8X8_KERNEL_NEON -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 128, 256 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 24, 128, 256 }; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_NEON::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides; - -template <> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned - ) -{ - if(AIsSigned) { - ZeroPointA = (uint8_t)(ZeroPointA ^ 0x80); - } - - return ZeroPointA; -} - -template<> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_NEON::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template -void -MlasGemmQuantCopyPackAU8X8Neon( - MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) -{ - const uint8_t BitFlipByte = 0x80; - const uint32_t BitFlip4Bytes = 0x80808080; - const uint32x4_t BitFlipVector = vdupq_n_u32(BitFlip4Bytes); - - if constexpr (!AIsSigned) { - - MLAS_UNREFERENCED_PARAMETER(BitFlipByte); - MLAS_UNREFERENCED_PARAMETER(BitFlip4Bytes); - MLAS_UNREFERENCED_PARAMETER(BitFlipVector); - } - - // - // Process four rows of matrix A in a loop. - // - // The buffer is packed as a series of 16 byte vectors where four rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // - // This pattern is repeated (CountK / 4) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - while (CountM >= 4) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a1 + lda; - const uint8_t* a3 = a2 + lda; - - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - - uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); - a0 += 16; - uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); - a1 += 16; - uint32x4_t v2 = vld1q_u32(reinterpret_cast(a2)); - a2 += 16; - uint32x4_t v3 = vld1q_u32(reinterpret_cast(a3)); - a3 += 16; - - if constexpr (AIsSigned) { - - v0 = veorq_u32(v0, BitFlipVector); - v1 = veorq_u32(v1, BitFlipVector); - v2 = veorq_u32(v2, BitFlipVector); - v3 = veorq_u32(v3, BitFlipVector); - } - -#if defined(MLAS_NEON32_INTRINSICS) - uint32x4x2_t z0 = vzipq_u32(v0, v2); - uint32x4x2_t z1 = vzipq_u32(v1, v3); - - v0 = z0.val[0]; - v1 = z0.val[1]; - v2 = z1.val[0]; - v3 = z1.val[1]; - - uint32x4x2_t z2 = vzipq_u32(v0, v2); - uint32x4x2_t z3 = vzipq_u32(v1, v3); - - v0 = z2.val[0]; - v1 = z2.val[1]; - v2 = z3.val[0]; - v3 = z3.val[1]; -#else - uint32x4_t z0 = vzip1q_u32(v0, v2); - uint32x4_t z1 = vzip2q_u32(v0, v2); - uint32x4_t z2 = vzip1q_u32(v1, v3); - uint32x4_t z3 = vzip2q_u32(v1, v3); - - v0 = vzip1q_u32(z0, z2); - v1 = vzip2q_u32(z0, z2); - v2 = vzip1q_u32(z1, z3); - v3 = vzip2q_u32(z1, z3); -#endif - - vst1q_u8(&D[0], vreinterpretq_u8_u32(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_u32(v1)); - vst1q_u8(&D[32], vreinterpretq_u8_u32(v2)); - vst1q_u8(&D[48], vreinterpretq_u8_u32(v3)); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v0))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v1))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v2))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v3))); - - D += 64; - k -= 16; - } - - while (k >= 4) { - - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - uint32_t v2 = *reinterpret_cast(a2); - a2 += 4; - uint32_t v3 = *reinterpret_cast(a3); - a3 += 4; - - if constexpr (AIsSigned) { - - v0 = v0 ^ BitFlip4Bytes; - v1 = v1 ^ BitFlip4Bytes; - v2 = v2 ^ BitFlip4Bytes; - v3 = v3 ^ BitFlip4Bytes; - } - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - *reinterpret_cast(&D[8]) = v2; - *reinterpret_cast(&D[12]) = v3; - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(&D[0]))); - - D += 16; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* d = D; - - vst1q_u8(&D[0], vmovq_n_u8(0)); - - while (k > 0) { - - if constexpr (AIsSigned) { - - d[0] = (*a0++) ^ BitFlipByte; - d[4] = (*a1++) ^ BitFlipByte; - d[8] = (*a2++) ^ BitFlipByte; - d[12] = (*a3++) ^ BitFlipByte; - } else { - - d[0] = *a0++; - d[4] = *a1++; - d[8] = *a2++; - d[12] = *a3++; - } - - d += 1; - k -= 1; - } - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(&D[0]))); - - D += 16; - } - - vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - // The buffer is packed as a series of 8 byte vectors where two rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 ] - // - // This pattern is repeated (CountK / 4) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if ((CountM & 2) != 0) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - - size_t k = CountK; - uint32x2_t RowSums = vmov_n_u32(0); - - while (k >= 4) { - - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - - if constexpr (AIsSigned) { - - v0 = v0 ^ BitFlip4Bytes; - v1 = v1 ^ BitFlip4Bytes; - } - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - - RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(&D[0]))); - - D += 8; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* d = D; - - vst1_u8(&D[0], vmov_n_u8(0)); - - while (k > 0) { - - if constexpr (AIsSigned) { - - d[0] = (*a0++) ^ BitFlipByte; - d[4] = (*a1++) ^ BitFlipByte; - } else { - - d[0] = *a0++; - d[4] = *a1++; - } - - d += 1; - k -= 1; - } - - RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(&D[0]))); - - D += 8; - } - - vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); - RowSumBuffer += 2; - - A = A + lda * 2; - } - - // - // Process one row of matrix A. - // - // The buffer is packed as a series of 4 byte with the following pattern: - // - // [ A0 A1 A2 A3 ] - // [ A4 A5 A6 A7 ] - // - // This pattern is repeated (CountK / 4) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if ((CountM & 1) != 0) { - - const uint8_t* a = A; - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - - uint8x16_t v = vld1q_u8(a); - a += 16; - - if constexpr (AIsSigned) { - v = veorq_u8(v, vreinterpretq_u8_u32(BitFlipVector)); - } - - vst1q_u8(D, v); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); - - D += 16; - k -= 16; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - vst1q_u8(&D[0], vmovq_n_u8(0)); - - for (size_t kk = 0; kk < k; kk++) { - D[kk] = a[kk]; - } - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(&D[0]))); - } - -#if defined(MLAS_NEON32_INTRINSICS) - uint32x2_t RowSumsLow = vpadd_u32(vget_high_u32(RowSums), vget_low_u32(RowSums)); - RowSumsLow = vpadd_u32(RowSumsLow, RowSumsLow); - vst1_lane_u32(reinterpret_cast(RowSumBuffer), RowSumsLow, 0); -#elif defined(_M_ARM64) - // N.B. The workaround of defining a local vaddvq_u32 doesn't work here - // as VS2019 added new intrinsics to make the operation work. Also, not - // all build environments using VS2019 have the up-to-date arm64_neon.h, - // so fallback to pairwise addition. - RowSums = vpaddq_u32(RowSums, RowSums); - RowSums = vpaddq_u32(RowSums, RowSums); - vst1q_lane_u32(reinterpret_cast(RowSumBuffer), RowSums, 0); -#else - * RowSumBuffer = int32_t(vaddvq_u32(RowSums)); -#endif - } -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - if (AIsSigned) { - MlasGemmQuantCopyPackAU8X8Neon(D, A, lda, CountM, CountK, RowSumBuffer); - } else { - MlasGemmQuantCopyPackAU8X8Neon(D, A, lda, CountM, CountK, RowSumBuffer); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessNeon( - uint8_t* D, - const uint8_t* B, - uint8x8_t BitFlipVector, - uint32x4_t ColumnSums[2] -) -{ - uint8x8_t BytesRow = veor_u8(vld1_u8(B), BitFlipVector); - vst1_u8(D, BytesRow); - - uint16x8_t WordsRow = vmovl_u8(BytesRow); - ColumnSums[0] = vaddq_u32(ColumnSums[0], vmovl_u16(vget_low_u16(WordsRow))); -#if defined(MLAS_NEON32_INTRINSICS) - ColumnSums[1] = vaddq_u32(ColumnSums[1], vmovl_u16(vget_high_u16(WordsRow))); -#else - ColumnSums[1] = vaddq_u32(ColumnSums[1], vmovl_high_u16(WordsRow)); -#endif -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const uint8x8_t BitFlipVector = vdup_n_u8(BIsSigned ? 0x80 : 0); - const uint8x8_t ZeroVector = vmov_n_u8(0); - const size_t AlignedCountK = - (CountK + MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1) & ~(MLAS_GEMM_U8X8_KERNEL_NEON::PackedK - 1); - - // - // Process 8 columns of matrix B in a loop. - // - // Copy columns from matrix B to the packed buffer. Signed buffers are - // converted to unsigned buffers in order to share a common kernel. - // - // If CountK is not aligned to a multiple of four, then the packed buffer - // is padded with zero vectors. - // - // If CountN is not aligned to a multiple of four, then the extra columns - // are padded with zeroes. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - uint32x4_t ColumnSums[2]; - - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - for (size_t k = CountK; k > 0; k--) { - - MlasGemmU8X8CopyPackBProcessNeon(D, b, BitFlipVector, ColumnSums); - - b += ldb; - D += 8; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - vst1_u8(D, ZeroVector); - D += 8; - } - - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const uint8_t* b = B; - uint8_t PaddedMatrixBData[8]; - uint32x4_t ColumnSums[2]; - - vst1_u8(PaddedMatrixBData, ZeroVector); - - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - for (size_t k = CountK; k > 0; k--) { - - for (size_t n = 0; n < CountN; n++) { - PaddedMatrixBData[n] = b[n]; - } - - MlasGemmU8X8CopyPackBProcessNeon(D, PaddedMatrixBData, BitFlipVector, ColumnSums); - - b += ldb; - D += 8; - } - - for (size_t k = CountK; k < AlignedCountK; k++) { - vst1_u8(D, ZeroVector); - D += 8; - } - - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - } -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8X8_KERNEL_NEON::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_NEON::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return MlasGemmU8X8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8X8_KERNEL_NEON::PackedK, - MLAS_GEMM_U8X8_KERNEL_NEON::PackedStrides.K, - 4 // Kernel Stride M -}; - -#if defined(MLAS_TARGET_ARM64) -/*------------------------- - * NEON kernel for signed int8 - */ - -extern "C" { - // Prototype of NEON s8 kernel in assembly - - size_t - MLASCALL - MlasGemmS8S8KernelNeon( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); - - size_t - MLASCALL - MlasSymQgemmS8KernelNeon( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector - ); - -} - -struct MLAS_GEMM_X8S8_KERNEL_NEON { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 16; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; -}; - -constexpr size_t MLAS_GEMM_X8S8_KERNEL_NEON::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_X8S8_KERNEL_NEON::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_X8S8_KERNEL_NEON::PackedStrides; - -template <> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointA( - int32_t ZeroPointA, - bool AIsSigned - ) -{ - if(AIsSigned) { - return ZeroPointA; - } - - return MLAS_GEMM_X8S8_KERNEL_NEON::OffsetAType(ZeroPointA ^ 0x80); -} - -template -void -MlasGemmQuantCopyPackAX8S8Neon( - MLAS_GEMM_X8S8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer - ) -{ - const uint8x16_t BitFlipVector = vdupq_n_u8(0x80); - if constexpr (AIsSigned) { - MLAS_UNREFERENCED_PARAMETER(BitFlipVector); - } - - // - // Process four rows of matrix A. - // - - while (CountM >= 4) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a1 + lda; - const uint8_t* a3 = a2 + lda; - - size_t k = CountK; - int32x4_t RowSums0 = vdupq_n_s32(0); - int32x4_t RowSums1 = vdupq_n_s32(0); - int32x4_t RowSums2 = vdupq_n_s32(0); - int32x4_t RowSums3 = vdupq_n_s32(0); - - while (k >= 16) { - - int8x16_t v0; - int8x16_t v1; - int8x16_t v2; - int8x16_t v3; - - if constexpr (AIsSigned) { - - v0 = vreinterpretq_s8_u8(vld1q_u8(a0)); - a0 += 16; - v1 = vreinterpretq_s8_u8(vld1q_u8(a1)); - a1 += 16; - v2 = vreinterpretq_s8_u8(vld1q_u8(a2)); - a2 += 16; - v3 = vreinterpretq_s8_u8(vld1q_u8(a3)); - a3 += 16; - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a0), BitFlipVector)); - a0 += 16; - v1 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a1), BitFlipVector)); - a1 += 16; - v2 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a2), BitFlipVector)); - a2 += 16; - v3 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a3), BitFlipVector)); - a3 += 16; - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(v1)); - RowSums2 = vpadalq_s16(RowSums2, vpaddlq_s8(v2)); - RowSums3 = vpadalq_s16(RowSums3, vpaddlq_s8(v3)); - - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_s8(v1)); - vst1q_u8(&D[32], vreinterpretq_u8_s8(v2)); - vst1q_u8(&D[48], vreinterpretq_u8_s8(v3)); - - D += 64; - k -= 16; - } - - if (k > 0) { - - uint8_t* d = D; - - if constexpr (AIsSigned) { - - vst1q_u8(&D[0], vmovq_n_u8(0)); - vst1q_u8(&D[16], vmovq_n_u8(0)); - vst1q_u8(&D[32], vmovq_n_u8(0)); - vst1q_u8(&D[48], vmovq_n_u8(0)); - } else { - - vst1q_u8(&D[0], BitFlipVector); - vst1q_u8(&D[16], BitFlipVector); - vst1q_u8(&D[32], BitFlipVector); - vst1q_u8(&D[48], BitFlipVector); - } - - if (k >= 8) { - - vst1_u8(&d[0], vld1_u8(a0)); - a0 += 8; - vst1_u8(&d[16], vld1_u8(a1)); - a1 += 8; - vst1_u8(&d[32], vld1_u8(a2)); - a2 += 8; - vst1_u8(&d[48], vld1_u8(a3)); - a3 += 8; - - d += 8; - k -= 8; - } - - if (k >= 4) { - - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - uint32_t v2 = *reinterpret_cast(a2); - a2 += 4; - uint32_t v3 = *reinterpret_cast(a3); - a3 += 4; - - *reinterpret_cast(&d[0]) = v0; - *reinterpret_cast(&d[16]) = v1; - *reinterpret_cast(&d[32]) = v2; - *reinterpret_cast(&d[48]) = v3; - - d += 4; - k -= 4; - } - - while (k > 0) { - - d[0] = *a0++; - d[16] = *a1++; - d[32] = *a2++; - d[48] = *a3++; - - d += 1; - k -= 1; - } - - int8x16_t v0; - int8x16_t v1; - int8x16_t v2; - int8x16_t v3; - - if constexpr (AIsSigned) { - - v0 = vreinterpretq_s8_u8(vld1q_u8(&D[0])); - v1 = vreinterpretq_s8_u8(vld1q_u8(&D[16])); - v2 = vreinterpretq_s8_u8(vld1q_u8(&D[32])); - v3 = vreinterpretq_s8_u8(vld1q_u8(&D[48])); - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[0]), BitFlipVector)); - v1 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[16]), BitFlipVector)); - v2 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[32]), BitFlipVector)); - v3 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[48]), BitFlipVector)); - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(v1)); - RowSums2 = vpadalq_s16(RowSums2, vpaddlq_s8(v2)); - RowSums3 = vpadalq_s16(RowSums3, vpaddlq_s8(v3)); - - if constexpr (!AIsSigned) { - - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_s8(v1)); - vst1q_u8(&D[32], vreinterpretq_u8_s8(v2)); - vst1q_u8(&D[48], vreinterpretq_u8_s8(v3)); - } - - D += 64; - } - - RowSums0 = vpaddq_s32(RowSums0, RowSums1); - RowSums2 = vpaddq_s32(RowSums2, RowSums3); - RowSums0 = vpaddq_s32(RowSums0, RowSums2); - - vst1q_s32(&RowSumBuffer[0], RowSums0); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - - if ((CountM & 2) != 0) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - - size_t k = CountK; - int32x4_t RowSums0 = vdupq_n_s32(0); - int32x4_t RowSums1 = vdupq_n_s32(0); - - while (k >= 16) { - - int8x16_t v0; - int8x16_t v1; - - if constexpr (AIsSigned) { - - v0 = vreinterpretq_s8_u8(vld1q_u8(a0)); - a0 += 16; - v1 = vreinterpretq_s8_u8(vld1q_u8(a1)); - a1 += 16; - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a0), BitFlipVector)); - a0 += 16; - v1 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a1), BitFlipVector)); - a1 += 16; - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(v1)); - - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_s8(v1)); - - D += 32; - k -= 16; - } - - if (k > 0) { - - uint8_t* d = D; - - if constexpr (AIsSigned) { - - vst1q_u8(&D[0], vmovq_n_u8(0)); - vst1q_u8(&D[16], vmovq_n_u8(0)); - } else { - - vst1q_u8(&D[0], BitFlipVector); - vst1q_u8(&D[16], BitFlipVector); - } - - while (k > 0) { - - d[0] = *a0++; - d[16] = *a1++; - - d += 1; - k -= 1; - } - - int8x16_t v0; - int8x16_t v1; - - if constexpr (AIsSigned) { - - v0 = vreinterpretq_s8_u8(vld1q_u8(&D[0])); - v1 = vreinterpretq_s8_u8(vld1q_u8(&D[16])); - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[0]), BitFlipVector)); - v1= vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[16]), BitFlipVector)); - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(v1)); - - if constexpr (!AIsSigned) { - - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_s8(v1)); - } - - D += 32; - } - - RowSums0 = vpaddq_s32(RowSums0, RowSums1); - RowSums0 = vpaddq_s32(RowSums0, RowSums0); - - vst1_s32(RowSumBuffer, vget_low_s32(RowSums0)); - RowSumBuffer += 2; - - A = A + lda * 2; - } - - // - // Process one row of matrix A. - // - - if ((CountM & 1) != 0) { - - const uint8_t* a0 = A; - - size_t k = CountK; - int32x4_t RowSums0 = vdupq_n_s32(0); - - while (k >= 16) { - - int8x16_t v0; - if constexpr (AIsSigned){ - - v0 = vreinterpretq_s8_u8(vld1q_u8(a0)); - a0 += 16; - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(a0), BitFlipVector)); - a0 += 16; - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - - D += 16; - k -= 16; - } - - if (k > 0) { - - uint8_t* d = D; - - if constexpr (AIsSigned) { - vst1q_u8(&D[0], vmovq_n_u8(0)); - } else{ - vst1q_u8(&D[0], BitFlipVector); - } - - while (k > 0) { - - d[0] = *a0++; - - d += 1; - k -= 1; - } - - int8x16_t v0; - if constexpr (AIsSigned) { - - v0 = vreinterpretq_s8_u8(vld1q_u8(&D[0])); - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - } else { - - v0 = vreinterpretq_s8_u8(veorq_u8(vld1q_u8(&D[0]), BitFlipVector)); - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(v0)); - vst1q_u8(&D[0], vreinterpretq_u8_s8(v0)); - } - - D += 16; - } - -#if defined(_M_ARM64) - // N.B. The workaround of defining a local vaddvq_u32 doesn't work here - // as VS2019 added new intrinsics to make the operation work. Also, not - // all build environments using VS2019 have the up-to-date arm64_neon.h, - // so fallback to pairwise addition. - RowSums0 = vpaddq_s32(RowSums0, RowSums0); - RowSums0 = vpaddq_s32(RowSums0, RowSums0); - vst1q_lane_s32(RowSumBuffer, RowSums0, 0); -#else - *RowSumBuffer = vaddvq_s32(RowSums0); -#endif - } -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_X8S8_KERNEL_NEON::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - if(AIsSigned) { - MlasGemmQuantCopyPackAX8S8Neon(D, A, lda, CountM, CountK, RowSumBuffer); - } else { - MlasGemmQuantCopyPackAX8S8Neon(D, A, lda, CountM, CountK, RowSumBuffer); - } -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_X8S8_KERNEL_NEON::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - while (CountN >= 4) { - - const uint8_t* b = B; - size_t k = CountK; - int32x4_t ColumnSums0 = vdupq_n_s32(0); - int32x4_t ColumnSums1 = vdupq_n_s32(0); - int32x4_t ColumnSums2 = vdupq_n_s32(0); - int32x4_t ColumnSums3 = vdupq_n_s32(0); - - while (k >= 16) { - - for (size_t nn = 0; nn < 4; nn++) { - for (size_t kk = 0; kk < 16; kk++) { - D[nn * 16 + kk] = (b[kk * ldb + nn] ); - } - } - - ColumnSums0 = vpadalq_s16(ColumnSums0, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[0])))); - ColumnSums1 = vpadalq_s16(ColumnSums1, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[16])))); - ColumnSums2 = vpadalq_s16(ColumnSums2, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[32])))); - ColumnSums3 = vpadalq_s16(ColumnSums3, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[48])))); - - b += 16 * ldb; - D += 64; - k -= 16; - } - - if (k > 0) { - - vst1q_u8(&D[0], vdupq_n_u8(0)); - vst1q_u8(&D[16], vdupq_n_u8(0)); - vst1q_u8(&D[32], vdupq_n_u8(0)); - vst1q_u8(&D[48], vdupq_n_u8(0)); - - for (size_t nn = 0; nn < 4; nn++) { - for (size_t kk = 0; kk < k; kk++) { - D[nn * 16 + kk] = (b[kk * ldb + nn] ); - } - } - - ColumnSums0 = vpadalq_s16(ColumnSums0, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[0])))); - ColumnSums1 = vpadalq_s16(ColumnSums1, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[16])))); - ColumnSums2 = vpadalq_s16(ColumnSums2, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[32])))); - ColumnSums3 = vpadalq_s16(ColumnSums3, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[48])))); - - D += 64; - } - - ColumnSums0 = vpaddq_s32(ColumnSums0, ColumnSums1); - ColumnSums2 = vpaddq_s32(ColumnSums2, ColumnSums3); - ColumnSums0 = vpaddq_s32(ColumnSums0, ColumnSums2); - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums0); - ColumnSumBuffer += 4; - - B += 4; - CountN -= 4; - } - - if (CountN > 0) { - - const uint8_t* b = B; - size_t k = CountK; - int32x4_t ColumnSums0 = vdupq_n_s32(0); - int32x4_t ColumnSums1 = vdupq_n_s32(0); - int32x4_t ColumnSums2 = vdupq_n_s32(0); - int32x4_t ColumnSums3 = vdupq_n_s32(0); - - while (k >= 16) { - - for (size_t nn = 0; nn < CountN; nn++) { - for (size_t kk = 0; kk < 16; kk++) { - D[nn * 16 + kk] = (b[kk * ldb + nn] ); - } - } - - ColumnSums0 = vpadalq_s16(ColumnSums0, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[0])))); - ColumnSums1 = vpadalq_s16(ColumnSums1, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[16])))); - ColumnSums2 = vpadalq_s16(ColumnSums2, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[32])))); - ColumnSums3 = vpadalq_s16(ColumnSums3, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[48])))); - - b += 16 * ldb; - D += 64; - k -= 16; - } - - if (k > 0) { - - vst1q_u8(&D[0], vdupq_n_u8(0)); - vst1q_u8(&D[16], vdupq_n_u8(0)); - vst1q_u8(&D[32], vdupq_n_u8(0)); - vst1q_u8(&D[48], vdupq_n_u8(0)); - - for (size_t nn = 0; nn < CountN; nn++) { - for (size_t kk = 0; kk < k; kk++) { - D[nn * 16 + kk] = (b[kk * ldb + nn] ); - } - } - - ColumnSums0 = vpadalq_s16(ColumnSums0, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[0])))); - ColumnSums1 = vpadalq_s16(ColumnSums1, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[16])))); - ColumnSums2 = vpadalq_s16(ColumnSums2, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[32])))); - ColumnSums3 = vpadalq_s16(ColumnSums3, vpaddlq_s8(vreinterpretq_s8_u8(vld1q_u8(&D[48])))); - - D += 64; - } - - ColumnSums0 = vpaddq_s32(ColumnSums0, ColumnSums1); - ColumnSums2 = vpaddq_s32(ColumnSums2, ColumnSums3); - ColumnSums0 = vpaddq_s32(ColumnSums0, ColumnSums2); - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums0); - ColumnSumBuffer += 4; - } -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_X8S8_KERNEL_NEON::PackedAType* A, - const MLAS_GEMM_X8S8_KERNEL_NEON::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return MlasGemmS8S8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - - -template<> -MLAS_FORCEINLINE -size_t MlasSymmQGemmKernel( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector -) -{ - return MlasSymQgemmS8KernelNeon(A, B, C, PackedCountK, CountM, CountN, ldc, lda, - ColumnSumVector); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_X8S8_KERNEL_NEON::PackedK, - MLAS_GEMM_X8S8_KERNEL_NEON::PackedStrides.K, - 4 // Kernel Stride M -}; - -const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = { - MlasSymmQGemmPackedOperation, - MlasSymmQGemmPackedOperation, - MlasGemmQuantCopyPackB, - 4, // StrideM - MLAS_GEMM_X8S8_KERNEL_NEON::PackedK -}; - -#endif //defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp deleted file mode 100644 index 5370b859bc73a..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp +++ /dev/null @@ -1,1077 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_sdot.cpp - -Abstract: - - This module implements sdot QGEMM kernel. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the prototypes of the NEON SDOT routines written in assembly. -// - -extern "C" { - - size_t - MLASCALL - MlasGemmS8S8KernelSDot( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); -} - -struct MLAS_GEMM_S8S8_KERNEL_SDOT -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 8; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 128, 256 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 24, 128, 384 }; -}; - -constexpr size_t MLAS_GEMM_S8S8_KERNEL_SDOT::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SDOT::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SDOT::PackedStrides; - -template<> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_S8S8_KERNEL_SDOT::PackedAType* D_uint8_t, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - int8_t* D = reinterpret_cast(D_uint8_t); - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - int8_t PaddedMatrixAData[16]; - - // - // Process 8 rows of matrix A. - // - // DOT kernels load 8x4 block of A with two vector registers. So A is packed - // a series of 16 byte vectors where four rows are interleaved with the - // following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] - // - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // [ E4 E5 E6 E7 F4 F5 F6 F7 G4 G5 G6 G7 H4 H5 H6 H7 ] - // - // ... - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - while (CountM >= 8) { - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - const int8_t* a2 = a0 + lda * 2; - const int8_t* a3 = a0 + lda * 3; - const int8_t* a4 = a0 + lda * 4; - const int8_t* a5 = a0 + lda * 5; - const int8_t* a6 = a0 + lda * 6; - const int8_t* a7 = a0 + lda * 7; - - size_t k = CountK; - int32x4_t RowSums0 = vmovq_n_s32(0); - int32x4_t RowSums1 = vmovq_n_s32(0); - - while (k >= 16) { - int32x4_t v0 = vld1q_s32(reinterpret_cast(a0)); - a0 += 16; - int32x4_t v1 = vld1q_s32(reinterpret_cast(a1)); - a1 += 16; - int32x4_t v2 = vld1q_s32(reinterpret_cast(a2)); - a2 += 16; - int32x4_t v3 = vld1q_s32(reinterpret_cast(a3)); - a3 += 16; - int32x4_t v4 = vld1q_s32(reinterpret_cast(a4)); - a4 += 16; - int32x4_t v5 = vld1q_s32(reinterpret_cast(a5)); - a5 += 16; - int32x4_t v6 = vld1q_s32(reinterpret_cast(a6)); - a6 += 16; - int32x4_t v7 = vld1q_s32(reinterpret_cast(a7)); - a7 += 16; - - int32x4_t z0 = vzip1q_s32(v0, v2); - int32x4_t z1 = vzip2q_s32(v0, v2); - int32x4_t z2 = vzip1q_s32(v1, v3); - int32x4_t z3 = vzip2q_s32(v1, v3); - - int32x4_t z4 = vzip1q_s32(v4, v6); - int32x4_t z5 = vzip2q_s32(v4, v6); - int32x4_t z6 = vzip1q_s32(v5, v7); - int32x4_t z7 = vzip2q_s32(v5, v7); - - v0 = vzip1q_s32(z0, z2); - v1 = vzip2q_s32(z0, z2); - v2 = vzip1q_s32(z1, z3); - v3 = vzip2q_s32(z1, z3); - - v4 = vzip1q_s32(z4, z6); - v5 = vzip2q_s32(z4, z6); - v6 = vzip1q_s32(z5, z7); - v7 = vzip2q_s32(z5, z7); - - vst1q_s8(&D[0], vreinterpretq_s8_s32(v0)); - vst1q_s8(&D[16], vreinterpretq_s8_s32(v4)); - vst1q_s8(&D[32], vreinterpretq_s8_s32(v1)); - vst1q_s8(&D[48], vreinterpretq_s8_s32(v5)); - vst1q_s8(&D[64], vreinterpretq_s8_s32(v2)); - vst1q_s8(&D[80], vreinterpretq_s8_s32(v6)); - vst1q_s8(&D[96], vreinterpretq_s8_s32(v3)); - vst1q_s8(&D[112], vreinterpretq_s8_s32(v7)); - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vreinterpretq_s8_s32(v0))); - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vreinterpretq_s8_s32(v1))); - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vreinterpretq_s8_s32(v2))); - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vreinterpretq_s8_s32(v3))); - - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vreinterpretq_s8_s32(v4))); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vreinterpretq_s8_s32(v5))); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vreinterpretq_s8_s32(v6))); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vreinterpretq_s8_s32(v7))); - - D += 128; - k -= 16; - } - - while (k >= 4) { - int32_t v0 = *reinterpret_cast(a0); - a0 += 4; - int32_t v1 = *reinterpret_cast(a1); - a1 += 4; - int32_t v2 = *reinterpret_cast(a2); - a2 += 4; - int32_t v3 = *reinterpret_cast(a3); - a3 += 4; - int32_t v4 = *reinterpret_cast(a4); - a4 += 4; - int32_t v5 = *reinterpret_cast(a5); - a5 += 4; - int32_t v6 = *reinterpret_cast(a6); - a6 += 4; - int32_t v7 = *reinterpret_cast(a7); - a7 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - *reinterpret_cast(&D[8]) = v2; - *reinterpret_cast(&D[12]) = v3; - *reinterpret_cast(&D[16]) = v4; - *reinterpret_cast(&D[20]) = v5; - *reinterpret_cast(&D[24]) = v6; - *reinterpret_cast(&D[28]) = v7; - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vld1q_s8(D))); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vld1q_s8(&D[16]))); - - D += 32; - k -= 4; - } - - if (k > 0) { - // - // Copy the remaining bytes to the zero padded stack buffer. - // - int8_t* d = D; - - vst1q_s8(d, vmovq_n_s8(0)); - vst1q_s8(&d[16], vmovq_n_s8(0)); - - while (k > 0) { - d[0] = *a0++; - d[4] = *a1++; - d[8] = *a2++; - d[12] = *a3++; - d[16] = *a4++; - d[20] = *a5++; - d[24] = *a6++; - d[28] = *a7++; - d += 1; - k -= 1; - } - - RowSums0 = vpadalq_s16(RowSums0, vpaddlq_s8(vld1q_s8(D))); - RowSums1 = vpadalq_s16(RowSums1, vpaddlq_s8(vld1q_s8(&D[16]))); - - D += 32; - } - - if (((CountK - 1) & 7) < 4) { - vst1q_s8(D, vmovq_n_s8(0)); - vst1q_s8(&D[16], vmovq_n_s8(0)); - D += 32; - } - - vst1q_s32(RowSumBuffer, RowSums0); - vst1q_s32(&RowSumBuffer[4], RowSums1); - - RowSumBuffer += 8; - - A = A + lda * 8; - CountM -= 8; - } - - // - // Process four rows of matrix A. - // - // The buffer is packed as a series of 16 byte vectors where four rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 4) { - - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - const int8_t* a2 = a1 + lda; - const int8_t* a3 = a2 + lda; - - size_t k = CountK; - int32x4_t RowSums = vmovq_n_s32(0); - - while (k >= 16) { - - int32x4_t v0 = vld1q_s32(reinterpret_cast(a0)); - a0 += 16; - int32x4_t v1 = vld1q_s32(reinterpret_cast(a1)); - a1 += 16; - int32x4_t v2 = vld1q_s32(reinterpret_cast(a2)); - a2 += 16; - int32x4_t v3 = vld1q_s32(reinterpret_cast(a3)); - a3 += 16; - - int32x4_t z0 = vzip1q_s32(v0, v2); - int32x4_t z1 = vzip2q_s32(v0, v2); - int32x4_t z2 = vzip1q_s32(v1, v3); - int32x4_t z3 = vzip2q_s32(v1, v3); - - v0 = vzip1q_s32(z0, z2); - v1 = vzip2q_s32(z0, z2); - v2 = vzip1q_s32(z1, z3); - v3 = vzip2q_s32(z1, z3); - - vst1q_s8(&D[0], vreinterpretq_s8_s32(v0)); - vst1q_s8(&D[16], vreinterpretq_s8_s32(v1)); - vst1q_s8(&D[32], vreinterpretq_s8_s32(v2)); - vst1q_s8(&D[48], vreinterpretq_s8_s32(v3)); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(vreinterpretq_s8_s32(v0))); - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(vreinterpretq_s8_s32(v1))); - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(vreinterpretq_s8_s32(v2))); - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(vreinterpretq_s8_s32(v3))); - - D += 64; - k -= 16; - } - - while (k >= 4) { - - int32_t v0 = *reinterpret_cast(a0); - a0 += 4; - int32_t v1 = *reinterpret_cast(a1); - a1 += 4; - int32_t v2 = *reinterpret_cast(a2); - a2 += 4; - int32_t v3 = *reinterpret_cast(a3); - a3 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - *reinterpret_cast(&D[8]) = v2; - *reinterpret_cast(&D[12]) = v3; - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(vld1q_s8(D))); - - D += 16; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - int8_t* d = PaddedMatrixAData; - - vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); - - while (k > 0) { - - d[0] = *a0++; - d[4] = *a1++; - d[8] = *a2++; - d[12] = *a3++; - - d += 1; - k -= 1; - } - - int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); - vst1q_s8(D, PackedVector); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(PackedVector)); - - D += 16; - } - - if (((CountK - 1) & 7) < 4) { - - vst1q_s8(D, vmovq_n_s8(0)); - - D += 16; - } - - vst1q_s32(RowSumBuffer, RowSums); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - // The buffer is packed as a series of 8 byte vectors where two rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if (CountM >= 2) { - - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - - size_t k = CountK; - int32x2_t RowSums = vmov_n_s32(0); - - while (k >= 4) { - - int32_t v0 = *reinterpret_cast(a0); - a0 += 4; - int32_t v1 = *reinterpret_cast(a1); - a1 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - - RowSums = vpadal_s16(RowSums, vpaddl_s8(vld1_s8(D))); - - D += 8; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - int8_t* d = PaddedMatrixAData; - - vst1_s8(PaddedMatrixAData, vmov_n_s8(0)); - - while (k > 0) { - - d[0] = *a0++; - d[4] = *a1++; - - d += 1; - k -= 1; - } - - int8x8_t PackedVector = vld1_s8(PaddedMatrixAData); - vst1_s8(D, PackedVector); - - RowSums = vpadal_s16(RowSums, vpaddl_s8(PackedVector)); - - D += 8; - } - - if (((CountK - 1) & 7) < 4) { - - vst1_s8(D, vmov_n_s8(0)); - - D += 8; - } - - vst1_s32(RowSumBuffer, RowSums); - RowSumBuffer += 2; - - A = A + lda * 2; - CountM -= 2; - } - - // - // Process one row of matrix A. - // - // The buffer is packed as a series of 4 byte with the following pattern: - // - // [ A0 A1 A2 A3 ] - // [ A4 A5 A6 A7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if (CountM > 0) { - - const int8_t* a = reinterpret_cast(A); - size_t k = CountK; - int32x4_t RowSums = vmovq_n_s32(0); - - while (k >= 16) { - - int8x16_t v = vld1q_s8(a); - a += 16; - - vst1q_s8(D, v); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); - - D += 16; - k -= 16; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); - - for (size_t kk = 0; kk < k; kk++) { - PaddedMatrixAData[kk] = a[kk]; - } - - int8x16_t v = vld1q_s8(PaddedMatrixAData); - vst1q_s8(D, v); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); - } - -#if defined(_M_ARM64) - // N.B. The workaround of defining a local vaddvq_u32 doesn't work here - // as VS2019 added new intrinsics to make the operation work. Also, not - // all build environments using VS2019 have the up-to-date arm64_neon.h, - // so fallback to pairwise addition. - RowSums = vpaddq_s32(RowSums, RowSums); - RowSums = vpaddq_s32(RowSums, RowSums); - vst1q_lane_s32(reinterpret_cast(RowSumBuffer), RowSums, 0); -#else - *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); -#endif - } -} - -MLAS_FORCEINLINE -void -MlasGemmS8S8CopyPackBProcessSDot( - int8_t* D, - int8x8_t BytesRow[4], - int32x4_t ColumnSums[2] - ) -{ - int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); - int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); - - int8x16x2_t zw = vzipq_s8(v02, v13); - int16x8x2_t zd = vzipq_s16(vreinterpretq_s16_s8(zw.val[0]), vreinterpretq_s16_s8(zw.val[1])); - - vst1q_s8(&D[0], vreinterpretq_s8_s16(zd.val[0])); - vst1q_s8(&D[16], vreinterpretq_s8_s16(zd.val[1])); - - ColumnSums[0] = vpadalq_s16(ColumnSums[0], vpaddlq_s8(vreinterpretq_s8_s16(zd.val[0]))); - ColumnSums[1] = vpadalq_s16(ColumnSums[1], vpaddlq_s8(vreinterpretq_s8_s16(zd.val[1]))); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_S8S8_KERNEL_SDOT::PackedBType* Dst, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - int8_t* D = reinterpret_cast(Dst); - const int8x16_t ZeroVector = vmovq_n_s8(0); - int8x8_t BytesRow[4]; - - // - // Process 8 columns of matrix B in a loop. - // - // The buffer is packed as a series of 16 byte vectors where eight rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // [ E4 E5 E6 E7 F4 F5 F6 F7 G4 G5 G6 G7 H4 H5 H6 H7 ] - // - // Copy columns from matrix B to the packed buffer. - // - // If CountK is not aligned to a multiple of eight, then the packed buffer - // is padded with zero vectors. - // - // If CountN is not aligned to a multiple of eight, then the extra columns - // are padded with zeroes. - // - - while (CountN >= 8) { - - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int32x4_t ColumnSums[2]; - - ColumnSums[0] = vmovq_n_s32(0); - ColumnSums[1] = vmovq_n_s32(0); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - - while (k >= 4) { - - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = vld1_s8(&b[ldb * 1]); - BytesRow[2] = vld1_s8(&b[ldb * 2]); - BytesRow[3] = vld1_s8(&b[ldb * 3]); - - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums); - - b += ldb * 4; - D += 32; - k -= 4; - } - - if (k > 0) { - - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); - BytesRow[2] = (k > 2) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); - BytesRow[3] = vget_low_s8(ZeroVector); - - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums); - - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - - if (((CountK - 1) & 7) < 4) { - - vst1q_s8(&D[0], ZeroVector); - vst1q_s8(&D[16], ZeroVector); - - D += 32; - } - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int8_t PaddedMatrixBData[32]; - int32x4_t ColumnSums[2]; - - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - - ColumnSums[0] = vmovq_n_s32(0); - ColumnSums[1] = vmovq_n_s32(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k > 0) { - - const int8_t* bcopy0 = &b[ldb * 0]; - const int8_t* bcopy1 = &b[ldb * 1]; - const int8_t* bcopy2 = &b[ldb * 2]; - const int8_t* bcopy3 = &b[ldb * 3]; - - if (k >= 4) { - - b += ldb * 4; - k -= 4; - - } else { - - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - - bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[24]; - bcopy2 = (k > 2) ? bcopy2 : &PaddedMatrixBData[24]; - bcopy3 = &PaddedMatrixBData[24]; - - k = 0; - } - - int8_t* padded = PaddedMatrixBData; - int8_t* padded_end = padded + CountN; - - do { - padded[0] = *bcopy0++; - padded[8] = *bcopy1++; - padded[16] = *bcopy2++; - padded[24] = *bcopy3++; - } while (++padded < padded_end); - - BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); - BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); - BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); - BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); - - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums); - - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - - if (((CountK - 1) & 7) < 4) { - - vst1q_s8(&D[0], ZeroVector); - vst1q_s8(&D[16], ZeroVector); - - D += 32; - } - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); - } -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_S8S8_KERNEL_SDOT::PackedAType* A, - const MLAS_GEMM_S8S8_KERNEL_SDOT::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return MlasGemmS8S8KernelSDot(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_S8S8_KERNEL_SDOT::PackedK, - MLAS_GEMM_S8S8_KERNEL_SDOT::PackedStrides.K, - 8 // Kernel Stride M -}; - - -/** - * @brief Type parameter for symmetric qgemm - */ -struct MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 16; -}; - -constexpr size_t MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK; - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedBType* Dst, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - int8_t* D = reinterpret_cast(Dst); - const int8x16_t ZeroVector = vmovq_n_s8(0); - int8x8_t BytesRow[4]; - - // Kernel MlasSymQgemmS8KernelSdot code loads 4x16 B block like: - // - // |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| - // | ... ... ... ... ... ... ... ... | - // |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| - // - // So we process B data similar with MlasGemmQuantCopyPackB - // But twice as wide and twice as deep: - // 16 CountN and 16 CountK - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] - // [ I4 I5 I6 I7 J4 J5 J6 J7 K4 K5 K6 K7 L4 L5 L6 L7 ] - // [ M4 M5 M6 M7 N4 N5 N6 N7 O4 O5 O6 O7 P4 P5 P6 P7 ] - // .... repeat 4 times on K dimension .... - // - // If CountK is not aligned to a multiple of 16, then the packed buffer - // is padded with zero vectors. - // - // If CountN is not aligned to a multiple of 16, then the extra columns - // are padded with zeroes. - // - - while (CountN >= 16) { - - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int32x4_t ColumnSums0[2]; - int32x4_t ColumnSums1[2]; - - ColumnSums0[0] = vmovq_n_s32(0); - ColumnSums0[1] = vmovq_n_s32(0); - ColumnSums1[0] = vmovq_n_s32(0); - ColumnSums1[1] = vmovq_n_s32(0); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - - while (k >= 4) { - - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = vld1_s8(&b[ldb * 1]); - BytesRow[2] = vld1_s8(&b[ldb * 2]); - BytesRow[3] = vld1_s8(&b[ldb * 3]); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums0); - D += 32; - - BytesRow[0] = vld1_s8(&b[ldb * 0 + 8]); - BytesRow[1] = vld1_s8(&b[ldb * 1 + 8]); - BytesRow[2] = vld1_s8(&b[ldb * 2 + 8]); - BytesRow[3] = vld1_s8(&b[ldb * 3 + 8]); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums1); - D += 32; - - b += ldb * 4; - k -= 4; - } - - if (k > 0) { - - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); - BytesRow[2] = (k > 2) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); - BytesRow[3] = vget_low_s8(ZeroVector); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums0); - D += 32; - - BytesRow[0] = vld1_s8(&b[ldb * 0 + 8]); - BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1 + 8]) : vget_low_s8(ZeroVector); - BytesRow[2] = (k > 2) ? vld1_s8(&b[ldb * 2 + 8]) : vget_low_s8(ZeroVector); - BytesRow[3] = vget_low_s8(ZeroVector); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums1); - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK - // - constexpr size_t mask = MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK - 1; - size_t remain = (MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK - (CountK & mask)) & mask; - remain = remain >> 2; // divid by 4 - for (; remain > 0; remain--) { - vst1q_s8(&D[0], ZeroVector); - vst1q_s8(&D[16], ZeroVector); - vst1q_s8(&D[32], ZeroVector); - vst1q_s8(&D[48], ZeroVector); - D += 64; - } - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums0[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums0[1]); - vst1q_s32(&ColumnSumBuffer[8], ColumnSums1[0]); - vst1q_s32(&ColumnSumBuffer[12], ColumnSums1[1]); - ColumnSumBuffer += 16; - - B += 16; - CountN -= 16; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int8_t PaddedMatrixBData[64]; - int32x4_t ColumnSums0[2]; - int32x4_t ColumnSums1[2]; - - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - vst1q_s8(&PaddedMatrixBData[32], ZeroVector); - vst1q_s8(&PaddedMatrixBData[48], ZeroVector); - - ColumnSums0[0] = vmovq_n_s32(0); - ColumnSums0[1] = vmovq_n_s32(0); - ColumnSums1[0] = vmovq_n_s32(0); - ColumnSums1[1] = vmovq_n_s32(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k > 0) { - - const int8_t* bcopy0 = &b[ldb * 0]; - const int8_t* bcopy1 = &b[ldb * 1]; - const int8_t* bcopy2 = &b[ldb * 2]; - const int8_t* bcopy3 = &b[ldb * 3]; - - if (k >= 4) { - - b += ldb * 4; - k -= 4; - - } else { - - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - vst1q_s8(&PaddedMatrixBData[32], ZeroVector); - vst1q_s8(&PaddedMatrixBData[48], ZeroVector); - - bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[48]; - bcopy2 = (k > 2) ? bcopy2 : &PaddedMatrixBData[48]; - bcopy3 = &PaddedMatrixBData[48]; - - k = 0; - } - - int8_t* padded = PaddedMatrixBData; - int8_t* padded_end = padded + CountN; - - do { - padded[0] = *bcopy0++; - padded[16] = *bcopy1++; - padded[32] = *bcopy2++; - padded[48] = *bcopy3++; - } while (++padded < padded_end); - - BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); - BytesRow[1] = vld1_s8(&PaddedMatrixBData[16]); - BytesRow[2] = vld1_s8(&PaddedMatrixBData[32]); - BytesRow[3] = vld1_s8(&PaddedMatrixBData[48]); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums0); - D += 32; - - BytesRow[0] = vld1_s8(&PaddedMatrixBData[8]); - BytesRow[1] = vld1_s8(&PaddedMatrixBData[24]); - BytesRow[2] = vld1_s8(&PaddedMatrixBData[40]); - BytesRow[3] = vld1_s8(&PaddedMatrixBData[56]); - MlasGemmS8S8CopyPackBProcessSDot(D, BytesRow, ColumnSums1); - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK - // - constexpr size_t mask = MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK - 1; - size_t remain = (MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK - (CountK & mask)) & mask; - remain = remain >> 2; // divid by 4 - for (; remain > 0; remain--) { - vst1q_s8(&D[0], ZeroVector); - vst1q_s8(&D[16], ZeroVector); - vst1q_s8(&D[32], ZeroVector); - vst1q_s8(&D[48], ZeroVector); - D += 64; - } - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums0[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums0[1]); - vst1q_s32(&ColumnSumBuffer[8], ColumnSums1[0]); - vst1q_s32(&ColumnSumBuffer[12], ColumnSums1[1]); - } -} - -extern "C" { - // Prototype of SDOT symmetric qgemm kernel in assembly - - size_t - MLASCALL - MlasSymQgemmS8KernelSdot( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector - ); - - size_t - MLASCALL - MlasSymQgemmS8KernelSdotLd64( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector - ); - -} - -template<> -MLAS_FORCEINLINE -size_t MlasSymmQGemmKernel( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector -) -{ - return MlasSymQgemmS8KernelSdot(A, B, C, PackedCountK, CountM, CountN, ldc, lda, - ColumnSumVector); -} - -/** - * @brief Type parameter for symmetric qgemm, little core - */ -struct MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT_LIT { - static constexpr size_t PackedK = MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK; -}; -constexpr size_t MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT_LIT::PackedK; - - -template <> -MLAS_FORCEINLINE -size_t MlasSymmQGemmKernel( - const int8_t* A, - const int8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - size_t lda, - const int32_t* ColumnSumVector -) -{ - return MlasSymQgemmS8KernelSdotLd64(A, B, C, PackedCountK, CountM, CountN, ldc, lda, - ColumnSumVector); -} - -const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot = { - MlasSymmQGemmPackedOperation, - MlasSymmQGemmPackedOperation, - MlasGemmQuantCopyPackB, - 4, // StrideM - MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT::PackedK -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp deleted file mode 100644 index c41f43ca22d18..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp +++ /dev/null @@ -1,964 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_smmla.cpp - -Abstract: - - This module implements smmla QGEMM kernel. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the prototypes of the NEON SMMLA routines written in assembly. -// - -extern "C" { - -size_t MLASCALL -MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB); - -size_t MLASCALL -MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB); -} - -struct MLAS_GEMM_S8S8_KERNEL_SMMLA { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef int8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 8; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; -}; - -constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; - -template <> -MLAS_FORCEINLINE int32_t -MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - return ZeroPointB; -} - -template <> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned) -{ - int8_t* D = reinterpret_cast(D_uint8_t); - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - int8_t PaddedMatrixAData[64]; - - // - // Process 8 rows of matrix A. - // - // MMLA kernels load 8x8 block of A with four vector registers. So A is packed - // a series of 64 byte vectors where eight rows are interleaved with the - // following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // [ C0 C1 C2 C3 C4 C5 C6 C7 ] - // [ D0 D1 D2 D3 D4 D5 D6 D7 ] - // [ E0 E1 E2 E3 E4 E5 E6 E7 ] - // [ F0 F1 F2 F3 F4 F5 F6 F7 ] - // [ G0 G1 G2 G3 G4 G5 G6 G7 ] - // [ H0 H1 H2 H3 H4 H5 H6 H7 ] - // - // ... - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - while (CountM >= 8) { - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - const int8_t* a2 = a0 + lda * 2; - const int8_t* a3 = a0 + lda * 3; - const int8_t* a4 = a0 + lda * 4; - const int8_t* a5 = a0 + lda * 5; - const int8_t* a6 = a0 + lda * 6; - const int8_t* a7 = a0 + lda * 7; - - size_t k = CountK; - int32x4_t RowSums0 = vmovq_n_s32(0); - int32x4_t RowSums1 = vmovq_n_s32(0); - - while (k >= 16) { - int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); - a0 += 16; - int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); - a1 += 16; - int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); - a2 += 16; - int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); - a3 += 16; - int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); - a4 += 16; - int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); - a5 += 16; - int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); - a6 += 16; - int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); - a7 += 16; - - int64x2_t z0 = vzip1q_s64(v0, v1); - int64x2_t z1 = vzip2q_s64(v0, v1); - int64x2_t z2 = vzip1q_s64(v2, v3); - int64x2_t z3 = vzip2q_s64(v2, v3); - - int64x2_t z4 = vzip1q_s64(v4, v5); - int64x2_t z5 = vzip2q_s64(v4, v5); - int64x2_t z6 = vzip1q_s64(v6, v7); - int64x2_t z7 = vzip2q_s64(v6, v7); - - vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); - vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); - vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); - vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); - vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); - vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); - vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); - vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); - - int32x4_t RowSums0L_pada = vmovq_n_s32(0); - RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); - RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); - - int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); - int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); - int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), - vdups_laneq_s32(RowSums0L_add, 2)}; - - int32x4_t RowSums0H_pada = vmovq_n_s32(0); - RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); - RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); - - int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); - int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); - int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), - vdups_laneq_s32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); - - int32x4_t RowSums1L_pada = vmovq_n_s32(0); - RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); - RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); - - int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); - int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); - int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), - vdups_laneq_s32(RowSums1L_add, 2)}; - - int32x4_t RowSums1H_pada = vmovq_n_s32(0); - RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); - RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); - - int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); - int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); - int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), - vdups_laneq_s32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); - - D += 128; - k -= 16; - } - - while (k >= 8) { - int64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - int64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - int64x1_t v2 = *reinterpret_cast(a2); - a2 += 8; - int64x1_t v3 = *reinterpret_cast(a3); - a3 += 8; - int64x1_t v4 = *reinterpret_cast(a4); - a4 += 8; - int64x1_t v5 = *reinterpret_cast(a5); - a5 += 8; - int64x1_t v6 = *reinterpret_cast(a6); - a6 += 8; - int64x1_t v7 = *reinterpret_cast(a7); - a7 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - *reinterpret_cast(&D[16]) = v2; - *reinterpret_cast(&D[24]) = v3; - *reinterpret_cast(&D[32]) = v4; - *reinterpret_cast(&D[40]) = v5; - *reinterpret_cast(&D[48]) = v6; - *reinterpret_cast(&D[56]) = v7; - - int64x2_t z01 = vcombine_s64(v0, v1); - int64x2_t z23 = vcombine_s64(v2, v3); - int64x2_t z45 = vcombine_s64(v4, v5); - int64x2_t z67 = vcombine_s64(v6, v7); - - int32x4_t RowSums0L_pada = vmovq_n_s32(0); - RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); - int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); - int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), - vdups_laneq_s32(RowSums0L_add, 2)}; - - int32x4_t RowSums0H_pada = vmovq_n_s32(0); - RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); - - int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); - int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); - int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), - vdups_laneq_s32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); - - int32x4_t RowSums1L_pada = vmovq_n_s32(0); - RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); - - int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); - int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); - int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), - vdups_laneq_s32(RowSums1L_add, 2)}; - - int32x4_t RowSums1H_pada = vmovq_n_s32(0); - RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); - - int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); - int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); - int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), - vdups_laneq_s32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); - - D += 64; - k -= 8; - } - - if (k > 0) { - // - // zero pad the remaining columns to 8 - // - int8_t* d = D; - - vst1q_s8(d, vmovq_n_s8(0)); - vst1q_s8(&d[16], vmovq_n_s8(0)); - vst1q_s8(&d[32], vmovq_n_s8(0)); - vst1q_s8(&d[48], vmovq_n_s8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - d[16] = *a2++; - d[24] = *a3++; - d[32] = *a4++; - d[40] = *a5++; - d[48] = *a6++; - d[56] = *a7++; - d += 1; - k -= 1; - } - d = D; - int64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v2 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v3 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v4 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v5 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v6 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v7 = *reinterpret_cast(d); - d = d + 8; - - int64x2_t z01 = vcombine_s64(v0, v1); - int64x2_t z23 = vcombine_s64(v2, v3); - int64x2_t z45 = vcombine_s64(v4, v5); - int64x2_t z67 = vcombine_s64(v6, v7); - - int32x4_t RowSums0L_pada = vmovq_n_s32(0); - RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); - int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); - int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), - vdups_laneq_s32(RowSums0L_add, 2)}; - - int32x4_t RowSums0H_pada = vmovq_n_s32(0); - RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); - - int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); - int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); - int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), - vdups_laneq_s32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); - - int32x4_t RowSums1L_pada = vmovq_n_s32(0); - RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); - - int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); - int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); - int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), - vdups_laneq_s32(RowSums1L_add, 2)}; - - int32x4_t RowSums1H_pada = vmovq_n_s32(0); - RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); - - int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); - int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); - int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), - vdups_laneq_s32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); - - D += 64; - } - - vst1q_s32(RowSumBuffer, RowSums0); - vst1q_s32(&RowSumBuffer[4], RowSums1); - - RowSumBuffer += 8; - - A = A + lda * 8; - CountM -= 8; - } - - // - // Process four rows of matrix A. - // - // The buffer is packed as a series of 32 byte vectors where four rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // [ C0 C1 C2 C3 C4 C5 C6 C7 ] - // [ D0 D1 D2 D3 D4 D5 D6 D7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 4) { - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - const int8_t* a2 = a1 + lda; - const int8_t* a3 = a2 + lda; - - size_t k = CountK; - int32x4_t RowSums = vmovq_n_s32(0); - - while (k >= 16) { - int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); - a0 += 16; - int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); - a1 += 16; - int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); - a2 += 16; - int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); - a3 += 16; - - int64x2_t z0 = vzip1q_s64(v0, v1); - int64x2_t z1 = vzip2q_s64(v0, v1); - int64x2_t z2 = vzip1q_s64(v2, v3); - int64x2_t z3 = vzip2q_s64(v2, v3); - - vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); - vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); - vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); - vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); - - int32x4_t RowSumsL_pada = vmovq_n_s32(0); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); - - int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); - int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); - int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), - vdups_laneq_s32(RowSumsL_add, 2)}; - - int32x4_t RowSumsH_pada = vmovq_n_s32(0); - RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); - RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); - - int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); - int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); - int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), - vdups_laneq_s32(RowSumsH_add, 2)}; - - RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); - - D += 64; - k -= 16; - } - - while (k >= 8) { - int64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - int64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - int64x1_t v2 = *reinterpret_cast(a2); - a2 += 8; - int64x1_t v3 = *reinterpret_cast(a3); - a3 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - *reinterpret_cast(&D[16]) = v2; - *reinterpret_cast(&D[24]) = v3; - - int64x2_t z01 = vcombine_s64(v0, v1); - int64x2_t z23 = vcombine_s64(v2, v3); - - int32x4_t RowSumsL_pada = vmovq_n_s32(0); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); - int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); - int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), - vdups_laneq_s32(RowSumsL_add, 2)}; - - int32x4_t RowSumsH_pada = vmovq_n_s32(0); - RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); - - int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); - int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); - int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), - vdups_laneq_s32(RowSumsH_add, 2)}; - - RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); - - D += 32; - k -= 8; - } - - if (k > 0) { - // - // Copy the remaining bytes with zero padding. - // - int8_t* d = D; - - vst1q_s8(d, vmovq_n_s8(0)); - vst1q_s8(&d[16], vmovq_n_s8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - d[16] = *a2++; - d[24] = *a3++; - d += 1; - k -= 1; - } - - d = D; - int64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v2 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v3 = *reinterpret_cast(d); - d = d + 8; - - int64x2_t z01 = vcombine_s64(v0, v1); - int64x2_t z23 = vcombine_s64(v2, v3); - - int32x4_t RowSums0L_pada = vmovq_n_s32(0); - RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); - int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); - int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), - vdups_laneq_s32(RowSums0L_add, 2)}; - - int32x4_t RowSums0H_pada = vmovq_n_s32(0); - RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); - - int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); - int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); - int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), - vdups_laneq_s32(RowSums0H_add, 2)}; - - RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); - - D += 32; - } - - vst1q_s32(RowSumBuffer, RowSums); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - // The buffer is packed as a series of 16 byte vectors where two rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 2) { - const int8_t* a0 = reinterpret_cast(A); - const int8_t* a1 = a0 + lda; - - size_t k = CountK; - int32x2_t RowSums = vmov_n_s32(0); - - while (k >= 16) { - int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); - a0 += 16; - int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); - a1 += 16; - - int64x2_t z0 = vzip1q_s64(v0, v1); - int64x2_t z1 = vzip2q_s64(v0, v1); - - vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); - vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); - - int32x4_t RowSumsL_pada = vmovq_n_s32(0); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); - - int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); - int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); - int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), - vdups_laneq_s32(RowSumsL_add, 2)}; - - RowSums = vadd_s32(RowSums, RowSumsL); - - D += 32; - k -= 16; - } - - while (k >= 8) { - int64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - int64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - - int64x2_t z01 = vcombine_s64(v0, v1); - int32x4_t RowSumsL_pada = vmovq_n_s32(0); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); - int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); - int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), - vdups_laneq_s32(RowSumsL_add, 2)}; - - RowSums = vadd_s32(RowSums, RowSumsL); - - D += 16; - k -= 8; - } - - if (k > 0) { - // - // Zero pad the remaining elements to make 8 columns. - // - - int8_t* d = PaddedMatrixAData; - vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - - d += 1; - k -= 1; - } - - d = PaddedMatrixAData; - int64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - int64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - - int64x2_t z01 = vcombine_s64(v0, v1); - int32x4_t RowSumsL_pada = vmovq_n_s32(0); - RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); - - int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); - int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); - int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), - vdups_laneq_s32(RowSumsL_add, 2)}; - - RowSums = vadd_s32(RowSums, RowSumsL); - - int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); - vst1q_s8(D, PackedVector); - - D += 16; - } - - vst1_s32(RowSumBuffer, RowSums); - RowSumBuffer += 2; - - A = A + lda * 2; - CountM -= 2; - } - - // - // Process one row of matrix A. - // - // The buffer is packed as a series of 8 byte with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of 8, then the vector is padded - // with zeroes. - // - - if (CountM > 0) { - // No need to pad the rows to 2, the .S takes care of zero pdding - const int8_t* a = reinterpret_cast(A); - size_t k = CountK; - int32x4_t RowSums = vmovq_n_s32(0); - - while (k >= 16) { - int8x16_t v = vld1q_s8(a); - a += 16; - - vst1q_s8(D, v); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); - - D += 16; - k -= 16; - } - - if (k > 0) { - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); - - for (size_t kk = 0; kk < k; kk++) { - PaddedMatrixAData[kk] = a[kk]; - } - - int8x16_t v = vld1q_s8(PaddedMatrixAData); - vst1q_s8(D, v); - - RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); - } - - *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); - } -} - -MLAS_FORCEINLINE -void -MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) -{ - int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); - int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); - - int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); - int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); - - int8x16x2_t zw1 = vzipq_s8(v02, v13); - int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); - - int8x16x2_t zw2 = vzipq_s8(v46, v57); - int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); - - int32x4x2_t zd3 = - vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); - int32x4x2_t zd4 = - vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); - - vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); - vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); - vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); - vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); - - int32x4_t ColSums0L_pada = vmovq_n_s32(0); - ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); - int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); - int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); - int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; - - int32x4_t ColSums0H_pada = vmovq_n_s32(0); - ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); - int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); - int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); - int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; - - ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); - - int32x4_t ColSums1L_pada = vmovq_n_s32(0); - ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); - int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); - int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); - int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; - - int32x4_t ColSums1H_pada = vmovq_n_s32(0); - ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); - int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); - int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); - int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; - - ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); -} - -template <> -void -MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) -{ - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - int8_t* D = reinterpret_cast(Dst); - const int8x16_t ZeroVector = vmovq_n_s8(0); - int8x8_t BytesRow[8]; - - // - // Copy data from matrix B into the destination buffer 8x2 blocks at a - // time. - // - // - while (CountN >= 8) { - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int32x4_t ColumnSums[2]; - - ColumnSums[0] = vmovq_n_s32(0); - ColumnSums[1] = vmovq_n_s32(0); - - while (k >= 8) { - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = vld1_s8(&b[ldb * 1]); - BytesRow[2] = vld1_s8(&b[ldb * 2]); - BytesRow[3] = vld1_s8(&b[ldb * 3]); - BytesRow[4] = vld1_s8(&b[ldb * 4]); - BytesRow[5] = vld1_s8(&b[ldb * 5]); - BytesRow[6] = vld1_s8(&b[ldb * 6]); - BytesRow[7] = vld1_s8(&b[ldb * 7]); - - MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); - - D += 64; - b += ldb * 8; - k -= 8; - } - - if (k > 0) { - // Pad k to 8 - - BytesRow[0] = vld1_s8(&b[ldb * 0]); - BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); - BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); - BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); - BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); - BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); - BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); - BytesRow[7] = vget_low_s8(ZeroVector); - - MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); - - D += 64; - } - - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); - - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - const int8_t* b = reinterpret_cast(B); - size_t k = CountK; - int8_t PaddedMatrixBData[64]; - int32x4_t ColumnSums[2]; - - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - vst1q_s8(&PaddedMatrixBData[32], ZeroVector); - vst1q_s8(&PaddedMatrixBData[48], ZeroVector); - - ColumnSums[0] = vmovq_n_s32(0); - ColumnSums[1] = vmovq_n_s32(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k > 0) { - const int8_t* bcopy0 = &b[ldb * 0]; - const int8_t* bcopy1 = &b[ldb * 1]; - const int8_t* bcopy2 = &b[ldb * 2]; - const int8_t* bcopy3 = &b[ldb * 3]; - const int8_t* bcopy4 = &b[ldb * 4]; - const int8_t* bcopy5 = &b[ldb * 5]; - const int8_t* bcopy6 = &b[ldb * 6]; - const int8_t* bcopy7 = &b[ldb * 7]; - - if (k >= 8) { - b += ldb * 8; - k -= 8; - - } else { - vst1q_s8(&PaddedMatrixBData[0], ZeroVector); - vst1q_s8(&PaddedMatrixBData[16], ZeroVector); - vst1q_s8(&PaddedMatrixBData[32], ZeroVector); - vst1q_s8(&PaddedMatrixBData[48], ZeroVector); - - bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; - bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; - bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; - bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; - bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; - bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; - bcopy7 = &PaddedMatrixBData[56]; - - k = 0; - } - - int8_t* padded = PaddedMatrixBData; - int8_t* padded_end = padded + CountN; - do { - padded[0] = *bcopy0++; - padded[8] = *bcopy1++; - padded[16] = *bcopy2++; - padded[24] = *bcopy3++; - padded[32] = *bcopy4++; - padded[40] = *bcopy5++; - padded[48] = *bcopy6++; - padded[56] = *bcopy7++; - - } while (++padded < padded_end); - - BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); - BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); - BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); - BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); - BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); - BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); - BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); - BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); - - MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); - - D += 64; - } - - vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); - vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); - } -} - -template <> -MLAS_FORCEINLINE size_t -MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, - const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) -{ - size_t RowsHandled; - - if (ZeroMode) { - RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB); - } else { - RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB); - } - - return RowsHandled; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, - MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, - 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp deleted file mode 100644 index 65c9e2b5ae1d7..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse.cpp +++ /dev/null @@ -1,497 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_sse.cpp - -Abstract: - - This module implements QGEMM kernels for sse. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -struct MLAS_GEMM_U8X8_KERNEL_SSE -{ - typedef int16_t PackedAType; - typedef int16_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 2; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_SSE::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_SSE::Strides; - -template<> -MLAS_FORCEINLINE constexpr -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (!BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_SSE::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - const __m128i ZeroVector = _mm_setzero_si128(); - const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - uint8_t PaddedMatrixAData[8] = { 0 }; - - // - // Process a single row of matrix A in a loop. - // - - while (CountM > 0) { - - const uint8_t* a = A; - size_t k = CountK; - __m128i ReductionVector = ZeroVector; - - // - // Zero extend the source bytes to 16-bits and write to the packed - // buffer. - // - // The packed buffer has the same data ordering as the source bytes, - // but CountK is aligned up to a multiple of 2 to maintain 32-bit - // alignment. All extra bytes are zero-padded. - // - // These 16-bit values are also accumulated into an intermediate per-row - // accumulator. CountK cannot be greater than 128 to avoid overflowing - // these signed 16-bit accumulators. - // - - while (k >= 8) { - - __m128i Bytes = _mm_loadl_epi64((const __m128i*) & a[0]); - __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - - ReductionVector = _mm_add_epi16(ReductionVector, Words); - - _mm_storeu_si128((__m128i*) & D[0], Words); - - a += 8; - D += 8; - k -= 8; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* padded = PaddedMatrixAData; - uint8_t* padded_end = padded + k; - - do { - padded[0] = a[0]; - padded++; - a++; - } while (padded < padded_end); - - __m128i Bytes = _mm_loadl_epi64((__m128i*)PaddedMatrixAData); - __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - - ReductionVector = _mm_add_epi16(ReductionVector, Words); - - // - // Copy pairs of 16-bit values from the vector to the packed - // buffer and rotate the vector for the next iteration. - // - - for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { - *((int32_t*)D) = _mm_cvtsi128_si32(Words); - D += 2; - Words = _mm_shuffle_epi32(Words, _MM_SHUFFLE(0, 3, 2, 1)); - } - } - - // - // Reduce the partial accumulators. - // - - ReductionVector = _mm_madd_epi16(ReductionVector, OnesWordBroadcast); - ReductionVector = _mm_add_epi32(ReductionVector, - _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(3, 2, 3, 2))); - ReductionVector = _mm_add_epi32(ReductionVector, - _mm_shuffle_epi32(ReductionVector, _MM_SHUFFLE(0, 1, 0, 1))); - - *RowSumBuffer++ = _mm_cvtsi128_si32(ReductionVector); - - A += lda; - CountM -= 1; - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessSse( - MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, - __m128i BytesRow0, - __m128i BytesRow1, - __m128i BitFlipVector, - __m128i ColumnSums[2] -) -{ - __m128i BytesInterleaved = _mm_unpacklo_epi8(BytesRow0, BytesRow1); - - BytesInterleaved = _mm_xor_si128(BytesInterleaved, BitFlipVector); - - __m128i WordsInterleaved0 = _mm_srai_epi16(_mm_unpacklo_epi8(BytesInterleaved, BytesInterleaved), 8); - __m128i WordsInterleaved1 = _mm_srai_epi16(_mm_unpackhi_epi8(BytesInterleaved, BytesInterleaved), 8); - - ColumnSums[0] = _mm_add_epi16(ColumnSums[0], WordsInterleaved0); - ColumnSums[1] = _mm_add_epi16(ColumnSums[1], WordsInterleaved1); - - _mm_storeu_si128((__m128i*) & D[0], WordsInterleaved0); - _mm_storeu_si128((__m128i*) & D[8], WordsInterleaved1); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - const __m128i BitFlipVector = _mm_set1_epi32(BIsSigned ? 0 : 0x80808080); - - // - // Process 8 columns of matrix B in a loop. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - size_t k = CountK; - __m128i ColumnSums[2]; - - ColumnSums[0] = _mm_setzero_si128(); - ColumnSums[1] = _mm_setzero_si128(); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - // These values are also zero-extended and accumulated into an - // intermediate per-column accumulator. CountK cannot be greater than - // 128 to avoid overflowing these signed 16-bit accumulators. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { - - __m128i BytesRow0 = _mm_loadl_epi64((const __m128i*) & b[0]); - __m128i BytesRow1 = _mm_loadl_epi64((const __m128i*) & b[ldb]); - - MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - __m128i BytesRow0 = _mm_loadl_epi64((const __m128i*) & b[0]); - - MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - - D += 16; - } - - ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast); - ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast); - - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[0], ColumnSums[0]); - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[4], ColumnSums[1]); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const uint8_t* b = B; - size_t k = CountK; - __m128i ColumnSums[2]; - uint8_t PaddedMatrixBData[16]; - - _mm_storeu_si128((__m128i*)PaddedMatrixBData, BitFlipVector); - - ColumnSums[0] = _mm_setzero_si128(); - ColumnSums[1] = _mm_setzero_si128(); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_SSE::PackedK) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded[8] = bcopy[ldb]; - padded++; - bcopy++; - } while (padded < padded_end); - - __m128i BytesRow0 = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[0]); - __m128i BytesRow1 = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[8]); - - MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded++; - bcopy++; - } while (padded < padded_end); - - __m128i BytesRow0 = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[0]); - - MlasGemmU8X8CopyPackBProcessSse(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - } - - ColumnSums[0] = _mm_madd_epi16(ColumnSums[0], OnesWordBroadcast); - ColumnSums[1] = _mm_madd_epi16(ColumnSums[1], OnesWordBroadcast); - - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[0], ColumnSums[0]); - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[4], ColumnSums[1]); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowSse( - __m128i ABroadcast, - const int16_t* B, - __m128i Accumulators[2] -) -{ - __m128i BElements0 = _mm_load_si128((__m128i*) & B[0]); - __m128i BElements1 = _mm_load_si128((__m128i*) & B[8]); - - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_madd_epi16(BElements0, ABroadcast)); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(BElements1, ABroadcast)); -} - -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8X8_KERNEL_SSE::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_SSE::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - __m128i Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { - - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } - - ZeroPointB += 8; - - Accumulators[0] = _mm_loadu_si128((__m128i*) & ScaledRowSumBuffer[0]); - Accumulators[1] = _mm_loadu_si128((__m128i*) & ScaledRowSumBuffer[4]); - - } - else { - - Accumulators[0] = _mm_set1_epi32(RowSumValue); - Accumulators[1] = Accumulators[0]; - } - - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*) & ColumnSumBuffer[0])); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*) & ColumnSumBuffer[4])); - ColumnSumBuffer += 8; - - // - // Broadcast each pair of 16-bit values from the matrix A and multiply - // with the pair of 16-bit values from matrix B, and add the 32-bit - // intermediate into the accumulator registers. - // - - const int16_t* a = A; - size_t k = PackedCountK; - - while (k >= 4) { - - __m128i AElements = _mm_loadu_si128((__m128i*)a); - __m128i ABroadcast; - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0)); - MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[0], Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1)); - MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[16], Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(2, 2, 2, 2)); - MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[32], Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(3, 3, 3, 3)); - MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } - - while (k > 0) { - - __m128i ABroadcast = _mm_set1_epi32(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowSse(ABroadcast, &B[0], Accumulators); - - a += 2; - B += 16; - k -= 1; - } - - // - // Output the accumulator block after optionally accumulating the values - // from matrix C. - // - - if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*) & C[0])); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*) & C[4])); - } - - _mm_storeu_si128((__m128i*) & C[0], Accumulators[0]); - _mm_storeu_si128((__m128i*) & C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*) & C[0])); - } - - _mm_storeu_si128((__m128i*) & C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*) & C[0])); - } - - _mm_storel_epi64((__m128i*) & C[0], Accumulators[0]); - C += 2; - - Accumulators[0] = _mm_shuffle_epi32(Accumulators[0], _MM_SHUFFLE(3, 2, 3, 2)); - } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulators[0]); - - if (!ZeroMode) { - AccumulatorValue += C[0]; - } - - C[0] = AccumulatorValue; - } - - CountN = 0; - } - } - - return 1; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchSse = { - MlasGemmQuantOperation, - nullptr, - nullptr, - MLAS_GEMM_U8X8_KERNEL_SSE::PackedK, - 0, - 1 // assembly kernel M stride -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp deleted file mode 100644 index 68931c53eed79..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sse41.cpp +++ /dev/null @@ -1,449 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_sse41.cpp - -Abstract: - - This module implements QGEMM kernels for sse41. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// N.B. MSVC does not require turning on SSE 4.1 intrinsics and the current use -// for this code is Windows only, so restrict this kernel to that environment. - -struct MLAS_GEMM_U8S8_KERNEL_SSE41 -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 4; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 128, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 24, 128, 128 }; -}; - -constexpr size_t MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_SSE41::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides; - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - const __m128i ZeroVector = _mm_setzero_si128(); - const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - - // - // Process a single row of matrix A in a loop. - // - - while (CountM > 0) { - - const uint8_t* a = A; - size_t k = CountK; - __m128i ReductionVector = ZeroVector; - - // - // Copy the source bytes to the packed buffer. - // - // The packed buffer has the same data ordering as the source bytes, - // but CountK is aligned up to a multiple of 4 to maintain 32-bit - // alignment. All extra bytes are zero-padded. - // - - while (k >= 8) { - - __m128i Bytes = _mm_loadl_epi64((const __m128i*) & a[0]); - - __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - ReductionVector = _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); - - _mm_storel_epi64((__m128i*) & D[0], Bytes); - - a += 8; - D += 8; - k -= 8; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - _mm_storel_epi64((__m128i*) & D[0], ZeroVector); - - std::copy_n(&a[0], k, &D[0]); - - __m128i Bytes = _mm_loadl_epi64((__m128i*) & D[0]); - D += (k + 3) & ~3; - - __m128i Words = _mm_unpacklo_epi8(Bytes, ZeroVector); - ReductionVector = _mm_add_epi32(ReductionVector, _mm_madd_epi16(Words, OnesWordBroadcast)); - } - - // - // Reduce the partial accumulators. - // - - ReductionVector = _mm_hadd_epi32(ReductionVector, ReductionVector); - ReductionVector = _mm_hadd_epi32(ReductionVector, ReductionVector); - - *RowSumBuffer++ = _mm_cvtsi128_si32(ReductionVector); - - A += lda; - CountM -= 1; - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessSse41( - MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, - __m128i BytesRows[4], - __m128i OnesByteBroadcast, - __m128i OnesWordBroadcast, - __m128i ColumnSums[2] -) -{ - __m128i PairsInterleaved0 = _mm_unpacklo_epi8(BytesRows[0], BytesRows[1]); - __m128i PairsInterleaved1 = _mm_unpacklo_epi8(BytesRows[2], BytesRows[3]); - - __m128i QuadsInterleaved0 = _mm_unpacklo_epi16(PairsInterleaved0, PairsInterleaved1); - __m128i QuadsInterleaved1 = _mm_unpackhi_epi16(PairsInterleaved0, PairsInterleaved1); - - __m128i PairwiseAdd0 = _mm_maddubs_epi16(OnesByteBroadcast, QuadsInterleaved0); - __m128i PairwiseAdd1 = _mm_maddubs_epi16(OnesByteBroadcast, QuadsInterleaved1); - - PairwiseAdd0 = _mm_madd_epi16(PairwiseAdd0, OnesWordBroadcast); - PairwiseAdd1 = _mm_madd_epi16(PairwiseAdd1, OnesWordBroadcast); - - ColumnSums[0] = _mm_add_epi32(ColumnSums[0], PairwiseAdd0); - ColumnSums[1] = _mm_add_epi32(ColumnSums[1], PairwiseAdd1); - - _mm_storeu_si128((__m128i*) & D[0], QuadsInterleaved0); - _mm_storeu_si128((__m128i*) & D[16], QuadsInterleaved1); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const __m128i OnesByteBroadcast = _mm_set1_epi8(1); - const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - __m128i BytesRows[4]; - - MLAS_UNREFERENCED_PARAMETER(BIsSigned); - - // - // Process 8 columns of matrix B in a loop. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - size_t k = CountK; - __m128i ColumnSums[2]; - - ColumnSums[0] = _mm_setzero_si128(); - ColumnSums[1] = _mm_setzero_si128(); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - - while (k >= MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK) { - - BytesRows[0] = _mm_loadl_epi64((const __m128i*) & b[ldb * 0]); - BytesRows[1] = _mm_loadl_epi64((const __m128i*) & b[ldb * 1]); - BytesRows[2] = _mm_loadl_epi64((const __m128i*) & b[ldb * 2]); - BytesRows[3] = _mm_loadl_epi64((const __m128i*) & b[ldb * 3]); - - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); - - b += ldb * 4; - D += 32; - k -= 4; - } - - if (k > 0) { - - BytesRows[0] = _mm_loadl_epi64((const __m128i*) & b[ldb * 0]); - BytesRows[1] = _mm_setzero_si128(); - BytesRows[2] = _mm_setzero_si128(); - BytesRows[3] = _mm_setzero_si128(); - - if (k >= 2) { - BytesRows[1] = _mm_loadl_epi64((const __m128i*) & b[ldb * 1]); - } - - if (k >= 3) { - BytesRows[2] = _mm_loadl_epi64((const __m128i*) & b[ldb * 2]); - } - - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); - - D += 32; - } - - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[0], ColumnSums[0]); - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[4], ColumnSums[1]); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const __m128i ZeroVector = _mm_setzero_si128(); - - __m128i ColumnSums[2]; - uint8_t PaddedMatrixBData[32]; - - ColumnSums[0] = _mm_setzero_si128(); - ColumnSums[1] = _mm_setzero_si128(); - - while (CountK > 0) { - - size_t k = std::min(CountK, MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK); - CountK -= k; - - _mm_storeu_si128((__m128i*) & PaddedMatrixBData[0], ZeroVector); - _mm_storeu_si128((__m128i*) & PaddedMatrixBData[16], ZeroVector); - - uint8_t* padded = PaddedMatrixBData; - - do { - - std::copy_n(B, CountN, padded); - - padded += 8; - B += ldb; - k -= 1; - - } while (k > 0); - - BytesRows[0] = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[0]); - BytesRows[1] = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[8]); - BytesRows[2] = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[16]); - BytesRows[3] = _mm_loadl_epi64((__m128i*) & PaddedMatrixBData[24]); - - MlasGemmU8X8CopyPackBProcessSse41(D, BytesRows, OnesByteBroadcast, OnesWordBroadcast, ColumnSums); - - D += 32; - } - - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[0], ColumnSums[0]); - _mm_storeu_si128((__m128i*) & ColumnSumBuffer[4], ColumnSums[1]); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowSse41( - __m128i ABroadcast, - const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, - __m128i OnesWordBroadcast, - __m128i Accumulators[2] -) -{ - __m128i BElements0 = _mm_load_si128((__m128i*) & B[0]); - __m128i BElements1 = _mm_load_si128((__m128i*) & B[16]); - - __m128i Intermediate0 = _mm_maddubs_epi16(ABroadcast, BElements0); - __m128i Intermediate1 = _mm_maddubs_epi16(ABroadcast, BElements1); - - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_madd_epi16(Intermediate0, OnesWordBroadcast)); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_madd_epi16(Intermediate1, OnesWordBroadcast)); -} - -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedAType* A, - const MLAS_GEMM_U8S8_KERNEL_SSE41::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - const __m128i OnesWordBroadcast = _mm_set1_epi16(1); - - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - __m128i Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - Accumulators[0] = _mm_set1_epi32(RowSumBuffer[0]); - Accumulators[1] = Accumulators[0]; - - if (ZeroPointB != nullptr) { - Accumulators[0] = _mm_mullo_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*) & ZeroPointB[0])); - Accumulators[1] = _mm_mullo_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*) & ZeroPointB[4])); - ZeroPointB += 8; - } - - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((const __m128i*) & ColumnSumBuffer[0])); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((const __m128i*) & ColumnSumBuffer[4])); - ColumnSumBuffer += 8; - - // - // Broadcast each quad of 8-bit values from the matrix A and multiply - // with the quad of 8-bit values from matrix B, and add the 32-bit - // intermediate into the accumulator registers. - // - - const uint8_t* a = A; - size_t k = PackedCountK; - - while (k >= 4) { - - __m128i AElements = _mm_loadu_si128((__m128i*)a); - __m128i ABroadcast; - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(0, 0, 0, 0)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(1, 1, 1, 1)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[32], OnesWordBroadcast, Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(2, 2, 2, 2)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[64], OnesWordBroadcast, Accumulators); - - ABroadcast = _mm_shuffle_epi32(AElements, _MM_SHUFFLE(3, 3, 3, 3)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[96], OnesWordBroadcast, Accumulators); - - a += 4 * 4; - B += 4 * 32; - k -= 4; - } - - while (k > 0) { - - __m128i ABroadcast = _mm_set1_epi32(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowSse41(ABroadcast, &B[0], OnesWordBroadcast, Accumulators); - - a += 4; - B += 32; - k -= 1; - } - - // - // Output the accumulator block after optionally accumulating the values - // from matrix C. - // - - if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*) & C[0])); - Accumulators[1] = _mm_add_epi32(Accumulators[1], _mm_loadu_si128((__m128i*) & C[4])); - } - - _mm_storeu_si128((__m128i*) & C[0], Accumulators[0]); - _mm_storeu_si128((__m128i*) & C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadu_si128((__m128i*) & C[0])); - } - - _mm_storeu_si128((__m128i*) & C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = _mm_add_epi32(Accumulators[0], _mm_loadl_epi64((__m128i*) & C[0])); - } - - _mm_storel_epi64((__m128i*) & C[0], Accumulators[0]); - C += 2; - - Accumulators[0] = _mm_shuffle_epi32(Accumulators[0], _MM_SHUFFLE(3, 2, 3, 2)); - } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = _mm_cvtsi128_si32(Accumulators[0]); - - if (!ZeroMode) { - AccumulatorValue += C[0]; - } - - C[0] = AccumulatorValue; - } - - CountN = 0; - } - } - - return 1; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchSse41 = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8S8_KERNEL_SSE41::PackedK, - MLAS_GEMM_U8S8_KERNEL_SSE41::PackedStrides.K, - 1 // assembly kernel M stride -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp deleted file mode 100644 index 5cec72542d0d9..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_udot.cpp +++ /dev/null @@ -1,763 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_udot.cpp - -Abstract: - - This module implements udot QGEMM kernel. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the prototypes of the NEON UDOT routines written in assembly. -// - -extern "C" { - - size_t - MLASCALL - MlasGemmU8X8KernelUdot( - const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB, - bool ZeroMode - ); -} - -struct MLAS_GEMM_U8X8_KERNEL_UDOT -{ - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 8; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 128, 256 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 24, 128, 384 }; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides; - -template<> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UDOT::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8X8_KERNEL_UDOT::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - uint8_t PaddedMatrixAData[16]; - - // - // Process 8 rows of matrix A. - // - // DOT kernels load 8x4 block of A with two vector registers. So A is packed - // a series of 16 byte vectors where four rows are interleaved with the - // following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] - // - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // [ E4 E5 E6 E7 F4 F5 F6 F7 G4 G5 G6 G7 H4 H5 H6 H7 ] - // - // ... - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - while (CountM >= 8) { - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a0 + lda * 2; - const uint8_t* a3 = a0 + lda * 3; - const uint8_t* a4 = a0 + lda * 4; - const uint8_t* a5 = a0 + lda * 5; - const uint8_t* a6 = a0 + lda * 6; - const uint8_t* a7 = a0 + lda * 7; - - size_t k = CountK; - uint32x4_t RowSums0 = vmovq_n_u32(0); - uint32x4_t RowSums1 = vmovq_n_u32(0); - - while (k >= 16) { - uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); - a0 += 16; - uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); - a1 += 16; - uint32x4_t v2 = vld1q_u32(reinterpret_cast(a2)); - a2 += 16; - uint32x4_t v3 = vld1q_u32(reinterpret_cast(a3)); - a3 += 16; - uint32x4_t v4 = vld1q_u32(reinterpret_cast(a4)); - a4 += 16; - uint32x4_t v5 = vld1q_u32(reinterpret_cast(a5)); - a5 += 16; - uint32x4_t v6 = vld1q_u32(reinterpret_cast(a6)); - a6 += 16; - uint32x4_t v7 = vld1q_u32(reinterpret_cast(a7)); - a7 += 16; - - uint32x4_t z0 = vzip1q_u32(v0, v2); - uint32x4_t z1 = vzip2q_u32(v0, v2); - uint32x4_t z2 = vzip1q_u32(v1, v3); - uint32x4_t z3 = vzip2q_u32(v1, v3); - - uint32x4_t z4 = vzip1q_u32(v4, v6); - uint32x4_t z5 = vzip2q_u32(v4, v6); - uint32x4_t z6 = vzip1q_u32(v5, v7); - uint32x4_t z7 = vzip2q_u32(v5, v7); - - v0 = vzip1q_u32(z0, z2); - v1 = vzip2q_u32(z0, z2); - v2 = vzip1q_u32(z1, z3); - v3 = vzip2q_u32(z1, z3); - - v4 = vzip1q_u32(z4, z6); - v5 = vzip2q_u32(z4, z6); - v6 = vzip1q_u32(z5, z7); - v7 = vzip2q_u32(z5, z7); - - vst1q_u8(&D[0], vreinterpretq_u8_u32(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_u32(v4)); - vst1q_u8(&D[32], vreinterpretq_u8_u32(v1)); - vst1q_u8(&D[48], vreinterpretq_u8_u32(v5)); - vst1q_u8(&D[64], vreinterpretq_u8_u32(v2)); - vst1q_u8(&D[80], vreinterpretq_u8_u32(v6)); - vst1q_u8(&D[96], vreinterpretq_u8_u32(v3)); - vst1q_u8(&D[112], vreinterpretq_u8_u32(v7)); - - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vreinterpretq_u8_u32(v0))); - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vreinterpretq_u8_u32(v1))); - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vreinterpretq_u8_u32(v2))); - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vreinterpretq_u8_u32(v3))); - - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vreinterpretq_u8_u32(v4))); - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vreinterpretq_u8_u32(v5))); - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vreinterpretq_u8_u32(v6))); - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vreinterpretq_u8_u32(v7))); - - D += 128; - k -= 16; - } - - while (k >= 4) { - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - uint32_t v2 = *reinterpret_cast(a2); - a2 += 4; - uint32_t v3 = *reinterpret_cast(a3); - a3 += 4; - uint32_t v4 = *reinterpret_cast(a4); - a4 += 4; - uint32_t v5 = *reinterpret_cast(a5); - a5 += 4; - uint32_t v6 = *reinterpret_cast(a6); - a6 += 4; - uint32_t v7 = *reinterpret_cast(a7); - a7 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - *reinterpret_cast(&D[8]) = v2; - *reinterpret_cast(&D[12]) = v3; - *reinterpret_cast(&D[16]) = v4; - *reinterpret_cast(&D[20]) = v5; - *reinterpret_cast(&D[24]) = v6; - *reinterpret_cast(&D[28]) = v7; - - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vld1q_u8(D))); - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vld1q_u8(&D[16]))); - - D += 32; - k -= 4; - } - - if (k > 0) { - // - // Copy the remaining bytes to the zero padded stack buffer. - // - uint8_t* d = D; - - vst1q_u8(d, vmovq_n_u8(0)); - vst1q_u8(&d[16], vmovq_n_u8(0)); - - while (k > 0) { - d[0] = *a0++; - d[4] = *a1++; - d[8] = *a2++; - d[12] = *a3++; - d[16] = *a4++; - d[20] = *a5++; - d[24] = *a6++; - d[28] = *a7++; - d += 1; - k -= 1; - } - - RowSums0 = vpadalq_u16(RowSums0, vpaddlq_u8(vld1q_u8(D))); - RowSums1 = vpadalq_u16(RowSums1, vpaddlq_u8(vld1q_u8(&D[16]))); - - D += 32; - } - - if (((CountK - 1) & 7) < 4) { - vst1q_u8(D, vmovq_n_u8(0)); - vst1q_u8(&D[16], vmovq_n_u8(0)); - D += 32; - } - - vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); - vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); - - RowSumBuffer += 8; - - A = A + lda * 8; - CountM -= 8; - } - - // - // Process four rows of matrix A. - // - // The buffer is packed as a series of 16 byte vectors where four rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 4) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a1 + lda; - const uint8_t* a3 = a2 + lda; - - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - - uint32x4_t v0 = vld1q_u32(reinterpret_cast(a0)); - a0 += 16; - uint32x4_t v1 = vld1q_u32(reinterpret_cast(a1)); - a1 += 16; - uint32x4_t v2 = vld1q_u32(reinterpret_cast(a2)); - a2 += 16; - uint32x4_t v3 = vld1q_u32(reinterpret_cast(a3)); - a3 += 16; - - uint32x4_t z0 = vzip1q_u32(v0, v2); - uint32x4_t z1 = vzip2q_u32(v0, v2); - uint32x4_t z2 = vzip1q_u32(v1, v3); - uint32x4_t z3 = vzip2q_u32(v1, v3); - - v0 = vzip1q_u32(z0, z2); - v1 = vzip2q_u32(z0, z2); - v2 = vzip1q_u32(z1, z3); - v3 = vzip2q_u32(z1, z3); - - vst1q_u8(&D[0], vreinterpretq_u8_u32(v0)); - vst1q_u8(&D[16], vreinterpretq_u8_u32(v1)); - vst1q_u8(&D[32], vreinterpretq_u8_u32(v2)); - vst1q_u8(&D[48], vreinterpretq_u8_u32(v3)); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v0))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v1))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v2))); - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vreinterpretq_u8_u32(v3))); - - D += 64; - k -= 16; - } - - while (k >= 4) { - - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - uint32_t v2 = *reinterpret_cast(a2); - a2 += 4; - uint32_t v3 = *reinterpret_cast(a3); - a3 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - *reinterpret_cast(&D[8]) = v2; - *reinterpret_cast(&D[12]) = v3; - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(vld1q_u8(D))); - - D += 16; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* d = PaddedMatrixAData; - - vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); - - while (k > 0) { - - d[0] = *a0++; - d[4] = *a1++; - d[8] = *a2++; - d[12] = *a3++; - - d += 1; - k -= 1; - } - - uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); - vst1q_u8(D, PackedVector); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(PackedVector)); - - D += 16; - } - - if (((CountK - 1) & 7) < 4) { - - vst1q_u8(D, vmovq_n_u8(0)); - - D += 16; - } - - vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - // The buffer is packed as a series of 8 byte vectors where two rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if (CountM >= 2) { - - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - - size_t k = CountK; - uint32x2_t RowSums = vmov_n_u32(0); - - while (k >= 4) { - - uint32_t v0 = *reinterpret_cast(a0); - a0 += 4; - uint32_t v1 = *reinterpret_cast(a1); - a1 += 4; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[4]) = v1; - - RowSums = vpadal_u16(RowSums, vpaddl_u8(vld1_u8(D))); - - D += 8; - k -= 4; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* d = PaddedMatrixAData; - - vst1_u8(PaddedMatrixAData, vmov_n_u8(0)); - - while (k > 0) { - - d[0] = *a0++; - d[4] = *a1++; - - d += 1; - k -= 1; - } - - uint8x8_t PackedVector = vld1_u8(PaddedMatrixAData); - vst1_u8(D, PackedVector); - - RowSums = vpadal_u16(RowSums, vpaddl_u8(PackedVector)); - - D += 8; - } - - if (((CountK - 1) & 7) < 4) { - - vst1_u8(D, vmov_n_u8(0)); - - D += 8; - } - - vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); - RowSumBuffer += 2; - - A = A + lda * 2; - CountM -= 2; - } - - // - // Process one row of matrix A. - // - // The buffer is packed as a series of 4 byte with the following pattern: - // - // [ A0 A1 A2 A3 ] - // [ A4 A5 A6 A7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of four, then the vector is padded - // with zeroes. - // - - if (CountM > 0) { - - const uint8_t* a = A; - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - - uint8x16_t v = vld1q_u8(a); - a += 16; - - vst1q_u8(D, v); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); - - D += 16; - k -= 16; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); - - for (size_t kk = 0; kk < k; kk++) { - PaddedMatrixAData[kk] = a[kk]; - } - - uint8x16_t v = vld1q_u8(PaddedMatrixAData); - vst1q_u8(D, v); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); - } - -#if defined(_M_ARM64) - // N.B. The workaround of defining a local vaddvq_u32 doesn't work here - // as VS2019 added new intrinsics to make the operation work. Also, not - // all build environments using VS2019 have the up-to-date arm64_neon.h, - // so fallback to pairwise addition. - RowSums = vpaddq_u32(RowSums, RowSums); - RowSums = vpaddq_u32(RowSums, RowSums); - vst1q_lane_u32(reinterpret_cast(RowSumBuffer), RowSums, 0); -#else - *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); -#endif - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessUdot( - MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, - uint8x8_t BytesRow[4], - uint8x16_t BitFlipVector, - uint32x4_t ColumnSums[2] - ) -{ - uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); - uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); - - uint8x16x2_t zw = vzipq_u8(v02, v13); - uint16x8x2_t zd = vzipq_u16(vreinterpretq_u16_u8(zw.val[0]), vreinterpretq_u16_u8(zw.val[1])); - - vst1q_u8(&D[0], vreinterpretq_u8_u16(zd.val[0])); - vst1q_u8(&D[16], vreinterpretq_u8_u16(zd.val[1])); - - ColumnSums[0] = vpadalq_u16(ColumnSums[0], vpaddlq_u8(vreinterpretq_u8_u16(zd.val[0]))); - ColumnSums[1] = vpadalq_u16(ColumnSums[1], vpaddlq_u8(vreinterpretq_u8_u16(zd.val[1]))); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const uint8x16_t ZeroVector = vmovq_n_u8(0); - const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); - uint8x8_t BytesRow[4]; - - // - // Process 8 columns of matrix B in a loop. - // - // The buffer is packed as a series of 16 byte vectors where eight rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 B0 B1 B2 B3 C0 C1 C2 C3 D0 D1 D2 D3 ] - // [ E0 E1 E2 E3 F0 F1 F2 F3 G0 G1 G2 G3 H0 H1 H2 H3 ] - // [ A4 A5 A6 A7 B4 B5 B6 B7 C4 C5 C6 C7 D4 D5 D6 D7 ] - // [ E4 E5 E6 E7 F4 F5 F6 F7 G4 G5 G6 G7 H4 H5 H6 H7 ] - // - // Copy columns from matrix B to the packed buffer. Signed buffers are - // converted to unsigned buffers in order to share a common kernel. - // - // If CountK is not aligned to a multiple of eight, then the packed buffer - // is padded with zero vectors. - // - // If CountN is not aligned to a multiple of four, then the extra columns - // are padded with zeroes. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - size_t k = CountK; - uint32x4_t ColumnSums[2]; - - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - - while (k >= 4) { - - BytesRow[0] = vld1_u8(&b[ldb * 0]); - BytesRow[1] = vld1_u8(&b[ldb * 1]); - BytesRow[2] = vld1_u8(&b[ldb * 2]); - BytesRow[3] = vld1_u8(&b[ldb * 3]); - - MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); - - b += ldb * 4; - D += 32; - k -= 4; - } - - if (k > 0) { - - BytesRow[0] = vld1_u8(&b[ldb * 0]); - BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); - BytesRow[2] = (k > 2) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); - BytesRow[3] = vget_low_u8(BitFlipVector); - - MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); - - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - - if (((CountK - 1) & 7) < 4) { - - vst1q_u8(&D[0], ZeroVector); - vst1q_u8(&D[16], ZeroVector); - - D += 32; - } - - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const uint8_t* b = B; - size_t k = CountK; - uint8_t PaddedMatrixBData[32]; - uint32x4_t ColumnSums[2]; - - vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); - - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k > 0) { - - const uint8_t* bcopy0 = &b[ldb * 0]; - const uint8_t* bcopy1 = &b[ldb * 1]; - const uint8_t* bcopy2 = &b[ldb * 2]; - const uint8_t* bcopy3 = &b[ldb * 3]; - - if (k >= 4) { - - b += ldb * 4; - k -= 4; - - } else { - - vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); - - bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[24]; - bcopy2 = (k > 2) ? bcopy2 : &PaddedMatrixBData[24]; - bcopy3 = &PaddedMatrixBData[24]; - - k = 0; - } - - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = *bcopy0++; - padded[8] = *bcopy1++; - padded[16] = *bcopy2++; - padded[24] = *bcopy3++; - } while (++padded < padded_end); - - BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); - BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); - BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); - BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); - - MlasGemmU8X8CopyPackBProcessUdot(D, BytesRow, BitFlipVector, ColumnSums); - - D += 32; - } - - // - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - - if (((CountK - 1) & 7) < 4) { - - vst1q_u8(&D[0], ZeroVector); - vst1q_u8(&D[16], ZeroVector); - - D += 32; - } - - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - } -} - -template<> -MLAS_FORCEINLINE -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_UDOT::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - return MlasGemmU8X8KernelUdot(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8X8_KERNEL_UDOT::PackedK, - MLAS_GEMM_U8X8_KERNEL_UDOT::PackedStrides.K, - 8 -}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp deleted file mode 100644 index 3936154432ac7..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp +++ /dev/null @@ -1,967 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_ummla.cpp - -Abstract: - - This module implements ummla QGEMM kernel. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// -// Define the prototypes of the NEON UMMLA routines written in assembly. -// - -extern "C" { - -size_t MLASCALL -MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB); - -size_t MLASCALL -MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, - const uint8_t* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumVector, - const int32_t* ColumnSumVector, - const int32_t* ZeroPointB); -} - -struct MLAS_GEMM_U8X8_KERNEL_UMMLA { - typedef uint8_t PackedAType; - typedef uint8_t PackedBType; - typedef uint8_t OffsetAType; - typedef uint8_t OffsetBType; - - static constexpr size_t PackedK = 8; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; - -template <> -MLAS_FORCEINLINE int32_t -MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) -{ - if (BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template <> -void -MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - uint8_t PaddedMatrixAData[64]; - - // - // Process 8 rows of matrix A. - // - // MMLA kernels load 8x8 block of A with four vector registers. So A is packed - // a series of 64 byte vectors where eight rows are interleaved with the - // following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // [ C0 C1 C2 C3 C4 C5 C6 C7 ] - // [ D0 D1 D2 D3 D4 D5 D6 D7 ] - // [ E0 E1 E2 E3 E4 E5 E6 E7 ] - // [ F0 F1 F2 F3 F4 F5 F6 F7 ] - // [ G0 G1 G2 G3 G4 G5 G6 G7 ] - // [ H0 H1 H2 H3 H4 H5 H6 H7 ] - // - // ... - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - while (CountM >= 8) { - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a0 + lda * 2; - const uint8_t* a3 = a0 + lda * 3; - const uint8_t* a4 = a0 + lda * 4; - const uint8_t* a5 = a0 + lda * 5; - const uint8_t* a6 = a0 + lda * 6; - const uint8_t* a7 = a0 + lda * 7; - - size_t k = CountK; - uint32x4_t RowSums0 = vmovq_n_u32(0); - uint32x4_t RowSums1 = vmovq_n_u32(0); - - while (k >= 16) { - uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); - a0 += 16; - uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); - a1 += 16; - uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); - a2 += 16; - uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); - a3 += 16; - uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); - a4 += 16; - uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); - a5 += 16; - uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); - a6 += 16; - uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); - a7 += 16; - - uint64x2_t z0 = vzip1q_u64(v0, v1); - uint64x2_t z1 = vzip2q_u64(v0, v1); - uint64x2_t z2 = vzip1q_u64(v2, v3); - uint64x2_t z3 = vzip2q_u64(v2, v3); - - uint64x2_t z4 = vzip1q_u64(v4, v5); - uint64x2_t z5 = vzip2q_u64(v4, v5); - uint64x2_t z6 = vzip1q_u64(v6, v7); - uint64x2_t z7 = vzip2q_u64(v6, v7); - - vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); - vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); - vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); - vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); - vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); - vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); - vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); - vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); - - uint32x4_t RowSums0L_pada = vmovq_n_u32(0); - RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); - RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); - - uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); - uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); - uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), - vdups_laneq_u32(RowSums0L_add, 2)}; - - uint32x4_t RowSums0H_pada = vmovq_n_u32(0); - RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); - RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); - - uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); - uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); - uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), - vdups_laneq_u32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); - - uint32x4_t RowSums1L_pada = vmovq_n_u32(0); - RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); - RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); - - uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); - uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); - uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), - vdups_laneq_u32(RowSums1L_add, 2)}; - - uint32x4_t RowSums1H_pada = vmovq_n_u32(0); - RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); - RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); - - uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); - uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); - uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), - vdups_laneq_u32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); - - D += 128; - k -= 16; - } - - while (k >= 8) { - uint64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - uint64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - uint64x1_t v2 = *reinterpret_cast(a2); - a2 += 8; - uint64x1_t v3 = *reinterpret_cast(a3); - a3 += 8; - uint64x1_t v4 = *reinterpret_cast(a4); - a4 += 8; - uint64x1_t v5 = *reinterpret_cast(a5); - a5 += 8; - uint64x1_t v6 = *reinterpret_cast(a6); - a6 += 8; - uint64x1_t v7 = *reinterpret_cast(a7); - a7 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - *reinterpret_cast(&D[16]) = v2; - *reinterpret_cast(&D[24]) = v3; - *reinterpret_cast(&D[32]) = v4; - *reinterpret_cast(&D[40]) = v5; - *reinterpret_cast(&D[48]) = v6; - *reinterpret_cast(&D[56]) = v7; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint64x2_t z23 = vcombine_u64(v2, v3); - uint64x2_t z45 = vcombine_u64(v4, v5); - uint64x2_t z67 = vcombine_u64(v6, v7); - - uint32x4_t RowSums0L_pada = vmovq_n_u32(0); - RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); - uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); - uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), - vdups_laneq_u32(RowSums0L_add, 2)}; - - uint32x4_t RowSums0H_pada = vmovq_n_u32(0); - RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); - - uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); - uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); - uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), - vdups_laneq_u32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); - - uint32x4_t RowSums1L_pada = vmovq_n_u32(0); - RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); - - uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); - uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); - uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), - vdups_laneq_u32(RowSums1L_add, 2)}; - - uint32x4_t RowSums1H_pada = vmovq_n_u32(0); - RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); - - uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); - uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); - uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), - vdups_laneq_u32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); - - D += 64; - k -= 8; - } - - if (k > 0) { - // - // zero pad the remaining columns to 8 - // - uint8_t* d = D; - - vst1q_u8(d, vmovq_n_u8(0)); - vst1q_u8(&d[16], vmovq_n_u8(0)); - vst1q_u8(&d[32], vmovq_n_u8(0)); - vst1q_u8(&d[48], vmovq_n_u8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - d[16] = *a2++; - d[24] = *a3++; - d[32] = *a4++; - d[40] = *a5++; - d[48] = *a6++; - d[56] = *a7++; - d += 1; - k -= 1; - } - d = D; - uint64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v2 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v3 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v4 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v5 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v6 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v7 = *reinterpret_cast(d); - d = d + 8; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint64x2_t z23 = vcombine_u64(v2, v3); - uint64x2_t z45 = vcombine_u64(v4, v5); - uint64x2_t z67 = vcombine_u64(v6, v7); - - uint32x4_t RowSums0L_pada = vmovq_n_u32(0); - RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); - uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); - uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), - vdups_laneq_u32(RowSums0L_add, 2)}; - - uint32x4_t RowSums0H_pada = vmovq_n_u32(0); - RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); - - uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); - uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); - uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), - vdups_laneq_u32(RowSums0H_add, 2)}; - - RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); - - uint32x4_t RowSums1L_pada = vmovq_n_u32(0); - RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); - - uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); - uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); - uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), - vdups_laneq_u32(RowSums1L_add, 2)}; - - uint32x4_t RowSums1H_pada = vmovq_n_u32(0); - RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); - - uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); - uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); - uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), - vdups_laneq_u32(RowSums1H_add, 2)}; - - RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); - - D += 64; - } - - vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); - vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); - - RowSumBuffer += 8; - - A = A + lda * 8; - CountM -= 8; - } - - // - // Process four rows of matrix A. - // - // The buffer is packed as a series of 32 byte vectors where four rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // [ C0 C1 C2 C3 C4 C5 C6 C7 ] - // [ D0 D1 D2 D3 D4 D5 D6 D7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 4) { - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - const uint8_t* a2 = a1 + lda; - const uint8_t* a3 = a2 + lda; - - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); - a0 += 16; - uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); - a1 += 16; - uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); - a2 += 16; - uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); - a3 += 16; - - uint64x2_t z0 = vzip1q_u64(v0, v1); - uint64x2_t z1 = vzip2q_u64(v0, v1); - uint64x2_t z2 = vzip1q_u64(v2, v3); - uint64x2_t z3 = vzip2q_u64(v2, v3); - - vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); - vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); - vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); - vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); - - uint32x4_t RowSumsL_pada = vmovq_n_u32(0); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); - - uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); - uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); - uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), - vdups_laneq_u32(RowSumsL_add, 2)}; - - uint32x4_t RowSumsH_pada = vmovq_n_u32(0); - RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); - RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); - - uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); - uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); - uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), - vdups_laneq_u32(RowSumsH_add, 2)}; - - RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); - - D += 64; - k -= 16; - } - - while (k >= 8) { - uint64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - uint64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - uint64x1_t v2 = *reinterpret_cast(a2); - a2 += 8; - uint64x1_t v3 = *reinterpret_cast(a3); - a3 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - *reinterpret_cast(&D[16]) = v2; - *reinterpret_cast(&D[24]) = v3; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint64x2_t z23 = vcombine_u64(v2, v3); - - uint32x4_t RowSumsL_pada = vmovq_n_u32(0); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); - uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); - uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), - vdups_laneq_u32(RowSumsL_add, 2)}; - - uint32x4_t RowSumsH_pada = vmovq_n_u32(0); - RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); - - uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); - uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); - uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), - vdups_laneq_u32(RowSumsH_add, 2)}; - - RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); - - D += 32; - k -= 8; - } - - if (k > 0) { - // - // Copy the remaining bytes with zero padding. - // - uint8_t* d = D; - - vst1q_u8(d, vmovq_n_u8(0)); - vst1q_u8(&d[16], vmovq_n_u8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - d[16] = *a2++; - d[24] = *a3++; - d += 1; - k -= 1; - } - - d = D; - uint64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v2 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v3 = *reinterpret_cast(d); - d = d + 8; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint64x2_t z23 = vcombine_u64(v2, v3); - - uint32x4_t RowSums0L_pada = vmovq_n_u32(0); - RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); - uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); - uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), - vdups_laneq_u32(RowSums0L_add, 2)}; - - uint32x4_t RowSums0H_pada = vmovq_n_u32(0); - RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); - - uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); - uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); - uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), - vdups_laneq_u32(RowSums0H_add, 2)}; - - RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); - - D += 32; - } - - vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); - RowSumBuffer += 4; - - A = A + lda * 4; - CountM -= 4; - } - - // - // Process two rows of matrix A. - // - // The buffer is packed as a series of 16 byte vectors where two rows are - // interleaved with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // [ B0 B1 B2 B3 B4 B5 B6 B7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of eight, then the vector is padded - // with zeroes. - // - - if (CountM >= 2) { - const uint8_t* a0 = A; - const uint8_t* a1 = a0 + lda; - - size_t k = CountK; - uint32x2_t RowSums = vmov_n_u32(0); - - while (k >= 16) { - uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); - a0 += 16; - uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); - a1 += 16; - - uint64x2_t z0 = vzip1q_u64(v0, v1); - uint64x2_t z1 = vzip2q_u64(v0, v1); - - vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); - vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); - - uint32x4_t RowSumsL_pada = vmovq_n_u32(0); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); - - uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); - uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); - uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), - vdups_laneq_u32(RowSumsL_add, 2)}; - - RowSums = vadd_u32(RowSums, RowSumsL); - - D += 32; - k -= 16; - } - - while (k >= 8) { - uint64x1_t v0 = *reinterpret_cast(a0); - a0 += 8; - uint64x1_t v1 = *reinterpret_cast(a1); - a1 += 8; - - *reinterpret_cast(&D[0]) = v0; - *reinterpret_cast(&D[8]) = v1; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint32x4_t RowSumsL_pada = vmovq_n_u32(0); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); - uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); - uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), - vdups_laneq_u32(RowSumsL_add, 2)}; - - RowSums = vadd_u32(RowSums, RowSumsL); - - D += 16; - k -= 8; - } - - if (k > 0) { - // - // Zero pad the remaining elements to make 8 columns. - // - - uint8_t* d = PaddedMatrixAData; - vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); - - while (k > 0) { - d[0] = *a0++; - d[8] = *a1++; - - d += 1; - k -= 1; - } - - d = PaddedMatrixAData; - uint64x1_t v0 = *reinterpret_cast(d); - d = d + 8; - uint64x1_t v1 = *reinterpret_cast(d); - d = d + 8; - - uint64x2_t z01 = vcombine_u64(v0, v1); - uint32x4_t RowSumsL_pada = vmovq_n_u32(0); - RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); - - uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); - uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); - uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), - vdups_laneq_u32(RowSumsL_add, 2)}; - - RowSums = vadd_u32(RowSums, RowSumsL); - - uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); - vst1q_u8(D, PackedVector); - - D += 16; - } - - vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); - RowSumBuffer += 2; - - A = A + lda * 2; - CountM -= 2; - } - - // - // Process one row of matrix A. - // - // The buffer is packed as a series of 8 byte with the following pattern: - // - // [ A0 A1 A2 A3 A4 A5 A6 A7 ] - // - // This pattern is repeated (CountK / 8) times. - // - // If CountK is not aligned to a multiple of 8, then the vector is padded - // with zeroes. - // - - if (CountM > 0) { - // No need to pad the rows to 2, the .S takes care of zero pdding - const uint8_t* a = A; - size_t k = CountK; - uint32x4_t RowSums = vmovq_n_u32(0); - - while (k >= 16) { - uint8x16_t v = vld1q_u8(a); - a += 16; - - vst1q_u8(D, v); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); - - D += 16; - k -= 16; - } - - if (k > 0) { - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); - - for (size_t kk = 0; kk < k; kk++) { - PaddedMatrixAData[kk] = a[kk]; - } - - uint8x16_t v = vld1q_u8(PaddedMatrixAData); - vst1q_u8(D, v); - - RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); - } - - *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, - uint8x8_t BytesRow[8], - uint8x16_t BitFlipVector, - uint32x4_t ColumnSums[2]) -{ - uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); - uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); - - uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); - uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); - - uint8x16x2_t zw1 = vzipq_u8(v02, v13); - uint16x8x2_t zd1 = - vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); - - uint8x16x2_t zw2 = vzipq_u8(v46, v57); - uint16x8x2_t zd2 = - vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); - - uint32x4x2_t zd3 = - vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); - uint32x4x2_t zd4 = - vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); - - vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); - vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); - vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); - vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); - - uint32x4_t ColSums0L_pada = vmovq_n_u32(0); - ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); - uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); - uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); - uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; - - uint32x4_t ColSums0H_pada = vmovq_n_u32(0); - ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); - uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); - uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); - uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; - - ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); - - uint32x4_t ColSums1L_pada = vmovq_n_u32(0); - ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); - uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); - uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); - uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; - - uint32x4_t ColSums1H_pada = vmovq_n_u32(0); - ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); - uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); - uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); - uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; - - ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); -} - -template <> -void -MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned) -{ - const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); - uint8x8_t BytesRow[8]; - - // - // Copy data from matrix B into the destination buffer 8x2 blocks at a - // time. - // - // - while (CountN >= 8) { - const uint8_t* b = B; - size_t k = CountK; - uint32x4_t ColumnSums[2]; - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - while (k >= 8) { - BytesRow[0] = vld1_u8(&b[ldb * 0]); - BytesRow[1] = vld1_u8(&b[ldb * 1]); - BytesRow[2] = vld1_u8(&b[ldb * 2]); - BytesRow[3] = vld1_u8(&b[ldb * 3]); - BytesRow[4] = vld1_u8(&b[ldb * 4]); - BytesRow[5] = vld1_u8(&b[ldb * 5]); - BytesRow[6] = vld1_u8(&b[ldb * 6]); - BytesRow[7] = vld1_u8(&b[ldb * 7]); - - MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); - - D += 64; - b += ldb * 8; - k -= 8; - } - - if (k > 0) { - // Pad k to 8 - - BytesRow[0] = vld1_u8(&b[ldb * 0]); - BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); - BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); - BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); - BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); - BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); - BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); - BytesRow[7] = vget_low_u8(BitFlipVector); - - MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); - - D += 64; - } - - // Zero pad the output buffer to a multiple of PackedK if the above - // processed an odd number of four row bundles. - // - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - const uint8_t* b = B; - size_t k = CountK; - uint8_t PaddedMatrixBData[64]; - uint32x4_t ColumnSums[2]; - - vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); - - ColumnSums[0] = vmovq_n_u32(0); - ColumnSums[1] = vmovq_n_u32(0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k > 0) { - const uint8_t* bcopy0 = &b[ldb * 0]; - const uint8_t* bcopy1 = &b[ldb * 1]; - const uint8_t* bcopy2 = &b[ldb * 2]; - const uint8_t* bcopy3 = &b[ldb * 3]; - const uint8_t* bcopy4 = &b[ldb * 4]; - const uint8_t* bcopy5 = &b[ldb * 5]; - const uint8_t* bcopy6 = &b[ldb * 6]; - const uint8_t* bcopy7 = &b[ldb * 7]; - - if (k >= 8) { - b += ldb * 8; - k -= 8; - - } else { - vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); - vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); - - bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; - bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; - bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; - bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; - bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; - bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; - bcopy7 = &PaddedMatrixBData[56]; - - k = 0; - } - - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - do { - padded[0] = *bcopy0++; - padded[8] = *bcopy1++; - padded[16] = *bcopy2++; - padded[24] = *bcopy3++; - padded[32] = *bcopy4++; - padded[40] = *bcopy5++; - padded[48] = *bcopy6++; - padded[56] = *bcopy7++; - - } while (++padded < padded_end); - - BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); - BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); - BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); - BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); - BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); - BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); - BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); - BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); - - MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); - - D += 64; - } - - vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); - vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); - } -} - -template <> -MLAS_FORCEINLINE size_t -MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode) -{ - size_t RowsHandled; - - if (ZeroMode) { - RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB); - } else { - RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, - RowSumBuffer, ColumnSumBuffer, ZeroPointB); - } - - return RowsHandled; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { - MlasGemmQuantOperation, - MlasGemmQuantPackedOperation, - MlasGemmQuantCopyPackB, - MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, - MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, - 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp deleted file mode 100644 index 1f33d77adf4b9..0000000000000 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp +++ /dev/null @@ -1,509 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qgemm_kernel_wasmsimd.cpp - -Abstract: - - This module implements QGEMM kernel for WebAssembly SIMD128. - ---*/ - -#include "mlasi.h" -#include "qgemm.h" - -// wasm implementation of "_mm_unpacklo_epi8" -v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i8x16_unpacklo(v128_t a, v128_t b) { - return wasm_i8x16_shuffle(a, b, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23); -} - -// wasm implementation of "_mm_unpackhi_epi8" -v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i8x16_unpackhi(v128_t a, v128_t b) { - return wasm_i8x16_shuffle(a, b, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31); -} - -struct MLAS_GEMM_U8X8_KERNEL_WASMSIMD -{ - typedef int16_t PackedAType; - typedef int16_t PackedBType; - typedef uint8_t OffsetAType; - typedef int8_t OffsetBType; - - static constexpr size_t PackedK = 2; - static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; - static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; -}; - -constexpr size_t MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK; -constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_WASMSIMD::Strides; - -template<> -MLAS_FORCEINLINE -int32_t -MlasGemmQuantFixupZeroPointB( - int32_t ZeroPointB, - bool BIsSigned - ) -{ - if (!BIsSigned) { - ZeroPointB = MLAS_GEMM_U8X8_KERNEL_WASMSIMD::OffsetBType(ZeroPointB ^ 0x80); - } - - return ZeroPointB; -} - -template<> -void -MlasGemmQuantCopyPackA( - MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* D, - const uint8_t* A, - size_t lda, - size_t CountM, - size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned - ) -{ - MLAS_UNREFERENCED_PARAMETER(AIsSigned); - const v128_t ZeroVector = wasm_i64x2_const(0, 0); - const v128_t OnesWordBroadcast = wasm_i16x8_splat(1); - uint8_t PaddedMatrixAData[8] = { 0 }; - - // - // Process a single row of matrix A in a loop. - // - - while (CountM > 0) { - - const uint8_t* a = A; - size_t k = CountK; - v128_t ReductionVector = ZeroVector; - - // - // Zero extend the source bytes to 16-bits and write to the packed - // buffer. - // - // The packed buffer has the same data ordering as the source bytes, - // but CountK is aligned up to a multiple of 2 to maintain 32-bit - // alignment. All extra bytes are zero-padded. - // - // These 16-bit values are also accumulated into an intermediate per-row - // accumulator. CountK cannot be greater than 128 to avoid overflowing - // these signed 16-bit accumulators. - // - - while (k >= 8) { - - v128_t Bytes = wasm_v128_load64_zero(&a[0]); - v128_t Words = wasm_i8x16_unpacklo(Bytes, ZeroVector); - - ReductionVector = wasm_i16x8_add(ReductionVector, Words); - - wasm_v128_store(&D[0], Words); - - a += 8; - D += 8; - k -= 8; - } - - if (k > 0) { - - // - // Copy the remaining bytes to the zero padded stack buffer. - // - - uint8_t* padded = PaddedMatrixAData; - uint8_t* padded_end = padded + k; - - do { - padded[0] = a[0]; - padded++; - a++; - } while (padded < padded_end); - - v128_t Bytes = wasm_v128_load64_zero(PaddedMatrixAData); - v128_t Words = wasm_i8x16_unpacklo(Bytes, ZeroVector); - - ReductionVector = wasm_i16x8_add(ReductionVector, Words); - - // - // Copy pairs of 16-bit values from the vector to the packed - // buffer and rotate the vector for the next iteration. - // - - for (size_t pairs = (k + 1) / 2; pairs > 0; pairs--) { - *((int32_t*)D) = wasm_i32x4_extract_lane(Words, 0); - D += 2; - Words = wasm_i32x4_shuffle(Words, wasm_i32x4_splat(0), 1, 2, 3, 0); - } - } - - // - // Reduce the partial accumulators. - // - - ReductionVector = wasm_i32x4_dot_i16x8(ReductionVector, OnesWordBroadcast); - ReductionVector = wasm_i32x4_add(ReductionVector, - wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 2, 3, 2, 3)); - ReductionVector = wasm_i32x4_add(ReductionVector, - wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 1, 0, 1, 0)); - - *RowSumBuffer++ = wasm_i32x4_extract_lane(ReductionVector, 0); - - A += lda; - CountM -= 1; - } -} - - -MLAS_FORCEINLINE -void -MlasGemmU8X8CopyPackBProcessWasmSimd( - MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* D, - v128_t BytesRow0, - v128_t BytesRow1, - v128_t BitFlipVector, - v128_t ColumnSums[2] -) -{ - v128_t BytesInterleaved = wasm_i8x16_unpacklo(BytesRow0, BytesRow1); - - BytesInterleaved = wasm_v128_xor(BytesInterleaved, BitFlipVector); - - v128_t WordsInterleaved0 = wasm_i16x8_shr(wasm_i8x16_unpacklo(BytesInterleaved, BytesInterleaved), 8); - v128_t WordsInterleaved1 = wasm_i16x8_shr(wasm_i8x16_unpackhi(BytesInterleaved, BytesInterleaved), 8); - - ColumnSums[0] = wasm_i16x8_add(ColumnSums[0], WordsInterleaved0); - ColumnSums[1] = wasm_i16x8_add(ColumnSums[1], WordsInterleaved1); - - wasm_v128_store(&D[0], WordsInterleaved0); - wasm_v128_store(&D[8], WordsInterleaved1); -} - -template<> -void -MlasGemmQuantCopyPackB( - MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* D, - const uint8_t* B, - size_t ldb, - size_t CountN, - size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned - ) -{ - const v128_t OnesWordBroadcast = wasm_i16x8_splat(1); - const v128_t BitFlipVector = wasm_i32x4_splat(BIsSigned ? 0 : 0x80808080); - - // - // Process 8 columns of matrix B in a loop. - // - - while (CountN >= 8) { - - const uint8_t* b = B; - size_t k = CountK; - v128_t ColumnSums[2]; - - ColumnSums[0] = wasm_i64x2_const(0, 0); - ColumnSums[1] = wasm_i64x2_const(0, 0); - - // - // Interleave rows of matrix B and write to the packed buffer. - // - // These values are also zero-extended and accumulated into an - // intermediate per-column accumulator. CountK cannot be greater than - // 128 to avoid overflowing these signed 16-bit accumulators. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK) { - - v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); - v128_t BytesRow1 = wasm_v128_load64_zero(&b[ldb]); - - MlasGemmU8X8CopyPackBProcessWasmSimd(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); - - MlasGemmU8X8CopyPackBProcessWasmSimd(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - - D += 16; - } - - ColumnSums[0] = wasm_i32x4_dot_i16x8(ColumnSums[0], OnesWordBroadcast); - ColumnSums[1] = wasm_i32x4_dot_i16x8(ColumnSums[1], OnesWordBroadcast); - - wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); - wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); - ColumnSumBuffer += 8; - - B += 8; - CountN -= 8; - } - - // - // Process the remaining columns of matrix B. - // - - if (CountN > 0) { - - const uint8_t* b = B; - size_t k = CountK; - v128_t ColumnSums[2]; - uint8_t PaddedMatrixBData[16]; - - wasm_v128_store(PaddedMatrixBData, BitFlipVector); - - ColumnSums[0] = wasm_i64x2_const(0, 0); - ColumnSums[1] = wasm_i64x2_const(0, 0); - - // - // Interleave rows of matrix B using an intermediate zero padded stack - // buffer and write to the packed buffer. - // - - while (k >= MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded[8] = bcopy[ldb]; - padded++; - bcopy++; - } while (padded < padded_end); - - v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); - v128_t BytesRow1 = wasm_v128_load64_zero(&PaddedMatrixBData[8]); - - MlasGemmU8X8CopyPackBProcessWasmSimd(D, BytesRow0, BytesRow1, BitFlipVector, ColumnSums); - - b += ldb * 2; - D += 16; - k -= 2; - } - - if (k > 0) { - - const uint8_t* bcopy = b; - uint8_t* padded = PaddedMatrixBData; - uint8_t* padded_end = padded + CountN; - - do { - padded[0] = bcopy[0]; - padded++; - bcopy++; - } while (padded < padded_end); - - v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); - - MlasGemmU8X8CopyPackBProcessWasmSimd(D, BytesRow0, BitFlipVector, BitFlipVector, ColumnSums); - } - - ColumnSums[0] = wasm_i32x4_dot_i16x8(ColumnSums[0], OnesWordBroadcast); - ColumnSums[1] = wasm_i32x4_dot_i16x8(ColumnSums[1], OnesWordBroadcast); - - wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); - wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); - } -} - -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmSimd( - v128_t ABroadcast, - const int16_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[8]); - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); -} - - -template<> -size_t -MlasGemmQuantKernel( - const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, - const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, - int32_t* C, - size_t PackedCountK, - size_t CountM, - size_t CountN, - size_t ldc, - const int32_t* RowSumBuffer, - const int32_t* ColumnSumBuffer, - const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - v128_t Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { - - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } - - ZeroPointB += 8; - - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - - } - else { - - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; - } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); - ColumnSumBuffer += 8; - - // - // Broadcast each pair of 16-bit values from the matrix A and multiply - // with the pair of 16-bit values from matrix B, and add the 32-bit - // intermediate into the accumulator registers. - // - - const int16_t* a = A; - size_t k = PackedCountK; - - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[16], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } - - while (k > 0) { - - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); - - a += 2; - B += 16; - k -= 1; - } - - // - // Output the accumulator block after optionally accumulating the values - // from matrix C. - // - - if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); - } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); - } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; - } - - C[0] = AccumulatorValue; - } - - CountN = 0; - } - } - - return 1; -} - -const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { - MlasGemmQuantOperation, - nullptr, - nullptr, - MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK, - 0, - 4 // multiple of kernel stride M -}; diff --git a/onnxruntime/core/mlas/lib/qladd.cpp b/onnxruntime/core/mlas/lib/qladd.cpp deleted file mode 100644 index 5dafa17c2ae66..0000000000000 --- a/onnxruntime/core/mlas/lib/qladd.cpp +++ /dev/null @@ -1,812 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qladd.cpp - -Abstract: - - This module implements routines to quantize linear add. - - For quantization formula as specified in the ONNX operator documentation is: - - Output = Saturate(RoundToEven(Input / Scale) + ZeroPoint) - ---*/ - -#include "qladd.h" - -// Pure C++ helper, back off here in rare case. -template -MLAS_FORCEINLINE -static -void -MlasQLinearAddKernelRawHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); - const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); - - float ValueB; - - if (IsScalarB) { - ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); - } - - for (size_t n = 0; n < N; n++) { - float ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA); - if (!IsScalarB) { - ValueB = ScaleB * (int32_t(InputB[n]) - ZeroPointB); - } - float ValueC = (ValueA + ValueB) / ScaleC; - ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue); - OutputC[n] = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC); - } -} - -#if defined(MLAS_NEON_INTRINSICS) - -bool MlasCalcQLinearAddParameters( - float ScaleRatio_AC, - float ScaleRatio_BC, - int32_t& Shift, - int32_t& MultiplierA, - int32_t& MultiplierB) { - constexpr float MinScaleRatio = 6.103515625e-05f; // std::stof("0x1.0p-14f"); - constexpr float MaxScaleRatio = 256.0f; //std::stof("0x1.0p+8f"); - if (ScaleRatio_AC < MinScaleRatio || ScaleRatio_AC >= MaxScaleRatio || - ScaleRatio_BC < MinScaleRatio || ScaleRatio_BC >= MaxScaleRatio) { - return false; - } - - const float GreaterScaleRatio = std::max(ScaleRatio_AC, ScaleRatio_BC); - const int32_t GreaterExponent = (int32_t)(MlasBitsOfFp32(GreaterScaleRatio) >> 23) - 127; - Shift = 21 - GreaterExponent; - if (Shift > 31 || Shift < 13) return false; - - const float MultiplierFloatValue = MlasFp32FromBits((uint32_t)(21 - GreaterExponent + 127) << 23); - MultiplierA = (int32_t)lrintf(ScaleRatio_AC * MultiplierFloatValue); - MultiplierB = (int32_t)lrintf(ScaleRatio_BC * MultiplierFloatValue); - return ((MultiplierA < 0x00400000 && MultiplierB < 0x00400000) && - (MultiplierA >= 0x00200000 || MultiplierB >= 0x00200000)); // the greater one must fullfil this check -} - -template -static -void -MlasQLinearAddKernelHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - typedef MLAS_SignedUnsignedIntOps SUI; - - int32_t Shift, MultiplierA, MultiplierB; - const float ScaleRatio_AC = ScaleA / ScaleC; - const float ScaleRatio_BC = ScaleB / ScaleC; - if (!MlasCalcQLinearAddParameters(ScaleRatio_AC, ScaleRatio_BC, Shift, MultiplierA, MultiplierB)) { - MlasQLinearAddKernelRawHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - return; - } - - const int32x4_t VectorMultiplierA = vld1q_dup_s32(&MultiplierA); - const int32x4_t VectorMultiplierB = vld1q_dup_s32(&MultiplierB); - const typename SUI::i8x8_t VectorZeroPointA = SUI::vmov_n_i8((DataType)ZeroPointA); - const typename SUI::i8x8_t VectorZeroPointB = SUI::vmov_n_i8((DataType)ZeroPointB); - const int16x8_t VectorZeroPointC = vmovq_n_s16((int16_t)ZeroPointC); - const int32x4_t vright_shift = vmovq_n_s32(-Shift); // vld1q_dup_s32(&right_shift); - const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); - - int32x4_t vscalar; - if (IsScalarB) { - const typename SUI::i8x8_t VectorB0 = SUI::vmov_n_i8(*InputB); - const int16x8_t vb_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorB0, VectorZeroPointB)); - vscalar = vmulq_s32(vmovl_s16(vget_low_s16(vb_s16x8)), VectorMultiplierB); - } - -#if defined(MLAS_NEON64_INTRINSICS) - - while (N >= 32) { - int32x4_t vacc0_lo, vacc0_hi, vacc1_lo, vacc1_hi, vacc2_lo, vacc2_hi, vacc3_lo, vacc3_hi; - if (IsScalarB) { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); - const typename SUI::i8x16_t VectorA1 = SUI::vld1q_i8(InputA + 16); - InputA += 32; - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - const int16x8_t va2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA1), VectorZeroPointA)); - const int16x8_t va3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA1), VectorZeroPointA)); - - vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc2_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va2_s16x8)), VectorMultiplierA); - vacc3_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va3_s16x8)), VectorMultiplierA); - vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - vacc2_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va2_s16x8), VectorMultiplierA); - vacc3_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va3_s16x8), VectorMultiplierA); - } else { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB); - const typename SUI::i8x16_t VectorA1 = SUI::vld1q_i8(InputA + 16); - const typename SUI::i8x16_t VectorB1 = SUI::vld1q_i8(InputB + 16); - InputA += 32; - InputB += 32; - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); - const int16x8_t va2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA1), VectorZeroPointA)); - const int16x8_t vb2_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB1), VectorZeroPointB)); - const int16x8_t va3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA1), VectorZeroPointA)); - const int16x8_t vb3_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB1), VectorZeroPointB)); - - vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc2_lo = vmulq_s32(vmovl_s16(vget_low_s16(va2_s16x8)), VectorMultiplierA); - vacc3_lo = vmulq_s32(vmovl_s16(vget_low_s16(va3_s16x8)), VectorMultiplierA); - vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - vacc2_hi = vmulq_s32(MlasMoveHighS16S32(va2_s16x8), VectorMultiplierA); - vacc3_hi = vmulq_s32(MlasMoveHighS16S32(va3_s16x8), VectorMultiplierA); - - vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); - vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); - vacc2_lo = vmlaq_s32(vacc2_lo, vmovl_s16(vget_low_s16(vb2_s16x8)), VectorMultiplierB); - vacc3_lo = vmlaq_s32(vacc3_lo, vmovl_s16(vget_low_s16(vb3_s16x8)), VectorMultiplierB); - vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); - vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); - vacc2_hi = vmlaq_s32(vacc2_hi, MlasMoveHighS16S32(vb2_s16x8), VectorMultiplierB); - vacc3_hi = vmlaq_s32(vacc3_hi, MlasMoveHighS16S32(vb3_s16x8), VectorMultiplierB); - } - - vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); - vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); - vacc2_lo = vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31); - vacc3_lo = vsraq_n_s32(vacc3_lo, vbicq_s32(vacc3_lo, vzero_shift_mask), 31); - vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); - vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); - vacc2_hi = vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31); - vacc3_hi = vsraq_n_s32(vacc3_hi, vbicq_s32(vacc3_hi, vzero_shift_mask), 31); - - vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); - vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); - vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift); - vacc3_lo = vrshlq_s32(vacc3_lo, vright_shift); - vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); - vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); - vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift); - vacc3_hi = vrshlq_s32(vacc3_hi, vright_shift); - - // Pack, saturate, and add output zero point. - const int16x8_t vacc0 = vqaddq_s16(MlasCombineS16S32(vacc0_lo, vacc0_hi), VectorZeroPointC); - const int16x8_t vacc1 = vqaddq_s16(MlasCombineS16S32(vacc1_lo, vacc1_hi), VectorZeroPointC); - const int16x8_t vacc2 = vqaddq_s16(MlasCombineS16S32(vacc2_lo, vacc2_hi), VectorZeroPointC); - const int16x8_t vacc3 = vqaddq_s16(MlasCombineS16S32(vacc3_lo, vacc3_hi), VectorZeroPointC); - - const typename SUI::i8x16_t vc0 = SUI::combine_i8_s16(vacc0, vacc1); - const typename SUI::i8x16_t vc1 = SUI::combine_i8_s16(vacc2, vacc3); - - SUI::vst1q_i8(OutputC, vc0); - SUI::vst1q_i8(OutputC + 16, vc1); - N -= 32; - OutputC += 32; - } - -#endif - - while (N >= 16) { - int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi; - if (IsScalarB) { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); - InputA += 16; - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - - vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - } else { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB); - InputA += 16; - InputB += 16; - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); - - vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - - vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); - vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); - vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); - vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); - } - - vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); - vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); - vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); - vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); - - vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); - vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); - vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); - vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); - - // Pack, saturate, and add output zero point. - const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC); - const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC); - typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1); - - N -= 16; - SUI::vst1q_i8(OutputC, vc); - OutputC += 16; - } - - if (N > 0) { - typename SUI::T TailDataA[16] = { 0 }; - typename SUI::T TailDataB[16] = { 0 }; - - MlasCopyTailBytes((uint8_t*)TailDataA, (const uint8_t*)InputA, N); - if (!IsScalarB) { - MlasCopyTailBytes((uint8_t*)TailDataB, (const uint8_t*)InputB, N); - } - - int32x4_t vacc0_lo, vacc1_lo, vacc0_hi, vacc1_hi; - if (IsScalarB) { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA); - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - - vacc0_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmlaq_s32(vscalar, vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc0_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmlaq_s32(vscalar, MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - } else { - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(TailDataA); - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(TailDataB); - const int16x8_t va0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_low_i8(VectorB0), VectorZeroPointB)); - const int16x8_t va1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorA0), VectorZeroPointA)); - const int16x8_t vb1_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(SUI::vget_high_i8(VectorB0), VectorZeroPointB)); - - vacc0_lo = vmulq_s32(vmovl_s16(vget_low_s16(va0_s16x8)), VectorMultiplierA); - vacc1_lo = vmulq_s32(vmovl_s16(vget_low_s16(va1_s16x8)), VectorMultiplierA); - vacc0_hi = vmulq_s32(MlasMoveHighS16S32(va0_s16x8), VectorMultiplierA); - vacc1_hi = vmulq_s32(MlasMoveHighS16S32(va1_s16x8), VectorMultiplierA); - - vacc0_lo = vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vb0_s16x8)), VectorMultiplierB); - vacc1_lo = vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vb1_s16x8)), VectorMultiplierB); - vacc0_hi = vmlaq_s32(vacc0_hi, MlasMoveHighS16S32(vb0_s16x8), VectorMultiplierB); - vacc1_hi = vmlaq_s32(vacc1_hi, MlasMoveHighS16S32(vb1_s16x8), VectorMultiplierB); - } - - vacc0_lo = vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); - vacc1_lo = vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); - vacc0_hi = vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); - vacc1_hi = vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); - - vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); - vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); - vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); - vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); - - // Pack, saturate, and add output zero point. - const int16x8_t vacc0 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), VectorZeroPointC); - const int16x8_t vacc1 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), VectorZeroPointC); - typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1); - - typename SUI::i8x8_t i8x8 = SUI::vget_low_i8(vc); - if (N & 8) { - SUI::vst1_i8(OutputC, i8x8); - OutputC += 8; - i8x8 = SUI::vget_high_i8(vc); - } - if (N & 4) { - vst1_lane_u32_ex((uint32_t*)OutputC, SUI::vreinterpret_u32_i8(i8x8), 0, 8); - OutputC += 4; - i8x8 = SUI::template vext_i8<4>(i8x8, i8x8); - } - if (N & 2) { - vst1_lane_u16_ex((uint16_t*)OutputC, SUI::vreinterpret_u16_i8(i8x8), 0, 8); - OutputC += 2; - i8x8 = SUI::template vext_i8<2>(i8x8, i8x8); - } - if (N & 1) { - SUI::template vst1_lane_i8<0>(OutputC, i8x8); - } - } -} - -#elif defined(MLAS_SSE2_INTRINSICS) - -template -static -void -MlasQLinearAddKernelHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float ScaleRatio_AC = ScaleA / ScaleC; - const float ScaleRatio_BC = ScaleB / ScaleC; - const auto VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); - const auto VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); - auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); - - MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; - if (IsScalarB) { - vb_lo = _mm_set1_ps((float)*InputB); - VectorFixedPart = _mm_add_ps(VectorFixedPart, _mm_mul_ps(vb_lo, VectorScaleRatio_BC)); - } - - while (N >= 8) { - const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputA); - const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half); - InputA += 8; - va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24)); - va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24)); - - if (!IsScalarB) { - const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)InputB); - const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half); - InputB += 8; - vb_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(vb_i16x8, vb_i16x8), 24)); - vb_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(vb_i16x8, vb_i16x8), 24)); - } - - MLAS_INT32X4 r_lo, r_hi; - if (IsScalarB) { - r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC))); - r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC))); - } else { - r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)), _mm_mul_ps(vb_lo, VectorScaleRatio_BC))); - r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC))); - } - const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi); - MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); - - N -= 8; - _mm_storel_epi64((MLAS_INT32X4*)OutputC, vc); - OutputC += 8; - } - - if (N > 0) { - uint8_t TailData[8] = { 0 }; - - MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); - const auto va_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData); - const auto va_i16x8 = _mm_unpacklo_epi8(va_low_half, va_low_half); - va_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(va_i16x8, va_i16x8), 24)); - va_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(va_i16x8, va_i16x8), 24)); - - if (!IsScalarB) { - MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); - const auto vb_low_half = _mm_loadl_epi64((const MLAS_INT32X4*)TailData); - const auto vb_i16x8 = _mm_unpacklo_epi8(vb_low_half, vb_low_half); - vb_lo = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpacklo_epi16(vb_i16x8, vb_i16x8), 24)); - vb_hi = _mm_cvtepi32_ps(MlasShiftRightInt32(_mm_unpackhi_epi16(vb_i16x8, vb_i16x8), 24)); - } - - MLAS_INT32X4 r_lo, r_hi; - if (IsScalarB) { - r_lo = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC))); - r_hi = _mm_cvtps_epi32(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC))); - } else { - r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_lo, VectorScaleRatio_AC)), _mm_mul_ps(vb_lo, VectorScaleRatio_BC))); - r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_add_ps(VectorFixedPart, _mm_mul_ps(va_hi, VectorScaleRatio_AC)), _mm_mul_ps(vb_hi, VectorScaleRatio_BC))); - } - const auto vc_i16x8 = _mm_packs_epi32(r_lo, r_hi); - MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); - - if (N & 4) { - *(int*)OutputC = _mm_cvtsi128_si32(vc); - N -= 4; - OutputC += 4; - vc = _mm_shuffle_epi32(vc, _MM_SHUFFLE(0, 3, 2, 1)); - } - - uint32_t PackedValueC = (uint32_t)_mm_cvtsi128_si32(vc); - for (size_t i = 0; i < N; ++i) { - *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; - PackedValueC >>= 8; - } - } -} -#elif defined(MLAS_TARGET_POWER) -template -static -void -MlasQLinearAddKernelHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - if (N >= 16) { - float ScaleRatio_AC = ScaleA / ScaleC; - float ScaleRatio_BC = ScaleB / ScaleC; - MLAS_FLOAT32X4 VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); - MLAS_FLOAT32X4 VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); - MLAS_FLOAT32X4 VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); - MLAS_FLOAT32X4 vb0_lo, vb0_hi, vb1_lo, vb1_hi; - const uint8_t flip = 128; - MLAS_UNREFERENCED_PARAMETER(flip); - __vector unsigned char vmask = reinterpret_cast<__vector unsigned char>(vec_splats(flip)); - __vector signed short vmask1 = reinterpret_cast<__vector signed short>(vec_splats((short)flip)); - - if (IsScalarB) { - vb0_lo = MlasBroadcastFloat32x4((float)*InputB); - VectorFixedPart = vec_add(VectorFixedPart, vec_mul(vb0_lo, VectorScaleRatio_BC)); - } - while (N >= 16) { - MLAS_INT32X4 r_lo, r_hi; - MLAS_FLOAT32X4 va_lo, va_hi; - MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_AC); - MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_BC); - auto va = MlasPackL8(InputA, vmask); - auto vshort = vec_unpackh(va); - vshort = MlasPackS16(vshort, vmask1); - auto va1 = vec_unpackl(vshort); - auto va0 = vec_unpackh(vshort); - va_lo = vec_ctf(va0, 0); - va_hi = vec_ctf(va1, 0); - if (!IsScalarB) { - auto vb = MlasPackL8(InputB, vmask); - vshort = vec_unpackh(vb); - vshort = MlasPackS16(vshort, vmask1); - auto vb1 = vec_unpackl(vshort); - auto vb0 = vec_unpackh(vshort); - vb0_lo = vec_ctf(vb0, 0); - vb0_hi= vec_ctf(vb1, 0); - vshort = vec_unpackl(vb); - vshort = MlasPackS16(vshort, vmask1); - vb1 = vec_unpackl(vshort); - vb0 = vec_unpackh(vshort); - vb1_lo = vec_ctf(vb0, 0); - vb1_hi= vec_ctf(vb1, 0); - InputB += 16; - } - va_lo = vec_mul(va_lo, VectorScaleRatio_AC); - va_hi = vec_mul(va_hi, VectorScaleRatio_AC); - if (IsScalarB) { - r_lo = vec_cts(vec_round(vec_add(VectorFixedPart, va_lo)), 0); - r_hi = vec_cts(vec_round(vec_add(VectorFixedPart, va_hi)), 0); - } else { - vb0_lo = vec_mul(vb0_lo, VectorScaleRatio_BC); - vb0_hi = vec_mul(vb0_hi, VectorScaleRatio_BC); - r_lo = vec_cts(vec_round(vec_add(vec_add(VectorFixedPart, va_lo), vb0_lo)), 0); - r_hi = vec_cts(vec_round(vec_add(vec_add(VectorFixedPart, va_hi), vb0_hi)), 0); - } - const auto vc0 = vec_packs(r_lo, r_hi); - vshort = vec_unpackl(va); - vshort = MlasPackS16(vshort, vmask1); - va1 = vec_unpackl(vshort); - va0 = vec_unpackh(vshort); - va_lo = vec_ctf(va0, 0); - va_hi = vec_ctf(va1, 0); - va_lo = vec_mul(va_lo, VectorScaleRatio_AC); - va_hi = vec_mul(va_hi, VectorScaleRatio_AC); - if (IsScalarB) { - r_lo = vec_cts(vec_round(vec_add(VectorFixedPart, va_lo)), 0); - r_hi = vec_cts(vec_round(vec_add(VectorFixedPart, va_hi)), 0); - } else { - vb1_lo = vec_mul(vb1_lo, VectorScaleRatio_BC); - vb1_hi = vec_mul(vb1_hi, VectorScaleRatio_BC); - r_lo = vec_cts(vec_round(vec_add(vec_add(VectorFixedPart, va_lo), vb1_lo)), 0); - r_hi = vec_cts(vec_round(vec_add(vec_add(VectorFixedPart, va_hi), vb1_hi)), 0); - } - const auto vc1 = vec_packs(r_lo, r_hi); - MLAS_INT32X4 vc = MlasPackS16_128(vc0, vc1); - vec_vsx_st(vc, 0, reinterpret_cast(OutputC)); - N -= 16; - InputA += 16; - OutputC += 16; - } - } - if (N > 0) { - MlasQLinearAddKernelRawHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } -} -#elif defined(MLAS_LSX_INTRINSICS) - -template -static -void -MlasQLinearAddKernelHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float ScaleRatio_AC = ScaleA / ScaleC; - const float ScaleRatio_BC = ScaleB / ScaleC; - const auto VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); - const auto VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); - auto VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); - - MLAS_FLOAT32X4 va_lo, va_hi, vb_lo, vb_hi; - if (IsScalarB) { - float tmp_f = (float)*InputB; - uint32_t *tmp_p = (uint32_t *)&tmp_f; - vb_lo = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w(*tmp_p)); - VectorFixedPart = __lsx_vfmadd_s(vb_lo, VectorScaleRatio_BC, VectorFixedPart); - } - - __m128i tmp, tmp1; - - while (N >= 8) { - const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputA, 0), 0 ,1); - const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); - InputA += 8; - va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); - va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); - - if (!IsScalarB) { - const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)InputB, 0), 0 ,1); - const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); - InputB += 8; - vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); - vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); - } - - MLAS_INT32X4 r_lo, r_hi; - if (IsScalarB) { - r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); - r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); - } else { - r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); - r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); - } - tmp = __lsx_vsat_w(r_lo, 15); - tmp1 = __lsx_vsat_w(r_hi, 15); - const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); - - MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); - - N -= 8; - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((MLAS_INT32X4*)OutputC, 0), __lsx_vpickve2gr_d(vc, 0), 0), (MLAS_INT32X4*)OutputC, 0); - OutputC += 8; - } - - if (N > 0) { - uint8_t TailData[8] = { 0 }; - - MlasCopyTailBytes(TailData, (const uint8_t*)InputA, N); - const auto va_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); - const auto va_i16x8 = __lsx_vilvl_b(va_low_half, va_low_half); - va_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(va_i16x8, va_i16x8), 24)); - va_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(va_i16x8, va_i16x8), 24)); - - if (!IsScalarB) { - MlasCopyTailBytes(TailData, (const uint8_t*)InputB, N); - const auto vb_low_half = __lsx_vinsgr2vr_d(__lsx_vld((const MLAS_INT32X4*)TailData, 0), 0 ,1); - const auto vb_i16x8 = __lsx_vilvl_b(vb_low_half, vb_low_half); - vb_lo = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvl_h(vb_i16x8, vb_i16x8), 24)); - vb_hi = __lsx_vffint_s_w(MlasShiftRightInt32(__lsx_vilvh_h(vb_i16x8, vb_i16x8), 24)); - } - - MLAS_INT32X4 r_lo, r_hi; - if (IsScalarB) { - r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart)); - r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart)); - } else { - r_lo = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_lo, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_lo, VectorScaleRatio_BC))); - r_hi = __lsx_vftint_w_s(__lsx_vfadd_s(__lsx_vfmadd_s(va_hi, VectorScaleRatio_AC, VectorFixedPart), __lsx_vfmul_s(vb_hi, VectorScaleRatio_BC))); - } - tmp = __lsx_vsat_w(r_lo, 15); - tmp1 = __lsx_vsat_w(r_hi, 15); - const auto vc_i16x8 = __lsx_vpickev_h(tmp1, tmp); - - MLAS_INT32X4 vc = MlasPackS16_128(vc_i16x8, vc_i16x8); - - if (N & 4) { - __lsx_vstelm_w(vc, (int*)OutputC, 0, 0); - N -= 4; - OutputC += 4; - vc = __lsx_vshuf4i_w(vc, 0x39); //_MM_SHUFFLE(0, 3, 2, 1) - } - - uint32_t PackedValueC = (uint32_t)__lsx_vpickve2gr_w(vc, 0); - for (size_t i = 0; i < N; ++i) { - *((uint8_t*)OutputC + i) = (uint8_t)PackedValueC; - PackedValueC >>= 8; - } - } -} -#else - -template -static -void -MlasQLinearAddKernelHelper( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - // Pure C++ implementation. - MlasQLinearAddKernelRawHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); -} - -#endif - -template -static -void -MLASCALL -MlasQLinearAddKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N, - bool IsScalarB - ) -{ - if (IsScalarB) { - MlasQLinearAddKernelHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else { - MlasQLinearAddKernelHelper( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } -} - -template<> -void -MLASCALL -MlasQLinearAdd( - const int8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const int8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - int8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QLinearAddS8Kernel( -#else - MlasQLinearAddKernel( -#endif - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); -} - -template<> -void -MLASCALL -MlasQLinearAdd( - const uint8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const uint8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - uint8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QLinearAddU8Kernel( -#else - MlasQLinearAddKernel( -#endif - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); -} - -// -// Function definition for platform usage -// - -void -MLASCALL -MlasQLinearAddS8Kernel( - const int8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const int8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - int8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ - MlasQLinearAddKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); -} - -void -MLASCALL -MlasQLinearAddU8Kernel( - const uint8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const uint8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - uint8_t* OutputC, - size_t N, - bool IsScalarB - ) -{ - MlasQLinearAddKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); -} diff --git a/onnxruntime/core/mlas/lib/qladd.h b/onnxruntime/core/mlas/lib/qladd.h deleted file mode 100644 index 94568941a5660..0000000000000 --- a/onnxruntime/core/mlas/lib/qladd.h +++ /dev/null @@ -1,594 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qladd.h - -Abstract: - - This module contains the private data structures and procedure prototypes - for QLinearAdd function usage . - ---*/ - -#pragma once - -#include "mlasi.h" - -MLAS_FORCEINLINE -static -void -MlasCopyTailBytes( - uint8_t* target, - const uint8_t* src, - size_t N) -{ - while (N >= sizeof(uint32_t)) { - *(uint32_t*)(target) = *(uint32_t*)(src); - N -= sizeof(uint32_t); - target += sizeof(uint32_t); - src += sizeof(uint32_t); - } - while (N > 0) { - *target++ = *src++; - --N; - } -} - -bool -MlasCalcQLinearAddParameters( - float ScaleRatio_AC, - float ScaleRatio_BC, - int32_t& Shift, - int32_t& MultiplierA, - int32_t& MultiplierB - ); - -#if defined(MLAS_NEON_INTRINSICS) - -#if ! defined(_MSC_VER) - -#define vld1q_s8_ex(pD, align) vld1q_s8((int8_t*)__builtin_assume_aligned(pD, ((align)/8))) -#define vst1_s8_ex(pD, D, align) vst1_s8((int8_t*)__builtin_assume_aligned(pD, ((align)/8)), D) -#define vst1q_s8_ex(pD, D, align) vst1q_s8((int8_t*)__builtin_assume_aligned(pD, ((align)/8)), D) -#define vld1q_u8_ex(pD, align) vld1q_u8((uint8_t*)__builtin_assume_aligned(pD, ((align)/8))) -#define vst1_u8_ex(pD, D, align) vst1_u8((uint8_t*)__builtin_assume_aligned(pD, ((align)/8)), D) -#define vst1q_u8_ex(pD, D, align) vst1q_u8((uint8_t*)__builtin_assume_aligned(pD, ((align)/8)), D) -#define vst1_lane_u32_ex(pD, D, lane, align) vst1_lane_u32((uint32_t*)__builtin_assume_aligned(pD, ((align)/8)), D, lane) -#define vst1_lane_u16_ex(pD, D, lane, align) vst1_lane_u16((uint16_t*)__builtin_assume_aligned(pD, ((align)/8)), D, lane) - -#endif - -template -class MLAS_SignedUnsignedIntOps; - -template <> -class MLAS_SignedUnsignedIntOps -{ -public: - typedef uint8_t T; - typedef uint8x8_t i8x8_t; - typedef uint8x16_t i8x16_t; - typedef uint16x8_t i16x8_t; - - static MLAS_FORCEINLINE i8x8_t vmov_n_i8(T value) - { - return vmov_n_u8(value); - } - - static MLAS_FORCEINLINE i8x8_t vget_low_i8(i8x16_t a) - { - return vget_low_u8(a); - } - - static MLAS_FORCEINLINE i8x8_t vget_high_i8(i8x16_t a) - { - return vget_high_u8(a); - } - - static MLAS_FORCEINLINE i16x8_t vsubl_i8(i8x8_t a, i8x8_t b) - { - return vsubl_u8(a, b); - } - - static MLAS_FORCEINLINE int16x8_t vreinterpretq_s16_i16(i16x8_t a) - { - return vreinterpretq_s16_u16(a); - } - - static MLAS_FORCEINLINE uint32x4_t vreinterpretq_u32_i8(i8x16_t a) - { - return vreinterpretq_u32_u8(a); - } - - static MLAS_FORCEINLINE uint16x8_t vreinterpretq_u16_i8(i8x16_t a) - { - return vreinterpretq_u16_u8(a); - } - - static MLAS_FORCEINLINE uint32x2_t vreinterpret_u32_i8(i8x8_t a) - { - return vreinterpret_u32_u8(a); - } - - static MLAS_FORCEINLINE uint16x4_t vreinterpret_u16_i8(i8x8_t a) - { - return vreinterpret_u16_u8(a); - } - - static MLAS_FORCEINLINE i8x16_t vld1q_i8(T const * ptr) - { - return vld1q_u8_ex(ptr, 8); - } - - static MLAS_FORCEINLINE void vst1_i8(T* ptr, i8x8_t a) - { - vst1_u8_ex(ptr, a, 8); - } - - static MLAS_FORCEINLINE void vst1q_i8(T* ptr, i8x16_t a) - { - vst1q_u8_ex(ptr, a, 8); - } - - template - static MLAS_FORCEINLINE void vst1_lane_i8(T* ptr, i8x8_t a) - { - vst1_lane_u8(ptr, a, n); - } - - template - static MLAS_FORCEINLINE i8x16_t vextq_i8(i8x16_t lo, i8x16_t hi) - { - return vextq_u8(lo, hi, n); - } - - template - static MLAS_FORCEINLINE i8x8_t vext_i8(i8x8_t lo, i8x8_t hi) - { - return vext_u8(lo, hi, n); - } - - static MLAS_FORCEINLINE i8x16_t combine_i8_s16(int16x8_t v0, int16x8_t v1) - { - -#if defined(MLAS_NEON64_INTRINSICS) - return vqmovun_high_s16(vqmovun_s16(v0), v1); -#else - return vcombine_u8(vqmovun_s16(v0), vqmovun_s16(v1)); -#endif - - } -}; - -template <> -class MLAS_SignedUnsignedIntOps -{ -public: - typedef int8_t T; - typedef int8x8_t i8x8_t; - typedef int8x16_t i8x16_t; - typedef int16x8_t i16x8_t; - - static MLAS_FORCEINLINE i8x8_t vmov_n_i8(T value) - { - return vmov_n_s8(value); - } - - static MLAS_FORCEINLINE i8x8_t vget_low_i8(i8x16_t a) - { - return vget_low_s8(a); - } - - static MLAS_FORCEINLINE i8x8_t vget_high_i8(i8x16_t a) - { - return vget_high_s8(a); - } - - static MLAS_FORCEINLINE i16x8_t vsubl_i8(i8x8_t a, i8x8_t b) - { - return vsubl_s8(a, b); - } - - static MLAS_FORCEINLINE int16x8_t vreinterpretq_s16_i16(i16x8_t a) - { - return a; - } - - static MLAS_FORCEINLINE uint32x4_t vreinterpretq_u32_i8(i8x16_t a) - { - return vreinterpretq_u32_s8(a); - } - - static MLAS_FORCEINLINE uint16x8_t vreinterpretq_u16_i8(i8x16_t a) - { - return vreinterpretq_u16_s8(a); - } - - static MLAS_FORCEINLINE uint32x2_t vreinterpret_u32_i8(i8x8_t a) - { - return vreinterpret_u32_s8(a); - } - - static MLAS_FORCEINLINE uint16x4_t vreinterpret_u16_i8(i8x8_t a) - { - return vreinterpret_u16_s8(a); - } - - static MLAS_FORCEINLINE i8x16_t vld1q_i8(T const * ptr) - { - return vld1q_s8_ex(ptr, 8); - } - - static MLAS_FORCEINLINE void vst1_i8(T* ptr, i8x8_t a) - { - vst1_s8_ex(ptr, a, 8); - } - - static MLAS_FORCEINLINE void vst1q_i8(T* ptr, i8x16_t a) - { - vst1q_s8_ex(ptr, a, 8); - } - - template - static MLAS_FORCEINLINE void vst1_lane_i8(T* ptr, i8x8_t a) - { - vst1_lane_s8(ptr, a, n); - } - - template - static MLAS_FORCEINLINE i8x16_t vextq_i8(i8x16_t lo, i8x16_t hi) - { - return vextq_s8(lo, hi, n); - } - - template - static MLAS_FORCEINLINE i8x8_t vext_i8(i8x8_t lo, i8x8_t hi) - { - return vext_s8(lo, hi, n); - } - - static MLAS_FORCEINLINE i8x16_t combine_i8_s16(int16x8_t v0, int16x8_t v1) - { - -#if defined(MLAS_NEON64_INTRINSICS) - return vqmovn_high_s16(vqmovn_s16(v0), v1); -#else - return vcombine_s8(vqmovn_s16(v0), vqmovn_s16(v1)); -#endif - - } -}; - -#if defined(MLAS_NEON64_INTRINSICS) - -#define MlasMoveHighS16S32(s16x8) vmovl_high_s16(s16x8) -#define MlasCombineS16S32(lo, hi) vqmovn_high_s32(vqmovn_s32(lo), hi) - -#else - -#define MlasMoveHighS16S32(s16x8) vmovl_s16(vget_high_s16(s16x8)) -#define MlasCombineS16S32(lo, hi) vcombine_s16(vqmovn_s32(lo), vqmovn_s32(hi)) - -#endif - -#elif defined(MLAS_SSE2_INTRINSICS) - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ); - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ) -{ - return _mm_srai_epi32(v, imm); -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ) -{ - return _mm_srli_epi32(v, imm); -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ); - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ) -{ - return _mm_srai_epi16(v, imm); -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ) -{ - return _mm_srli_epi16(v, imm); -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ); - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ) -{ - return _mm_packus_epi16(a, b); -} - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ) -{ - return _mm_packs_epi16(a, b); -} - -#elif defined(MLAS_TARGET_POWER) -typedef __vector signed char MLAS_INT8; -typedef __vector short MLAS_SHORT; -template -MLAS_FORCEINLINE -MLAS_INT8 -MlasPackL8( - const DataType* Input, - __vector unsigned char vmask - ); - -template <> -MLAS_FORCEINLINE -MLAS_INT8 -MlasPackL8( - const uint8_t* Input, - __vector unsigned char vmask - ) -{ - __vector unsigned char va = vec_vsx_ld(0,Input); - return reinterpret_cast(vec_sub(reinterpret_cast<__vector unsigned char>(va), vmask)); -} - -template <> -MLAS_FORCEINLINE -MLAS_INT8 -MlasPackL8( - const int8_t* Input, - __vector unsigned char vmask - ) -{ - MLAS_UNREFERENCED_PARAMETER(vmask); - return reinterpret_cast(vec_vsx_ld(0,Input)); -} - -template -MLAS_FORCEINLINE -MLAS_SHORT -MlasPackS16( - __vector short a, - __vector short b - ); - -template <> -MLAS_FORCEINLINE -MLAS_SHORT -MlasPackS16( - __vector short a, - __vector short b - ) -{ - return vec_add(a, b); -} - -template <> -MLAS_FORCEINLINE -MLAS_SHORT -MlasPackS16( - __vector short a, - __vector short b - ) -{ - MLAS_UNREFERENCED_PARAMETER(b); - return a; -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - __vector short a, - __vector short b - ); - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - __vector short a, - __vector short b - ) -{ - return reinterpret_cast(vec_packsu(a, b)); -} - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - __vector short a, - __vector short b - ) -{ - return reinterpret_cast(vec_packs(a, b)); -} -#elif defined(MLAS_LSX_INTRINSICS) - -#define LSX_DBG 1 -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ); - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ) -{ -#if LSX_DBG - MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); - return __lsx_vsra_w(v, imm_v); -#else - return __lsx_vsrai_w(v, imm); -#endif -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt32( - MLAS_INT32X4 v, - int imm - ) -{ -#if LSX_DBG - MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_w(imm); - return __lsx_vsrl_w(v, imm_v); -#else - return __lsx_vsrli_w(v, imm); -#endif -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ); - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ) -{ -#if LSX_DBG - MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); - return __lsx_vsra_h(v, imm_v); -#else - return __lsx_vsrai_h(v, imm); -#endif -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasShiftRightInt16( - MLAS_INT32X4 v, - int imm - ) -{ -#if LSX_DBG - MLAS_INT32X4 imm_v = __lsx_vreplgr2vr_h(imm); - return __lsx_vsrl_h(v, imm_v); -#else - return __lsx_vsrli_h(v, imm); -#endif -} - -template -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ); - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ) -{ - // return _mm_packus_epi16(a, b); - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(zero, a); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(zero, b); - tmp3 = __lsx_vsat_hu(tmp, 7); - return __lsx_vpickev_b(tmp3, tmp2); - -} - -template <> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasPackS16_128( - MLAS_INT32X4 a, - MLAS_INT32X4 b - ) -{ - // return _mm_packs_epi16(a, b); - __m128i tmp, tmp1; - - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); - -} -#endif diff --git a/onnxruntime/core/mlas/lib/qlgavgpool.cpp b/onnxruntime/core/mlas/lib/qlgavgpool.cpp deleted file mode 100644 index e44d7ad25c446..0000000000000 --- a/onnxruntime/core/mlas/lib/qlgavgpool.cpp +++ /dev/null @@ -1,1183 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qlgavgpool.cpp - -Abstract: - - This module implements routines for quantized linear global average pool. - ---*/ - -#include "mlasi.h" - -size_t -MLASCALL -MlasQLinearSafePaddingElementCount( - size_t ElementSize, - size_t ElementCount - ) -{ - if (!(ElementSize == 1 || ElementSize == 2 || ElementSize == 4 || ElementSize == 8 || - ElementSize == 16)) { - MLAS_THROW_EX(std::invalid_argument, - "ElementSize must be power of 2 and less or equal than 16!"); - } - return ElementCount + (size_t{256} / ElementSize - 1); -} - -MLAS_FORCEINLINE -float -CheckQLinearGlobalAveragePoolScaleAndSize( - float ScaleInput, - float ScaleOutput, - size_t ImageSize - ) -{ - if (ImageSize >= 0x1000000) { - MLAS_THROW_EX(std::invalid_argument, "QLinearGlobalAveragePool ImageSize too large!"); - } - - float scale = ScaleInput / (ScaleOutput * static_cast(ImageSize)); - if (scale < 0x1.0p-32f || scale >= 256.0f) { - // In first case, the scale is too small, ScaleInput/ScaleOutput < 1/256 no matter what - // ImageSize In second case, the scale is too large, ScaleInput/ScaleOutput >= 256 no matter - // what Image Size both case make output value constant, and hence not meaningful. - MLAS_THROW_EX(std::invalid_argument, - "QLinearGlobalAveragePool parameter out of computation range!"); - } - return scale; -} - -#if defined(MLAS_NEON_INTRINSICS) - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; - const int32x4_t vbias = vld1q_s32(bias); - const int32x4_t vzero = vmovq_n_s32(0); - const uint8_t* InputU8 = (const uint8_t*)(Input); - - int32_t* sum_buffer = AccumulateBuffer; - uint8_t tail_buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - for (size_t c = Channels; c > 0; c--) { - - int32x4_t vacc_lo = vbias; - int32x4_t vacc_hi = vzero; - auto Len = ImageSize; - for (; Len >= 32; Len -= 32) { - - const uint8x8_t vi0 = vld1_u8(InputU8); - const uint8x8_t vi1 = vld1_u8(InputU8 + 8); - const uint8x8_t vi2 = vld1_u8(InputU8 + 16); - const uint8x8_t vi3 = vld1_u8(InputU8 + 24); - - int16x8_t vsum; - if constexpr (std::is_signed::value) { - - const int16x8_t vs01 = vaddl_s8(vreinterpret_s8_u8(vi0), vreinterpret_s8_u8(vi1)); - const int16x8_t vs23 = vaddl_s8(vreinterpret_s8_u8(vi2), vreinterpret_s8_u8(vi3)); - vsum = vaddq_s16(vs01, vs23); - } else { - - const uint16x8_t vs01 = vaddl_u8(vi0, vi1); - const uint16x8_t vs23 = vaddl_u8(vi2, vi3); - vsum = vreinterpretq_s16_u16(vaddq_u16(vs01, vs23)); - } - - vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); - vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)); - InputU8 += 32; - } - for (; Len >= 8; Len -= 8) { - - int16x8_t vsum; - if constexpr (std::is_signed::value) { - vsum = vmovl_s8(vreinterpret_s8_u8(vld1_u8(InputU8))); - } else { - vsum = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(InputU8))); - } - vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); - vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)); - InputU8 += 8; - } - - if (Len > 0) { - - memcpy(tail_buffer, InputU8, Len); - int16x8_t vsum; - if constexpr (std::is_signed::value) { - vsum = vmovl_s8(vreinterpret_s8_u8(vld1_u8(tail_buffer))); - } else { - vsum = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(tail_buffer))); - } - - vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); - vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)); - InputU8 += Len; - } - - vacc_lo = vaddq_s32(vacc_lo, vacc_hi); - int32x2_t vacc = vadd_s32(vget_high_s32(vacc_lo), vget_low_s32(vacc_lo)); - *sum_buffer++ = vget_lane_s32(vpadd_s32(vacc, vacc), 0); - } - - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, - static_cast(ZeroPointOutput), 0, 0, 1, Channels); -} - -template -MLAS_FORCEINLINE -void -MlasQLinearGlobalAveragePoolNhwcSingleBatch( - const T8Bits* Input, - T8Bits* Output, - const T8Bits* LastOf8, - size_t ImageSize, - size_t Channels, - size_t Stride, - int32_t Bias, - float Scale, - T8Bits Output_zero_point, - int32_t* AccumulateBuffer, - const T8Bits* ZeroBuffer - ) -{ -#define LOAD_FULL_CHANNELS() \ - const uint8x8_t vi0 = vld1_u8(i0); \ - i0 += 8; \ - const uint8x8_t vi1 = vld1_u8(i1); \ - i1 += 8; \ - const uint8x8_t vi2 = vld1_u8(i2); \ - i2 += 8; \ - const uint8x8_t vi3 = vld1_u8(i3); \ - i3 += 8; \ - const uint8x8_t vi4 = vld1_u8(i4); \ - i4 += 8; \ - const uint8x8_t vi5 = vld1_u8(i5); \ - i5 += 8; \ - const uint8x8_t vi6 = vld1_u8(i6); \ - i6 += 8 - -#define CALCULATE_ACCUMULATE_VECTORS() \ - int32x4_t vacc_lo = finish_one_pass ? vld1q_s32(acc) : vbias; \ - int32x4_t vacc_hi = finish_one_pass ? vld1q_s32(acc + 4) : vbias; \ - int16x8_t vsum; \ - if constexpr (std::is_signed::value) { \ - const int16x8_t vsum01 = vaddl_s8(vreinterpret_s8_u8(vi0), vreinterpret_s8_u8(vi1)); \ - const int16x8_t vsum23 = vaddl_s8(vreinterpret_s8_u8(vi2), vreinterpret_s8_u8(vi3)); \ - const int16x8_t vsum45 = vaddl_s8(vreinterpret_s8_u8(vi4), vreinterpret_s8_u8(vi5)); \ - const int16x8_t vsum016 = vaddw_s8(vsum01, vreinterpret_s8_u8(vi6)); \ - const int16x8_t vsum2345 = vaddq_s16(vsum23, vsum45); \ - vsum = vaddq_s16(vsum016, vsum2345); \ - } else { \ - const uint16x8_t vsum01 = vaddl_u8(vi0, vi1); \ - const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); \ - const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); \ - const uint16x8_t vsum016 = vaddw_u8(vsum01, vi6); \ - const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45); \ - vsum = vreinterpretq_s16_u16(vaddq_u16(vsum016, vsum2345)); \ - } \ - vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); \ - vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)) - - uint8_t tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - const int32x4_t vbias = vld1q_dup_s32(&Bias); - bool finish_one_pass = false; - const size_t step_next_group = 7 * Stride - (Channels & ~size_t{7}); - - const uint8_t* LastOf8U8 = (const uint8_t*)LastOf8; - const uint8_t* i0 = (const uint8_t*)Input; - const uint8_t* i1 = i0 + Stride; - const uint8_t* i4 = i0 + Stride * 4; - const uint8_t* i2 = i1 + Stride; - const uint8_t* i5 = i4 + Stride; - const uint8_t* i3 = i2 + Stride; - const uint8_t* i6 = i5 + Stride; - - for (; ImageSize > 7; ImageSize -= 7) { - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - vst1q_s32(acc, vacc_lo); - vst1q_s32(acc + 4, vacc_hi); - acc += 8; - } - if (c > 0) { - - const uint8x8_t vi0 = vld1_u8(((i0 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i0, c) : i0)); - const uint8x8_t vi1 = vld1_u8(((i1 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i1, c) : i1)); - const uint8x8_t vi2 = vld1_u8(((i2 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i2, c) : i2)); - const uint8x8_t vi3 = vld1_u8(((i3 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i3, c) : i3)); - const uint8x8_t vi4 = vld1_u8(((i4 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i4, c) : i4)); - const uint8x8_t vi5 = vld1_u8(((i5 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i5, c) : i5)); - const uint8x8_t vi6 = vld1_u8(((i6 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i6, c) : i6)); - - CALCULATE_ACCUMULATE_VECTORS(); - - vst1q_s32(acc, vacc_lo); - vst1q_s32(acc + 4, vacc_hi); - } - finish_one_pass = true; - - i0 += step_next_group; - i1 += step_next_group; - i2 += step_next_group; - i3 += step_next_group; - i4 += step_next_group; - i5 += step_next_group; - i6 += step_next_group; - } - - if (ImageSize > 0) { - - switch (ImageSize) { - case 1: - i1 = (const uint8_t*)ZeroBuffer; /* fall through */ - case 2: - i2 = (const uint8_t*)ZeroBuffer; /* fall through */ - case 3: - i3 = (const uint8_t*)ZeroBuffer; /* fall through */ - case 4: - i4 = (const uint8_t*)ZeroBuffer; /* fall through */ - case 5: - i5 = (const uint8_t*)ZeroBuffer; /* fall through */ - case 6: - i6 = (const uint8_t*)ZeroBuffer; /* fall through */ - default: - break; - } - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - vst1q_s32(acc, vacc_lo); - vst1q_s32(acc + 4, vacc_hi); - acc += 8; - } - - if (c > 0) { - - const uint8x8_t vi0 = - vld1_u8(((i0 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i0, c) : i0)); - const uint8x8_t vi1 = vld1_u8( - ((1 < ImageSize && i1 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i1, c) : i1)); - const uint8x8_t vi2 = vld1_u8( - ((2 < ImageSize && i2 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i2, c) : i2)); - const uint8x8_t vi3 = vld1_u8( - ((3 < ImageSize && i3 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i3, c) : i3)); - const uint8x8_t vi4 = vld1_u8( - ((4 < ImageSize && i4 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i4, c) : i4)); - const uint8x8_t vi5 = vld1_u8( - ((5 < ImageSize && i5 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i5, c) : i5)); - const uint8x8_t vi6 = vld1_u8( - ((6 < ImageSize && i6 >= LastOf8U8) ? (const uint8_t*)memcpy(tail, i6, c) : i6)); - - CALCULATE_ACCUMULATE_VECTORS(); - - vst1q_s32(acc, vacc_lo); - vst1q_s32(acc + 4, vacc_hi); - } - } - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, - Output_zero_point, 0, 0, 1, Channels); -} - -#elif defined(MLAS_SSE2_INTRINSICS) - -template -void MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - const int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; - const auto vbias = _mm_loadu_si128((const __m128i*)&bias); - const auto vzero = _mm_setzero_si128(); - uint8_t buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - int32_t* sum_buffer = AccumulateBuffer; - for (size_t c = Channels; c > 0; c--) { - - __m128i vacc_lo = vbias; - __m128i vacc_hi = vzero; - auto Len = ImageSize; - for (; Len >= 32; Len -= 32) { - - const __m128i vi0 = _mm_loadl_epi64((const __m128i*)Input); - const __m128i vi1 = _mm_loadl_epi64((const __m128i*)(Input + 8)); - const __m128i vi2 = _mm_loadl_epi64((const __m128i*)(Input + 16)); - const __m128i vi3 = _mm_loadl_epi64((const __m128i*)(Input + 24)); - - if constexpr (std::is_signed::value) { - - const __m128i vxi0 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi0), 8); - const __m128i vxi1 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi1), 8); - const __m128i vxi2 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi2), 8); - const __m128i vxi3 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi3), 8); - const __m128i vsum = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), - _mm_add_epi16(vxi2, vxi3)); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_srai_epi32(_mm_unpacklo_epi16(vzero, vsum), 16)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_srai_epi32(_mm_unpackhi_epi16(vzero, vsum), 16)); - } else { - - const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); - const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); - const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); - const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); - const __m128i vsum = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), - _mm_add_epi16(vxi2, vxi3)); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); - } - - Input += 32; - } - for (; Len >= 8; Len -= 8) { - - if constexpr (std::is_signed::value) { - - const __m128i vsum = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, _mm_loadl_epi64((const __m128i*)Input)), 8); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_srai_epi32(_mm_unpacklo_epi16(vzero, vsum), 16)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_srai_epi32(_mm_unpackhi_epi16(vzero, vsum), 16)); - } else { - - const __m128i vsum = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i*)Input), vzero); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); - } - - Input += 8; - } - if (Len > 0) { - - memcpy(buffer, Input, Len); - - if constexpr (std::is_signed::value) { - - const __m128i vsum = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, _mm_loadl_epi64((const __m128i*)buffer)), 8); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_srai_epi32(_mm_unpacklo_epi16(vzero, vsum), 16)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_srai_epi32(_mm_unpackhi_epi16(vzero, vsum), 16)); - } else { - - const __m128i vsum = _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i*)buffer), vzero); - vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); - vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); - } - - Input += Len; - } - - __m128i vacc = _mm_add_epi32(vacc_lo, vacc_hi); // [ D C | B A ] - __m128i vshuf = _mm_shuffle_epi32(vacc, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ] - __m128i vsums = _mm_add_epi32(vacc, vshuf); // [ D+C C+D | B+A A+B ] - vshuf = _mm_shuffle_epi32(vsums, _MM_SHUFFLE(1, 0, 3, 2)); // [ B+A A+B | D+C C+D ] - vsums = _mm_add_epi32(vsums, vshuf); - *sum_buffer++ = _mm_cvtsi128_si32(vsums); - } - - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, - static_cast(ZeroPointOutput), 0, 0, 1, Channels); -} - -template -MLAS_FORCEINLINE -void -MlasQLinearGlobalAveragePoolNhwcSingleBatch( - const T8Bits* Input, - T8Bits* Output, - const T8Bits* LastOf8, - size_t ImageSize, - size_t Channels, - size_t Stride, - int32_t Bias, - float Scale, - T8Bits Output_zero_point, - int32_t* AccumulateBuffer, - const T8Bits* ZeroBuffer - ) -{ -#if defined(MLAS_TARGET_IX86) - - constexpr size_t PixelsPerIteration = 4; - -#define LOAD_FULL_CHANNELS() \ - const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); \ - i0 += 8; \ - const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); \ - i1 += 8; \ - const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); \ - i2 += 8; \ - const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); \ - i3 += 8; - -#define CALCULATE_ACCUMULATE_VECTORS() \ - __m128i vacc_lo = finish_one_pass ? _mm_loadu_si128((__m128i*)acc) : vbias; \ - __m128i vacc_hi = finish_one_pass ? _mm_loadu_si128(((__m128i*)acc) + 1) : vbias; \ - __m128i vxi0; \ - __m128i vxi1; \ - __m128i vxi2; \ - __m128i vxi3; \ - if constexpr (std::is_signed::value) { \ - vxi0 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi0), 8); \ - vxi1 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi1), 8); \ - vxi2 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi2), 8); \ - vxi3 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi3), 8); \ - } else { \ - vxi0 = _mm_unpacklo_epi8(vi0, vzero); \ - vxi1 = _mm_unpacklo_epi8(vi1, vzero); \ - vxi2 = _mm_unpacklo_epi8(vi2, vzero); \ - vxi3 = _mm_unpacklo_epi8(vi3, vzero); \ - } \ - __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); \ - __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); \ - __m128i vsum = _mm_add_epi16(vsum01, vsum23); \ - \ - if constexpr (std::is_signed::value) { \ - vacc_lo = _mm_add_epi32(vacc_lo, _mm_srai_epi32(_mm_unpacklo_epi16(vzero, vsum), 16)); \ - vacc_hi = _mm_add_epi32(vacc_hi, _mm_srai_epi32(_mm_unpackhi_epi16(vzero, vsum), 16)); \ - } else { \ - vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); \ - vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); \ - } - -#else - - constexpr size_t PixelsPerIteration = 7; -#define LOAD_FULL_CHANNELS() \ - const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); \ - i0 += 8; \ - const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); \ - i1 += 8; \ - const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); \ - i2 += 8; \ - const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); \ - i3 += 8; \ - const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); \ - i4 += 8; \ - const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); \ - i5 += 8; \ - const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); \ - i6 += 8 - -#define CALCULATE_ACCUMULATE_VECTORS() \ - __m128i vacc_lo = finish_one_pass ? _mm_loadu_si128((__m128i*)acc) : vbias; \ - __m128i vacc_hi = finish_one_pass ? _mm_loadu_si128(((__m128i*)acc) + 1) : vbias; \ - __m128i vxi0; \ - __m128i vxi1; \ - __m128i vxi2; \ - __m128i vxi3; \ - __m128i vxi4; \ - __m128i vxi5; \ - __m128i vxi6; \ - if constexpr (std::is_signed::value) { \ - vxi0 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi0), 8); \ - vxi1 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi1), 8); \ - vxi2 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi2), 8); \ - vxi3 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi3), 8); \ - vxi4 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi4), 8); \ - vxi5 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi5), 8); \ - vxi6 = _mm_srai_epi16(_mm_unpacklo_epi8(vzero, vi6), 8); \ - } else { \ - vxi0 = _mm_unpacklo_epi8(vi0, vzero); \ - vxi1 = _mm_unpacklo_epi8(vi1, vzero); \ - vxi2 = _mm_unpacklo_epi8(vi2, vzero); \ - vxi3 = _mm_unpacklo_epi8(vi3, vzero); \ - vxi4 = _mm_unpacklo_epi8(vi4, vzero); \ - vxi5 = _mm_unpacklo_epi8(vi5, vzero); \ - vxi6 = _mm_unpacklo_epi8(vi6, vzero); \ - } \ - const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); \ - const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); \ - const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); \ - const __m128i vsum016 = _mm_add_epi16(vsum01, vxi6); \ - const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45); \ - const __m128i vsum = _mm_add_epi16(vsum016, vsum2345); \ - if constexpr (std::is_signed::value) { \ - vacc_lo = _mm_add_epi32(vacc_lo, _mm_srai_epi32(_mm_unpacklo_epi16(vzero, vsum), 16)); \ - vacc_hi = _mm_add_epi32(vacc_hi, _mm_srai_epi32(_mm_unpackhi_epi16(vzero, vsum), 16)); \ - } else { \ - vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); \ - vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); \ - } - -#endif - - T8Bits tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - bool finish_one_pass = false; - const __m128i vbias = _mm_set1_epi32(Bias); - const __m128i vzero = _mm_setzero_si128(); - size_t step_next_group = PixelsPerIteration * Stride - (Channels & ~size_t{7}); - - const T8Bits* i0 = Input; - const T8Bits* i1 = i0 + Stride; - const T8Bits* i2 = i1 + Stride; - const T8Bits* i3 = i2 + Stride; -#if !defined(MLAS_TARGET_IX86) - const T8Bits* i4 = i0 + Stride * 4; - const T8Bits* i5 = i4 + Stride; - const T8Bits* i6 = i5 + Stride; -#endif - - for (; ImageSize > PixelsPerIteration; ImageSize -= PixelsPerIteration) { - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - _mm_storeu_si128((__m128i*)acc, vacc_lo); - _mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi); - acc += 8; - } - if (c > 0) { - const __m128i vi0 = - _mm_loadl_epi64((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0)); - const __m128i vi1 = - _mm_loadl_epi64((const __m128i*)(i1 >= LastOf8 ? memcpy(tail, i1, c) : i1)); - const __m128i vi2 = - _mm_loadl_epi64((const __m128i*)(i2 >= LastOf8 ? memcpy(tail, i2, c) : i2)); - const __m128i vi3 = - _mm_loadl_epi64((const __m128i*)(i3 >= LastOf8 ? memcpy(tail, i3, c) : i3)); -#if !defined(MLAS_TARGET_IX86) - const __m128i vi4 = - _mm_loadl_epi64((const __m128i*)(i4 >= LastOf8 ? memcpy(tail, i4, c) : i4)); - const __m128i vi5 = - _mm_loadl_epi64((const __m128i*)(i5 >= LastOf8 ? memcpy(tail, i5, c) : i5)); - const __m128i vi6 = - _mm_loadl_epi64((const __m128i*)(i6 >= LastOf8 ? memcpy(tail, i6, c) : i6)); -#endif - - CALCULATE_ACCUMULATE_VECTORS(); - - _mm_storeu_si128((__m128i*)acc, vacc_lo); - _mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi); - } - finish_one_pass = true; - - i0 += step_next_group; - i1 += step_next_group; - i2 += step_next_group; - i3 += step_next_group; -#if !defined(MLAS_TARGET_IX86) - i4 += step_next_group; - i5 += step_next_group; - i6 += step_next_group; -#endif - } - - if (ImageSize > 0) { -#if defined(MLAS_TARGET_IX86) - switch (ImageSize) { - case 1: - i1 = ZeroBuffer; - [[fallthrough]]; - case 2: - i2 = ZeroBuffer; - [[fallthrough]]; - case 3: - i3 = ZeroBuffer; - [[fallthrough]]; - default: - break; - } -#else - switch (ImageSize) { - case 1: - i1 = ZeroBuffer; - [[fallthrough]]; - case 2: - i2 = ZeroBuffer; - [[fallthrough]]; - case 3: - i3 = ZeroBuffer; - [[fallthrough]]; - case 4: - i4 = ZeroBuffer; - [[fallthrough]]; - case 5: - i5 = ZeroBuffer; - [[fallthrough]]; - case 6: - i6 = ZeroBuffer; - [[fallthrough]]; - default: - break; - } -#endif - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - _mm_storeu_si128((__m128i*)acc, vacc_lo); - _mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi); - acc += 8; - } - - if (c > 0) { - const __m128i vi0 = - _mm_loadl_epi64((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0)); - const __m128i vi1 = _mm_loadl_epi64( - (const __m128i*)(1 < ImageSize && i1 >= LastOf8 ? memcpy(tail, i1, c) : i1)); - const __m128i vi2 = _mm_loadl_epi64( - (const __m128i*)(2 < ImageSize && i2 >= LastOf8 ? memcpy(tail, i2, c) : i2)); - const __m128i vi3 = _mm_loadl_epi64( - (const __m128i*)(3 < ImageSize && i3 >= LastOf8 ? memcpy(tail, i3, c) : i3)); -#if !defined(MLAS_TARGET_IX86) - const __m128i vi4 = _mm_loadl_epi64( - (const __m128i*)(4 < ImageSize && i4 >= LastOf8 ? memcpy(tail, i4, c) : i4)); - const __m128i vi5 = _mm_loadl_epi64( - (const __m128i*)(5 < ImageSize && i5 >= LastOf8 ? memcpy(tail, i5, c) : i5)); - const __m128i vi6 = _mm_loadl_epi64( - (const __m128i*)(6 < ImageSize && i6 >= LastOf8 ? memcpy(tail, i6, c) : i6)); -#endif - - CALCULATE_ACCUMULATE_VECTORS(); - - _mm_storeu_si128((__m128i*)acc, vacc_lo); - _mm_storeu_si128(((__m128i*)acc) + 1, vacc_hi); - } - } - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, - Output_zero_point, 0, 0, 1, Channels); -} - -#elif defined(MLAS_LSX_INTRINSICS) - -template -void MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - const int32_t bias[] = {-ZeroPointInput * static_cast(ImageSize), 0, 0, 0}; - const auto vbias = __lsx_vld((const __m128i*)&bias, 0); - const auto vzero = __lsx_vldi(0); - uint8_t buffer[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - - int32_t* sum_buffer = AccumulateBuffer; - for (size_t c = Channels; c > 0; c--) { - - __m128i vacc_lo = vbias; - __m128i vacc_hi = vzero; - auto Len = ImageSize; - for (; Len >= 32; Len -= 32) { - - const __m128i vi0 = __lsx_vld((const __m128i*)Input, 0); - __lsx_vinsgr2vr_d(vi0, 0, 1); - const __m128i vi1 = __lsx_vld((const __m128i*)(Input + 8), 0); - __lsx_vinsgr2vr_d(vi1, 0, 1); - const __m128i vi2 = __lsx_vld((const __m128i*)(Input + 16), 0); - __lsx_vinsgr2vr_d(vi2, 0, 1); - const __m128i vi3 = __lsx_vld((const __m128i*)(Input + 24), 0); - __lsx_vinsgr2vr_d(vi3, 0, 1); - - if constexpr (std::is_signed::value) { - - const __m128i vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); - const __m128i vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); - const __m128i vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); - const __m128i vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); - const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), - __lsx_vadd_h(vxi2, vxi3)); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); - } else { - - const __m128i vxi0 = __lsx_vilvl_b(vzero, vi0); - const __m128i vxi1 = __lsx_vilvl_b(vzero, vi1); - const __m128i vxi2 = __lsx_vilvl_b(vzero, vi2); - const __m128i vxi3 = __lsx_vilvl_b(vzero, vi3); - const __m128i vsum = __lsx_vadd_h(__lsx_vadd_h(vxi0, vxi1), - __lsx_vadd_h(vxi2, vxi3)); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); - } - - Input += 32; - } - for (; Len >= 8; Len -= 8) { - - if constexpr (std::is_signed::value) { - - const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1), vzero), 8); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); - } else { - - const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)Input, 0), 0, 1)); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); - } - - Input += 8; - } - if (Len > 0) { - - memcpy(buffer, Input, Len); - - if constexpr (std::is_signed::value) { - - const __m128i vsum = __lsx_vsrai_h(__lsx_vilvl_b(__lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1), vzero), 8); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); - } else { - - const __m128i vsum = __lsx_vilvl_b(vzero, __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)buffer, 0), 0, 1)); - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); - } - - Input += Len; - } - - __m128i vacc = __lsx_vadd_w(vacc_lo, vacc_hi); // [ D C | B A ] - __m128i vshuf = __lsx_vshuf4i_w(vacc, 0xb1); // [ C D | A B ] _MM_SHUFFLE(2, 3, 0, 1) - __m128i vsums = __lsx_vadd_w(vacc, vshuf); // [ D+C C+D | B+A A+B ] - vshuf = __lsx_vshuf4i_w(vsums, 0x4e); // [ B+A A+B | D+C C+D ] _MM_SHUFFLE(1, 0, 3, 2) - vsums = __lsx_vadd_w(vsums, vshuf); - __lsx_vstelm_w(vsums, sum_buffer++, 0 , 0); - } - - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &scale, false, - static_cast(ZeroPointOutput), 0, 0, 1, Channels); -} - -template -MLAS_FORCEINLINE -void -MlasQLinearGlobalAveragePoolNhwcSingleBatch( - const T8Bits* Input, - T8Bits* Output, - const T8Bits* LastOf8, - size_t ImageSize, - size_t Channels, - size_t Stride, - int32_t Bias, - float Scale, - T8Bits Output_zero_point, - int32_t* AccumulateBuffer, - const T8Bits* ZeroBuffer - ) -{ - - constexpr size_t PixelsPerIteration = 7; -#define LOAD_FULL_CHANNELS() \ - const __m128i vi0 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i0, 0), 0 , 1); \ - i0 += 8; \ - const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i1, 0), 0 , 1); \ - i1 += 8; \ - const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i2, 0), 0 , 1); \ - i2 += 8; \ - const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i3, 0), 0 , 1); \ - i3 += 8; \ - const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i4, 0), 0 , 1); \ - i4 += 8; \ - const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i5, 0), 0 , 1); \ - i5 += 8; \ - const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)i6, 0), 0 , 1); \ - i6 += 8 - -#define CALCULATE_ACCUMULATE_VECTORS() \ - __m128i vacc_lo = finish_one_pass ? __lsx_vld((__m128i*)acc, 0) : vbias; \ - __m128i vacc_hi = finish_one_pass ? __lsx_vld(((__m128i*)acc) + 1, 0) : vbias; \ - __m128i vxi0; \ - __m128i vxi1; \ - __m128i vxi2; \ - __m128i vxi3; \ - __m128i vxi4; \ - __m128i vxi5; \ - __m128i vxi6; \ - if constexpr (std::is_signed::value) { \ - vxi0 = __lsx_vsrai_h(__lsx_vilvl_b(vi0, vzero), 8); \ - vxi1 = __lsx_vsrai_h(__lsx_vilvl_b(vi1, vzero), 8); \ - vxi2 = __lsx_vsrai_h(__lsx_vilvl_b(vi2, vzero), 8); \ - vxi3 = __lsx_vsrai_h(__lsx_vilvl_b(vi3, vzero), 8); \ - vxi4 = __lsx_vsrai_h(__lsx_vilvl_b(vi4, vzero), 8); \ - vxi5 = __lsx_vsrai_h(__lsx_vilvl_b(vi5, vzero), 8); \ - vxi6 = __lsx_vsrai_h(__lsx_vilvl_b(vi6, vzero), 8); \ - } else { \ - vxi0 = __lsx_vilvl_b(vzero, vi0); \ - vxi1 = __lsx_vilvl_b(vzero, vi1); \ - vxi2 = __lsx_vilvl_b(vzero, vi2); \ - vxi3 = __lsx_vilvl_b(vzero, vi3); \ - vxi4 = __lsx_vilvl_b(vzero, vi4); \ - vxi5 = __lsx_vilvl_b(vzero, vi5); \ - vxi6 = __lsx_vilvl_b(vzero, vi6); \ - } \ - const __m128i vsum01 = __lsx_vadd_h(vxi0, vxi1); \ - const __m128i vsum23 = __lsx_vadd_h(vxi2, vxi3); \ - const __m128i vsum45 = __lsx_vadd_h(vxi4, vxi5); \ - const __m128i vsum016 = __lsx_vadd_h(vsum01, vxi6); \ - const __m128i vsum2345 = __lsx_vadd_h(vsum23, vsum45); \ - const __m128i vsum = __lsx_vadd_h(vsum016, vsum2345); \ - if constexpr (std::is_signed::value) { \ - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vsrai_w(__lsx_vilvl_h(vsum, vzero), 16)); \ - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vsrai_w(__lsx_vilvh_h(vsum, vzero), 16)); \ - } else { \ - vacc_lo = __lsx_vadd_w(vacc_lo, __lsx_vilvl_h(vzero, vsum)); \ - vacc_hi = __lsx_vadd_w(vacc_hi, __lsx_vilvh_h(vzero, vsum)); \ - } - - - T8Bits tail[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - bool finish_one_pass = false; - const __m128i vbias = __lsx_vreplgr2vr_w(Bias); - const __m128i vzero = __lsx_vldi(0); - size_t step_next_group = PixelsPerIteration * Stride - (Channels & ~size_t{7}); - - const T8Bits* i0 = Input; - const T8Bits* i1 = i0 + Stride; - const T8Bits* i2 = i1 + Stride; - const T8Bits* i3 = i2 + Stride; - const T8Bits* i4 = i0 + Stride * 4; - const T8Bits* i5 = i4 + Stride; - const T8Bits* i6 = i5 + Stride; - - for (; ImageSize > PixelsPerIteration; ImageSize -= PixelsPerIteration) { - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - __lsx_vst(vacc_lo, (__m128i*)acc, 0); - __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); - acc += 8; - } - if (c > 0) { - const __m128i vi0 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); - const __m128i vi1 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0 ,1); - const __m128i vi2 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0 ,1); - const __m128i vi3 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0 ,1); - const __m128i vi4 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0 ,1); - const __m128i vi5 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0 ,1); - const __m128i vi6 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0 ,1); - - CALCULATE_ACCUMULATE_VECTORS(); - - __lsx_vst(vacc_lo, (__m128i*)acc, 0); - __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); - } - finish_one_pass = true; - - i0 += step_next_group; - i1 += step_next_group; - i2 += step_next_group; - i3 += step_next_group; - i4 += step_next_group; - i5 += step_next_group; - i6 += step_next_group; - } - - if (ImageSize > 0) { - switch (ImageSize) { - case 1: - i1 = ZeroBuffer; - [[fallthrough]]; - case 2: - i2 = ZeroBuffer; - [[fallthrough]]; - case 3: - i3 = ZeroBuffer; - [[fallthrough]]; - case 4: - i4 = ZeroBuffer; - [[fallthrough]]; - case 5: - i5 = ZeroBuffer; - [[fallthrough]]; - case 6: - i6 = ZeroBuffer; - [[fallthrough]]; - default: - break; - } - - int32_t* acc = AccumulateBuffer; - size_t c = Channels; - for (; c >= 8; c -= 8) { - - LOAD_FULL_CHANNELS(); - - CALCULATE_ACCUMULATE_VECTORS(); - - __lsx_vst(vacc_lo, (__m128i*)acc, 0); - __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); - acc += 8; - } - - if (c > 0) { - const __m128i vi0 = - __lsx_vinsgr2vr_d(__lsx_vld((const __m128i*)(i0 >= LastOf8 ? memcpy(tail, i0, c) : i0), 0), 0 ,1); - const __m128i vi1 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(1 < ImageSize && i1 >= LastOf8 ? memcpy(tail, i1, c) : i1), 0), 0, 1); - const __m128i vi2 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(2 < ImageSize && i2 >= LastOf8 ? memcpy(tail, i2, c) : i2), 0), 0, 1); - const __m128i vi3 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(3 < ImageSize && i3 >= LastOf8 ? memcpy(tail, i3, c) : i3), 0), 0, 1); - const __m128i vi4 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(4 < ImageSize && i4 >= LastOf8 ? memcpy(tail, i4, c) : i4), 0), 0, 1); - const __m128i vi5 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(5 < ImageSize && i5 >= LastOf8 ? memcpy(tail, i5, c) : i5), 0), 0, 1); - const __m128i vi6 = __lsx_vinsgr2vr_d(__lsx_vld( - (const __m128i*)(6 < ImageSize && i6 >= LastOf8 ? memcpy(tail, i6, c) : i6), 0), 0, 1); - - CALCULATE_ACCUMULATE_VECTORS(); - - __lsx_vst(vacc_lo, (__m128i*)acc, 0); - __lsx_vst(vacc_hi, ((__m128i*)acc) + 1, 0); - } - } - MlasRequantizeOutput(AccumulateBuffer, Channels, Output, Channels, nullptr, &Scale, false, - Output_zero_point, 0, 0, 1, Channels); -} - -#else - -// Pure C++ Implementation - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* /* AccumulateBuffer */ - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - int32_t bias = -ZeroPointInput * static_cast(ImageSize); - for (; Channels > 0; Channels--) { - - int32_t acc = bias; - for (size_t i = 0; i < ImageSize; ++i) { - acc += static_cast(*Input++); - } - int32_t v = static_cast(std::nearbyintf(acc * scale)) + ZeroPointOutput; - v = std::min(static_cast(std::numeric_limits::max()), v); - v = std::max(static_cast(std::numeric_limits::lowest()), v); - *Output++ = static_cast(v); - } -} - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNhwc( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Batch, - size_t ImageSize, - size_t Stride, - size_t Channels, - int32_t* AccumulateBuffer, - const T8Bits* /*ZeroBuffer*/ - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - int32_t bias = -ZeroPointInput * static_cast(ImageSize); - for (; Batch > 0; Batch--) { - - const T8Bits* batch_input = Input; - T8Bits* batch_output = Output; - Input += Stride * ImageSize; - Output += Stride; - std::fill_n(AccumulateBuffer, Channels, bias); - for (size_t i = 0; i < ImageSize; ++i) { - - for (size_t c = 0; c < Channels; ++c) { - AccumulateBuffer[c] += static_cast(batch_input[c]); - } - - batch_input += Stride; - } - - for (size_t c = 0; c < Channels; ++c) { - - int32_t v = static_cast(std::nearbyintf(AccumulateBuffer[c] * scale)) + ZeroPointOutput; - v = std::min(static_cast(std::numeric_limits::max()), v); - v = std::max(static_cast(std::numeric_limits::lowest()), v); - *batch_output++ = static_cast(v); - } - } -} - -#endif - -#if defined(MLAS_NEON_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNhwc( - const T8Bits* Input, - float ScaleInput, - int32_t ZeroPointInput, - T8Bits* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Batch, - size_t ImageSize, - size_t Stride, - size_t Channels, - int32_t* AccumulateBuffer, - const T8Bits* ZeroBuffer - ) -{ - float scale = CheckQLinearGlobalAveragePoolScaleAndSize(ScaleInput, ScaleOutput, ImageSize); - const int32_t bias = -ZeroPointInput * static_cast(ImageSize); - const T8Bits* inputLastOf8 = Input + (Batch * ImageSize * Stride - Stride + Channels) - 8; - - for (; Batch > 0; Batch--) { - MlasQLinearGlobalAveragePoolNhwcSingleBatch( - Input, Output, inputLastOf8, ImageSize, Channels, Stride, bias, scale, - static_cast(ZeroPointOutput), AccumulateBuffer, ZeroBuffer); - Input += ImageSize * Stride; - Output += Stride; - } -} - -#endif - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const int8_t* Input, - float ScaleInput, - int32_t ZeroPointInput, - int8_t* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ); - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNchw( - const uint8_t* Input, - float ScaleInput, - int32_t ZeroPointInput, - uint8_t* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Channels, - size_t ImageSize, - int32_t* AccumulateBuffer - ); - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNhwc( - const int8_t* Input, - float ScaleInput, - int32_t ZeroPointInput, - int8_t* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Batch, - size_t ImageSize, - size_t Stride, - size_t Channels, - int32_t* AccumulateBuffer, - const int8_t* ZeroBuffer - ); - -template -void -MLASCALL -MlasQLinearGlobalAveragePoolNhwc( - const uint8_t* Input, - float ScaleInput, - int32_t ZeroPointInput, - uint8_t* Output, - float ScaleOutput, - int32_t ZeroPointOutput, - size_t Batch, - size_t ImageSize, - size_t Stride, - size_t Channels, - int32_t* AccumulateBuffer, - const uint8_t* ZeroBuffer - ); diff --git a/onnxruntime/core/mlas/lib/qlmul.cpp b/onnxruntime/core/mlas/lib/qlmul.cpp deleted file mode 100644 index 4a6d57db0d211..0000000000000 --- a/onnxruntime/core/mlas/lib/qlmul.cpp +++ /dev/null @@ -1,650 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qlmul.cpp - -Abstract: - - This module implements routines to quantize linear mul. - - For quantization formula as specified in the ONNX operator documentation is: - - Output = Saturate(RoundToEven(Input / Scale) + ZeroPoint) - ---*/ - -#include "qladd.h" - -#if defined(MLAS_NEON64_INTRINSICS) - -template -MLAS_FORCEINLINE -static -int16x8_t -MlasExtendToS16Debias( - typename SUI::i8x16_t Int8Vector, - typename SUI::i8x8_t VectorBias - ) -{ - auto HalfVector = IsLow ? SUI::vget_low_i8(Int8Vector) : SUI::vget_high_i8(Int8Vector); - return SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(HalfVector, VectorBias)); -} - -MLAS_FORCEINLINE -static -int16x8_t -MlasQLinearMulVectorS16( - int16x8_t va_s16x8, - int16x8_t vb_s16x8, - float32x4_t VectorScaleRatio, - float32x4_t VectorZeroPointC - ) -{ - int32x4_t vacc0_lo = vmull_s16(vget_low_s16(va_s16x8), vget_low_s16(vb_s16x8)); - int32x4_t vacc0_hi = vmull_s16(vget_high_s16(va_s16x8), vget_high_s16(vb_s16x8)); - auto vacc0_lo_f32 = vaddq_f32(VectorZeroPointC, vmulq_f32(VectorScaleRatio, vcvtq_f32_s32(vacc0_lo))); - auto vacc0_hi_f32 = vaddq_f32(VectorZeroPointC, vmulq_f32(VectorScaleRatio, vcvtq_f32_s32(vacc0_hi))); - // using rounding to nearst, ties to even - vacc0_lo = vcvtnq_s32_f32(vacc0_lo_f32); - vacc0_hi = vcvtnq_s32_f32(vacc0_hi_f32); - // Pack and saturate. - return vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)); -} - -template -static -void -MlasQLinearMulKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - typedef MLAS_SignedUnsignedIntOps SUI; - - const float32x4_t VectorScaleRatio = vmovq_n_f32(ScaleA * ScaleB / ScaleC); - const typename SUI::i8x8_t VectorZeroPointA = SUI::vmov_n_i8((DataType)ZeroPointA); - const typename SUI::i8x8_t VectorZeroPointB = SUI::vmov_n_i8((DataType)ZeroPointB); - const float32x4_t VectorZeroPointC = vmovq_n_f32((float)ZeroPointC); - - typename SUI::T TailDataA[16] = { 0 }; - typename SUI::T TailDataB[16] = { 0 }; - int16x8_t vb0_s16x8, vb1_s16x8; - if (IsScalarB) { - const typename SUI::i8x8_t VectorB0 = SUI::vmov_n_i8(*InputB); - vb0_s16x8 = SUI::vreinterpretq_s16_i16(SUI::vsubl_i8(VectorB0, VectorZeroPointB)); - vb1_s16x8 = vb0_s16x8; - } - - while (N > 0) { - if (N < 16) { - MlasCopyTailBytes((uint8_t*)TailDataA, (const uint8_t*)InputA, N); - InputA = (const DataType*)TailDataA; - if (!IsScalarB) { - MlasCopyTailBytes((uint8_t*)TailDataB, (const uint8_t*)InputB, N); - InputB = (const DataType*)TailDataB; - } - } - - const typename SUI::i8x16_t VectorA0 = SUI::vld1q_i8(InputA); - InputA += 16; - const int16x8_t va0_s16x8 = MlasExtendToS16Debias(VectorA0, VectorZeroPointA); - const int16x8_t va1_s16x8 = MlasExtendToS16Debias(VectorA0, VectorZeroPointA);; - - if (!IsScalarB) { - const typename SUI::i8x16_t VectorB0 = SUI::vld1q_i8(InputB); - InputB += 16; - vb0_s16x8 = MlasExtendToS16Debias(VectorB0, VectorZeroPointB); - vb1_s16x8 = MlasExtendToS16Debias(VectorB0, VectorZeroPointB); - } - - const int16x8_t vacc0 = MlasQLinearMulVectorS16(va0_s16x8, vb0_s16x8, VectorScaleRatio, VectorZeroPointC); - const int16x8_t vacc1 = MlasQLinearMulVectorS16(va1_s16x8, vb1_s16x8, VectorScaleRatio, VectorZeroPointC); - typename SUI::i8x16_t vc = SUI::combine_i8_s16(vacc0, vacc1); - - if (N >= 16) { - N -= 16; - SUI::vst1q_i8(OutputC, vc); - OutputC += 16; - } else { - SUI::vst1q_i8(TailDataA, vc); - MlasCopyTailBytes((uint8_t*)OutputC, (const uint8_t*)TailDataA, N); - N = 0; - } - } -} - -#elif defined(MLAS_SSE2_INTRINSICS) - -template -MLAS_FORCEINLINE -static -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ); - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - return _mm_unpacklo_epi8(Int8Vector, ZeroVector); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - return _mm_unpackhi_epi8(Int8Vector, ZeroVector); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - MLAS_UNREFERENCED_PARAMETER(ZeroVector); - return _mm_srai_epi16(_mm_unpacklo_epi8(Int8Vector, Int8Vector), 8); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - MLAS_UNREFERENCED_PARAMETER(ZeroVector); - return _mm_srai_epi16(_mm_unpackhi_epi8(Int8Vector, Int8Vector), 8); -} - -template -MLAS_FORCEINLINE -static -__m128i -MlasExtendToS16Debias( - __m128i Int8Vector, - __m128i ZeroVector, - __m128i VectorBias - ) -{ - return _mm_sub_epi16(MlasExtendToS16(Int8Vector, ZeroVector), VectorBias); -} - -MLAS_FORCEINLINE -static -__m128i -MlasQLinearMulVectorS16( - __m128i va_s16x8, - __m128i vb_s16x8, - __m128 VectorScaleRatio, - __m128 VectorZeroPointC - ) -{ - const auto ab_lo = _mm_mullo_epi16(va_s16x8, vb_s16x8); - const auto ab_hi = _mm_mulhi_epi16(va_s16x8, vb_s16x8); - auto r_lo = _mm_unpacklo_epi16(ab_lo, ab_hi); - auto r_hi = _mm_unpackhi_epi16(ab_lo, ab_hi); - r_lo = _mm_cvtps_epi32(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(r_lo), VectorScaleRatio), VectorZeroPointC)); - r_hi = _mm_cvtps_epi32(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(r_hi), VectorScaleRatio), VectorZeroPointC)); - return _mm_packs_epi32(r_lo, r_hi); -} - -template -static -void -MlasQLinearMulKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const auto VectorZeroPointA = _mm_set1_epi16((int16_t)ZeroPointA); - const auto VectorZeroPointB = _mm_set1_epi16((int16_t)ZeroPointB); - const auto VectorZeroPointC = MlasBroadcastFloat32x4((float)ZeroPointC); - const auto VectorScaleRatio = MlasBroadcastFloat32x4(ScaleA * ScaleB / ScaleC); - const auto ZeroVector = _mm_setzero_si128(); - - uint8_t TailDataA[16] = { 0 }; - uint8_t TailDataB[16] = { 0 }; - __m128i vb_lo_s16x8, vb_hi_s16x8; - - if (IsScalarB) { - vb_lo_s16x8 = _mm_sub_epi16(_mm_set1_epi16((int16_t)*InputB), VectorZeroPointB); - vb_hi_s16x8 = vb_lo_s16x8; - } - - while (N > 0) { - if (N < 16) { - MlasCopyTailBytes(TailDataA, (const uint8_t*)InputA, N); - InputA = (const DataType*)TailDataA; - if (!IsScalarB) { - MlasCopyTailBytes(TailDataB, (const uint8_t*)InputB, N); - InputB = (const DataType*)TailDataB; - } - } - - const auto va_i8x16 = _mm_loadu_si128((const MLAS_INT32X4*)InputA); - InputA += 16; - const auto va_lo_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); - const auto va_hi_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); - - if (!IsScalarB) { - const auto vb_i8x16 = _mm_loadu_si128((const MLAS_INT32X4*)InputB); - InputB += 16; - vb_lo_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); - vb_hi_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); - } - - const auto vc_lo_s16x8 = MlasQLinearMulVectorS16(va_lo_s16x8, vb_lo_s16x8, VectorScaleRatio, VectorZeroPointC); - const auto vc_hi_s16x8 = MlasQLinearMulVectorS16(va_hi_s16x8, vb_hi_s16x8, VectorScaleRatio, VectorZeroPointC); - auto vc = MlasPackS16_128(vc_lo_s16x8, vc_hi_s16x8); - - if (N >= 16) { - _mm_storeu_si128((__m128i*)OutputC, vc); - OutputC += 16; - N -= 16; - } else { - _mm_storeu_si128((__m128i*)TailDataA, vc); - MlasCopyTailBytes((uint8_t*)OutputC, TailDataA, N); - N = 0; - } - } -} - -#elif defined(MLAS_VSX_INTRINSICS) - -template -static -void -MlasQLinearMulKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); - const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); - - auto ZeroPointAVector = vec_splats(int32_t(ZeroPointA)); - auto ZeroPointBVector = vec_splats(int32_t(ZeroPointB)); - auto ZeroPointCVector = vec_splats(float(ZeroPointC)); - - auto ScaleAVector = vec_splats(ScaleA); - auto ScaleBVector = vec_splats(ScaleB); - auto ScaleCVector = vec_splats(ScaleC); - - auto MinimumVector = vec_splats(MinimumValue); - auto MaximumVector = vec_splats(MaximumValue); - - float ValueB; - __vector float ValueBVector; - - if (IsScalarB) { - ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); - ValueBVector = vec_splats(ValueB); - } - - while (N >= 4) { -#if defined(_AIX) && defined(__clang__) - __vector int IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; -#else - __vector int32_t IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; -#endif - auto IntegerVector = vec_sub(IntegerAVector, ZeroPointAVector); - auto ValueAVector = vec_mul(ScaleAVector, vec_ctf(IntegerVector, 0)); - - if (!IsScalarB) { -#if defined(_AIX) && defined(__clang__) - __vector int IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; -#else - __vector int32_t IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; -#endif - IntegerVector = vec_sub(IntegerBVector, ZeroPointBVector); - ValueBVector = vec_mul(ScaleBVector, vec_ctf(IntegerVector, 0)); - } - - auto ValueCVector = vec_div(vec_mul(ValueAVector, ValueBVector), ScaleCVector); - ValueCVector = vec_min(vec_max(ValueCVector, MinimumVector), MaximumVector); - ValueCVector = vec_nearbyint(vec_add(ValueCVector, ZeroPointCVector)); - - auto IntegerValueCVector = vec_signed(ValueCVector); - OutputC[0] = (DataType) IntegerValueCVector[0]; - OutputC[1] = (DataType) IntegerValueCVector[1]; - OutputC[2] = (DataType) IntegerValueCVector[2]; - OutputC[3] = (DataType) IntegerValueCVector[3]; - - OutputC += 4; - InputA += 4; - InputB += 4; - - N -= 4; - - // Suppress wrong GCC warnings - MLAS_UNREFERENCED_PARAMETER(ValueAVector); - } - - while (N > 0) { - float ValueA = ScaleA * (int32_t(*InputA) - ZeroPointA); - if (!IsScalarB) { - ValueB = ScaleB * (int32_t(*InputB) - ZeroPointB); - } - float ValueC = (ValueA * ValueB) / ScaleC; - ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue); - - *OutputC = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC); - - InputA++; - InputB++; - OutputC++; - N--; - } - - // Suppress wrong GCC warnings - MLAS_UNREFERENCED_PARAMETER(ScaleAVector); - MLAS_UNREFERENCED_PARAMETER(ScaleBVector); - MLAS_UNREFERENCED_PARAMETER(ValueBVector); -} - -#elif defined(MLAS_LSX_INTRINSICS) - -template -MLAS_FORCEINLINE -static -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ); - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - return __lsx_vilvl_b(ZeroVector, Int8Vector); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - return __lsx_vilvh_b(ZeroVector, Int8Vector); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - MLAS_UNREFERENCED_PARAMETER(ZeroVector); - return __lsx_vsrai_h(__lsx_vilvl_b(Int8Vector, Int8Vector), 8); -} - -template <> -MLAS_FORCEINLINE -__m128i -MlasExtendToS16( - __m128i Int8Vector, - __m128i ZeroVector - ) -{ - MLAS_UNREFERENCED_PARAMETER(ZeroVector); - return __lsx_vsrai_h(__lsx_vilvh_b(Int8Vector, Int8Vector), 8); -} - -template -MLAS_FORCEINLINE -static -__m128i -MlasExtendToS16Debias( - __m128i Int8Vector, - __m128i ZeroVector, - __m128i VectorBias - ) -{ - return __lsx_vsub_h(MlasExtendToS16(Int8Vector, ZeroVector), VectorBias); -} - -MLAS_FORCEINLINE -static -__m128i -MlasQLinearMulVectorS16( - __m128i va_s16x8, - __m128i vb_s16x8, - __m128 VectorScaleRatio, - __m128 VectorZeroPointC - ) -{ - __m128i tmp, tmp1; - - const auto ab_lo = __lsx_vmul_h(va_s16x8, vb_s16x8); - const auto ab_hi = __lsx_vmuh_h(va_s16x8, vb_s16x8); - auto r_lo = __lsx_vilvl_h(ab_hi, ab_lo); - auto r_hi = __lsx_vilvh_h(ab_hi, ab_lo); - r_lo = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_lo), VectorScaleRatio, VectorZeroPointC)); - r_hi = __lsx_vftint_w_s(__lsx_vfmadd_s(__lsx_vffint_s_w(r_hi), VectorScaleRatio, VectorZeroPointC)); - - tmp = __lsx_vsat_w(r_lo, 15); - tmp1 = __lsx_vsat_w(r_hi, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -template -static -void -MlasQLinearMulKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const auto VectorZeroPointA = __lsx_vreplgr2vr_h((int16_t)ZeroPointA); - const auto VectorZeroPointB = __lsx_vreplgr2vr_h((int16_t)ZeroPointB); - const auto VectorZeroPointC = MlasBroadcastFloat32x4((float)ZeroPointC); - const auto VectorScaleRatio = MlasBroadcastFloat32x4(ScaleA * ScaleB / ScaleC); - const auto ZeroVector = __lsx_vldi(0); - - uint8_t TailDataA[16] = { 0 }; - uint8_t TailDataB[16] = { 0 }; - __m128i vb_lo_s16x8, vb_hi_s16x8; - - if (IsScalarB) { - vb_lo_s16x8 = __lsx_vsub_h(__lsx_vreplgr2vr_h((int16_t)*InputB), VectorZeroPointB); - vb_hi_s16x8 = vb_lo_s16x8; - } - - while (N > 0) { - if (N < 16) { - MlasCopyTailBytes(TailDataA, (const uint8_t*)InputA, N); - InputA = (const DataType*)TailDataA; - if (!IsScalarB) { - MlasCopyTailBytes(TailDataB, (const uint8_t*)InputB, N); - InputB = (const DataType*)TailDataB; - } - } - - const auto va_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputA, 0); - InputA += 16; - const auto va_lo_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); - const auto va_hi_s16x8 = MlasExtendToS16Debias(va_i8x16, ZeroVector, VectorZeroPointA); - - if (!IsScalarB) { - const auto vb_i8x16 = __lsx_vld((const MLAS_INT32X4*)InputB, 0); - InputB += 16; - vb_lo_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); - vb_hi_s16x8 = MlasExtendToS16Debias(vb_i8x16, ZeroVector, VectorZeroPointB); - } - - const auto vc_lo_s16x8 = MlasQLinearMulVectorS16(va_lo_s16x8, vb_lo_s16x8, VectorScaleRatio, VectorZeroPointC); - const auto vc_hi_s16x8 = MlasQLinearMulVectorS16(va_hi_s16x8, vb_hi_s16x8, VectorScaleRatio, VectorZeroPointC); - auto vc = MlasPackS16_128(vc_lo_s16x8, vc_hi_s16x8); - - if (N >= 16) { - __lsx_vst(vc, (__m128i*)OutputC, 0); - OutputC += 16; - N -= 16; - } else { - __lsx_vst(vc, (__m128i*)TailDataA, 0); - MlasCopyTailBytes((uint8_t*)OutputC, TailDataA, N); - N = 0; - } - } -} - - -#else - -// Pure C++ implementation. -template -static -void -MlasQLinearMulKernel( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N - ) -{ - const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); - const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); - - float ValueB; - - if (IsScalarB) { - ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); - } - - for (size_t n = 0; n < N; n++) { - float ValueA = ScaleA * (int32_t(InputA[n]) - ZeroPointA); - if (!IsScalarB) { - ValueB = ScaleB * (int32_t(InputB[n]) - ZeroPointB); - } - float ValueC = (ValueA * ValueB) / ScaleC; - ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue); - OutputC[n] = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC); - } -} - -#endif - -template -void -MLASCALL -MlasQLinearMul( - const DataType* InputA, - float ScaleA, - int32_t ZeroPointA, - const DataType* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - DataType* OutputC, - size_t N, - bool IsScalarB - ) -{ - if (IsScalarB) { - MlasQLinearMulKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } else { - MlasQLinearMulKernel( - InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); - } -} - -// Explicit instantiation -template -void -MlasQLinearMul( - const uint8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const uint8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - uint8_t* OutputC, - size_t N, - bool IsScalarB - ); - -template -void -MlasQLinearMul( - const int8_t* InputA, - float ScaleA, - int32_t ZeroPointA, - const int8_t* InputB, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC, - int8_t* OutputC, - size_t N, - bool IsScalarB - ); diff --git a/onnxruntime/core/mlas/lib/qpostprocessor.cpp b/onnxruntime/core/mlas/lib/qpostprocessor.cpp deleted file mode 100644 index 97e9000a19b30..0000000000000 --- a/onnxruntime/core/mlas/lib/qpostprocessor.cpp +++ /dev/null @@ -1,238 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - qpostprocessor.cpp - -Abstract: - - This module implements the post processor for QGEMM. - ---*/ - -#include "mlasi.h" - -void MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::Process( - const int32_t* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const -{ - if (Bias_) { - if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { - if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } else { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } - } else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } else { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } - } else { - if (QuantGran_ == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { - if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } else { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } - } else if (OutputMode_ == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } else { - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); - } - } -} - -template -inline -void -MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR::ProcessImpl( - const int32_t* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc) const -/*++ - -Routine Description: - - This routine converts the output matrix C to a floating point format using - the stored scale and bias parameters. - -Arguments: - - C - Supplies the address of matrix C. - - StartM - Supplies the starting row offset relative to the matrix. - - StartN - Supplies the starting column offset relative to the matrix. - - CountM - Supplies the number of rows of the output matrix to process. - - CountN - Supplies the number of columns of the output matrix to process. - - ldc - Supplies the leading dimension of C. - -Return Value: - - None. - ---*/ -{ - float* Output = Output_; - const float* Bias = Bias_; - const float* Scale = Scale_; - - if (HasBias) { - Bias += StartN; - } - - if(QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn){ - Scale += StartN; - } - - MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale_); -#if !defined(MLAS_SSE2_INTRINSICS) - float ScaleValue = MlasExtractLaneFloat32x4<0>(ScaleVector); -#endif - - C += StartM * ldc + StartN; - Output += StartM * LeadingDimensionOutput_ + StartN; - - - while (CountM-- > 0) { - - float* c_out = Output; - const int32_t* c = C; - const float* bias = Bias; - const float* scale = Scale; - - size_t n = CountN; - - while (n >= 4) { - - MLAS_FLOAT32X4 FloatVector = MlasCastToFloat32x4(MlasLoadInt32x4(c)); - - if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { - ScaleVector = MlasLoadFloat32x4(scale); - scale += 4; - } - - if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - FloatVector = MlasMultiplyAddFloat32x4(FloatVector, ScaleVector, MlasLoadFloat32x4(c_out)); - } else { - FloatVector = MlasMultiplyFloat32x4(FloatVector, ScaleVector); - } - - if (HasBias) { - FloatVector = MlasAddFloat32x4(FloatVector, MlasLoadFloat32x4(bias)); - bias += 4; - } - - MlasStoreFloat32x4(c_out, FloatVector); - - c_out += 4; - c += 4; - n -= 4; - } - - for (size_t offset = 0; offset < n; offset++) { - -#if defined(MLAS_SSE2_INTRINSICS) - __m128 FloatVector = _mm_set_ss(float(c[offset])); - - if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { - ScaleVector = _mm_load_ss(&scale[offset]); - } - - if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - FloatVector = _mm_add_ps(_mm_mul_ss(FloatVector, ScaleVector), _mm_load_ss(&c_out[offset])); - } else { - FloatVector = _mm_mul_ss(FloatVector, ScaleVector); - } - - if (HasBias) { - FloatVector = _mm_add_ss(FloatVector, _mm_load_ss(&bias[offset])); - } - - _mm_store_ss(&c_out[offset], FloatVector); -#else - if (QuantGran == MLAS_QUANTIZATION_GRANULARITY::PerColumn) { - ScaleValue = scale[offset]; - } - - float result = float(c[offset]) * ScaleValue; - if (HasBias) { - result += bias[offset]; - } - - if (Mode == MLAS_QGEMM_OUTPUT_MODE::AccumulateMode) { - c_out[offset] += result; - } else { - c_out[offset] = result; - } -#endif - } - - C += ldc; - Output += LeadingDimensionOutput_; - } -} diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp deleted file mode 100644 index ae638fafee18f..0000000000000 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ /dev/null @@ -1,2121 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - quantize.cpp - -Abstract: - - This module implements routines to quantize buffers. - - For quantization formula as specified in the ONNX operator documentation is: - - Output = Saturate(RoundToEven(Input / Scale) + ZeroPoint) - ---*/ - -#include "mlasi.h" - -#if defined(MLAS_NEON64_INTRINSICS) || defined(MLAS_SSE2_INTRINSICS) || \ - defined(MLAS_LSX_INTRINSICS) - -#include - -// -// QuantizeLinear implementation using NEON or SSE2 intrinsics. -// - -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearVector( - MLAS_FLOAT32X4 FloatVector, - MLAS_FLOAT32X4 ScaleVector, - MLAS_FLOAT32X4 MinimumValueVector, - MLAS_FLOAT32X4 MaximumValueVector, - MLAS_INT32X4 ZeroPointVector - ) -{ - // - // Scale the input vector and clamp the values to the minimum and maximum - // range (adjusted by the zero point value). - // - - FloatVector = MlasDivideFloat32x4(FloatVector, ScaleVector); - -#if defined(MLAS_NEON64_INTRINSICS) - // N.B. FMINNM and FMAXNM returns the numeric value if either of the values - // is a NaN. - FloatVector = vmaxnmq_f32(FloatVector, MinimumValueVector); - FloatVector = vminnmq_f32(FloatVector, MaximumValueVector); -#elif defined(MLAS_LSX_INTRINSICS) - FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); - FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); -#else - // N.B. MINPS and MAXPS returns the value from the second vector if the - // value from the first vector is a NaN. - FloatVector = _mm_max_ps(FloatVector, MinimumValueVector); - FloatVector = _mm_min_ps(FloatVector, MaximumValueVector); -#endif - - // - // Convert the float values to integer using "round to nearest even" and - // then shift the output range using the zero point value. - // - -#if defined(MLAS_NEON64_INTRINSICS) - auto IntegerVector = vcvtnq_s32_f32(FloatVector); - IntegerVector = vaddq_s32(IntegerVector, ZeroPointVector); -#elif defined(MLAS_LSX_INTRINSICS) - auto IntegerVector = __lsx_vftint_w_s(FloatVector); - IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); -#else - // N.B. Assumes MXCSR has been configured with the default rounding mode of - // "round to nearest even". - auto IntegerVector = _mm_cvtps_epi32(FloatVector); - IntegerVector = _mm_add_epi32(IntegerVector, ZeroPointVector); -#endif - - return IntegerVector; -} - -template -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ); - -template -void -MlasQuantizeLinearStore4PackedValues( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ); - -template -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ); - -#if defined(MLAS_NEON64_INTRINSICS) - -template -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - // - // Swizzle the least significant byte from each int32_t element to the - // bottom four bytes of the vector register. - // - - uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector); - WordVector = vuzp1q_u16(WordVector, WordVector); - uint8x16_t ByteVector = vreinterpretq_u8_u16(WordVector); - ByteVector = vuzp1q_u8(ByteVector, ByteVector); - - return vreinterpretq_s32_u8(ByteVector); -} - -template<> -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - // - // Swizzle the least significant u16 from each int32_t element to the - // bottom eight bytes of the vector register. - // - - uint16x8_t WordVector = vreinterpretq_u16_s32(IntegerVector); - WordVector = vuzp1q_u16(WordVector, WordVector); - return vreinterpretq_s32_u16(WordVector); -} - -template<> -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - // - // Swizzle the least significant u16 from each int32_t element to the - // bottom eight bytes of the vector register. - // - - int16x8_t WordVector = vreinterpretq_s16_s32(IntegerVector); - WordVector = vuzp1q_s16(WordVector, WordVector); - return vreinterpretq_s32_s16(WordVector); -} - -template -MLAS_FORCEINLINE -void -MlasQuantizeLinearStore4PackedValues( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ) -{ - // Copies the lower 4 packed elements of the vector into memory (Output). - - if constexpr (std::is_same_v || std::is_same_v) { - vst1q_lane_s32(reinterpret_cast(Output), IntegerVector, 0); - } else { - static_assert(std::is_same_v || std::is_same_v); - vst1q_lane_s64(reinterpret_cast(Output), vreinterpretq_s64_s32(IntegerVector), 0); - } -} - -template <> -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - uint8_t* Output - ) -{ - // Copies the lower 8-bit element of the vector into memory (Output). - vst1q_lane_u8(Output, vreinterpretq_u8_s32(IntegerVector), 0); -} - -template <> -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - int8_t* Output - ) -{ - // Copies the lower 8-bit element of the vector into memory (Output). - vst1q_lane_s8(Output, vreinterpretq_s8_s32(IntegerVector), 0); -} - -template <> -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - uint16_t* Output - ) -{ - // Copies the lower 16-bit element of the vector into memory (Output). - vst1q_lane_u16(Output, vreinterpretq_u16_s32(IntegerVector), 0); -} - -template <> -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - int16_t* Output - ) -{ - // Copies the lower 16-bit element of the vector into memory (Output). - vst1q_lane_s16(Output, vreinterpretq_s16_s32(IntegerVector), 0); -} - -#elif defined(MLAS_LSX_INTRINSICS) -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 integervector - ) -{ - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2; - - tmp = __lsx_vmax_h(integervector, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - - integervector = __lsx_vpickev_b(tmp2, tmp2); - - - tmp = __lsx_vmax_h(integervector, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - - integervector = __lsx_vpickev_b(tmp2, tmp2); - return integervector; -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 integervector - ) -{ - - __m128i tmp, tmp1; - - tmp = __lsx_vsat_h(integervector, 7); - tmp1 = __lsx_vsat_h(integervector, 7); - integervector = __lsx_vpickev_b(tmp1, tmp); - - tmp = __lsx_vsat_h(integervector, 7); - tmp1 = __lsx_vsat_h(integervector, 7); - integervector = __lsx_vpickev_b(tmp1, tmp); - return integervector; -} - -template -MLAS_FORCEINLINE -void -MlasQuantizeLinearStore4PackedValues( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ) -{ - // Copies the lower 4 packed elements of the vector into memory (Output). - - if constexpr (std::is_same_v || std::is_same_v) { - __lsx_vstelm_w(IntegerVector, reinterpret_cast(Output), 0, 0); - } else { - static_assert(std::is_same_v || std::is_same_v); - - __lsx_vstelm_d(IntegerVector, reinterpret_cast(Output), 0, 0); - } -} - - -template -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ) -{ - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v); - - // Copies the lower element of the vector into memory (Output). - // Expects that the 32-bit element in lane 0 is already within the valid numerical - // range of the OutputType. - *Output = static_cast(__lsx_vpickve2gr_w(IntegerVector, 0)); -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2; - - tmp = __lsx_vmax_w(IntegerVector, zero); - tmp2 = __lsx_vsat_wu(tmp, 15); - - IntegerVector = __lsx_vpickev_h(tmp2, tmp2); - return IntegerVector; -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - __m128i tmp, tmp1; - - tmp = __lsx_vsat_w(IntegerVector, 15); - tmp1 = __lsx_vsat_w(IntegerVector, 15); - IntegerVector = __lsx_vpickev_h(tmp1, tmp); - return IntegerVector; -} -#else - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - IntegerVector = _mm_packus_epi16(IntegerVector, IntegerVector); - IntegerVector = _mm_packus_epi16(IntegerVector, IntegerVector); - - return IntegerVector; -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - IntegerVector = _mm_packs_epi16(IntegerVector, IntegerVector); - IntegerVector = _mm_packs_epi16(IntegerVector, IntegerVector); - - return IntegerVector; -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ -#if defined(MLAS_SSE41_INTRINSICS) - IntegerVector = _mm_packus_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. -#else - // Cannot use _mm_packus_epi32 because that was not available until SSE4.1. - // Instead, emulate by sign-extending the first 16-bits of each packed 32-bit element. - // Afterwards, can use _mm_packs_epi32, which is available on SSE2. - // See: https://stackoverflow.com/a/11028244 - - IntegerVector = _mm_slli_epi32(IntegerVector, 16); - IntegerVector = _mm_srai_epi32(IntegerVector, 16); // Sign-extend: undo left shift with right arithmetic shift - IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. -#endif // defined(MLAS_SSE41_INTRINSICS) - - return IntegerVector; -} - -template<> -MLAS_FORCEINLINE -MLAS_INT32X4 -MlasQuantizeLinearPackBytes( - MLAS_INT32X4 IntegerVector - ) -{ - IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); // 16-bit values packed in lower 8 bytes. - - return IntegerVector; -} - -template -MLAS_FORCEINLINE -void -MlasQuantizeLinearStore4PackedValues( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ) -{ - // Copies the lower 4 packed elements of the vector into memory (Output). - - if constexpr (std::is_same_v || std::is_same_v) { - *(reinterpret_cast(Output)) = _mm_cvtsi128_si32(IntegerVector); - } else { - static_assert(std::is_same_v || std::is_same_v); - -#if defined(MLAS_TARGET_IX86) - // x86 does not support _mm_cvtsi128_si64, so use _mm_maskmoveu_si128 instead. - constexpr uint32_t bytes_high_bit = 0x80808080; - const __m128i first_8_bytes_mask = _mm_set_epi32(0, 0, bytes_high_bit, bytes_high_bit); - _mm_maskmoveu_si128(IntegerVector, first_8_bytes_mask, reinterpret_cast(Output)); -#else - *(reinterpret_cast(Output)) = _mm_cvtsi128_si64(IntegerVector); -#endif // defined(MLAS_TARGET_IX86) - } -} - -template -MLAS_FORCEINLINE -void -MlasQuantizeLinearStoreSingleValue( - MLAS_INT32X4 IntegerVector, - OutputType* Output - ) -{ - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v); - - // Copies the lower element of the vector into memory (Output). - // Expects that the 32-bit element in lane 0 is already within the valid numerical - // range of the OutputType. - *Output = static_cast(_mm_cvtsi128_si32(IntegerVector)); -} - -#endif - -template -void -MLASCALL -MlasQuantizeLinearKernel( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - constexpr int32_t MinimumValue = std::numeric_limits::lowest(); - constexpr int32_t MaximumValue = std::numeric_limits::max(); - - auto ScaleVector = MlasBroadcastFloat32x4(Scale); - auto MinimumValueVector = MlasBroadcastFloat32x4(float(MinimumValue - ZeroPoint)); - auto MaximumValueVector = MlasBroadcastFloat32x4(float(MaximumValue - ZeroPoint)); - auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); - - while (N >= 4) { - - auto FloatVector = MlasLoadFloat32x4(Input); - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - MlasQuantizeLinearStore4PackedValues(IntegerVector, Output); - - Input += 4; - Output += 4; - N -= 4; - } - - for (size_t n = 0; n < N; n++) { - -#if defined(MLAS_NEON64_INTRINSICS) - auto FloatVector = vld1q_dup_f32(Input + n); -#elif defined(MLAS_LSX_INTRINSICS) - MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); -#else - auto FloatVector = _mm_load_ss(Input + n); -#endif - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - MlasQuantizeLinearStoreSingleValue(IntegerVector, &Output[n]); - } -} - -template -void -MLASCALL -MlasQuantizeLinearInt4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - constexpr int32_t MinimumValue = Int4Traits::Min; - constexpr int32_t MaximumValue = Int4Traits::Max; - using UnpackedType = typename Int4Traits::UnpackedType; - - auto ScaleVector = MlasBroadcastFloat32x4(Scale); - auto MinimumValueVector = MlasBroadcastFloat32x4(static_cast(MinimumValue - ZeroPoint)); - auto MaximumValueVector = MlasBroadcastFloat32x4(static_cast(MaximumValue - ZeroPoint)); - auto ZeroPointVector = MlasBroadcastInt32x4(ZeroPoint); - - // Holds 4 quantized 8bit values that will be packed into the output as packed 4bit values. - UnpackedType TmpOutput[4] = {}; - - while (N >= 4) { - - auto FloatVector = MlasLoadFloat32x4(Input); - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - IntegerVector = MlasQuantizeLinearPackBytes(IntegerVector); - MlasQuantizeLinearStore4PackedValues(IntegerVector, &TmpOutput[0]); - MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); - MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); - - Input += 4; - N -= 4; - } - - for (size_t n = 0; n < N; n++) { - -#if defined(MLAS_NEON64_INTRINSICS) - auto FloatVector = vld1q_dup_f32(Input + n); -#elif defined(MLAS_LSX_INTRINSICS) - MLAS_FLOAT32X4 FloatVector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input+n, 0); -#else - auto FloatVector = _mm_load_ss(Input + n); -#endif - auto IntegerVector = MlasQuantizeLinearVector(FloatVector, ScaleVector, - MinimumValueVector, MaximumValueVector, ZeroPointVector); - - MlasQuantizeLinearStoreSingleValue(IntegerVector, &TmpOutput[0]); - MlasSetInt4Element(Output, n, TmpOutput[0]); - } -} - -void -MLASCALL -MlasQuantizeLinearS4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU4Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS8Kernel( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU8Kernel( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint -) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU16Kernel( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint -) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS16Kernel( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint -) -{ - MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearS4Kernel( -#else - MlasQuantizeLinearS4Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearU4Kernel( -#else - MlasQuantizeLinearU4Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearS8Kernel( -#else - MlasQuantizeLinearS8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearU8Kernel( -#else - MlasQuantizeLinearU8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearU16Kernel( -#else - MlasQuantizeLinearU16Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().QuantizeLinearS16Kernel( -#else - MlasQuantizeLinearS16Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -#else - -#if defined(MLAS_TARGET_POWER) - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearS16Kernel(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearU16Kernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearS4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearS4Kernel(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasQuantizeLinearU4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - GetMlasPlatform().QuantizeLinearU4Kernel(Input, Output, N, Scale, ZeroPoint); -} -#endif - -// -// QuantizeLinear implementation using the C++ runtime. -// - -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - OutputType* Output, - size_t N, - float Scale, - OutputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - constexpr int32_t MinimumValue = std::numeric_limits::lowest(); - constexpr int32_t MaximumValue = std::numeric_limits::max(); - - for (size_t n = 0; n < N; n++) { - - float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); - FloatValue = std::max(FloatValue, float(MinimumValue)); - FloatValue = std::min(FloatValue, float(MaximumValue)); - Output[n] = (OutputType)(int32_t)FloatValue; - } -} - -#if !defined(MLAS_TARGET_POWER) -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ); - -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - int16_t* Output, - size_t N, - float Scale, - int16_t ZeroPoint - ); - -template -void -MLASCALL -MlasQuantizeLinear( - const float* Input, - uint16_t* Output, - size_t N, - float Scale, - uint16_t ZeroPoint - ); - -template -void -MLASCALL -MlasQuantizeLinearInt4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - constexpr int32_t MinimumValue = Int4Traits::Min; - constexpr int32_t MaximumValue = Int4Traits::Max; - using UnpackedType = typename Int4Traits::UnpackedType; - - for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); - FloatValue = std::max(FloatValue, static_cast(MinimumValue)); - FloatValue = std::min(FloatValue, static_cast(MaximumValue)); - UnpackedType IntValue = static_cast(FloatValue); - - MlasSetInt4Element(Output, n, IntValue); - } -} - -// QuantizeLinear INT4 implementation using the C++ runtime. -void -MLASCALL -MlasQuantizeLinearS4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); -} - -// QuantizeLinear UINT4 implementation using the C++ runtime. -void -MLASCALL -MlasQuantizeLinearU4( - const float* Input, - uint8_t* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasQuantizeLinearInt4(Input, Output, N, Scale, ZeroPoint); -} -#endif - -#endif - -#if defined(MLAS_SSE2_INTRINSICS) - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ) -{ - const __m128 PerMatrixScaleVector = PerColumnScale ? _mm_setzero_ps() : _mm_load1_ps(Scale); - const __m128 MinimumValueVector = _mm_set1_ps(float(std::numeric_limits::lowest() - ZeroPoint)); - const __m128 MaximumValueVector = _mm_set1_ps(float(std::numeric_limits::max() - ZeroPoint)); - const __m128i ZeroPointVector = _mm_set1_epi32(ZeroPoint); - - if (nullptr != Bias) { - Bias += StartN; - } - if (PerColumnScale) { - Scale += StartN; - } - - Input += StartM * InputLeadingDimension + StartN; - Output += StartM * OutputLeadingDimension + StartN; - - // - // Step through each row of the output matrix. - // - - while (CountM-- > 0) { - - const int32_t* bias = Bias; - const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = CountN; - - auto* RowInput = Input; - auto* RowOutput = Output; - - // - // Process 16 columns of the matrices at a time. - // - - while (n >= 16) { - - // - // Load the input data and optionally add the per-column bias. - // - - __m128i IntegerVector0 = _mm_loadu_si128((const __m128i*)&RowInput[0]); - __m128i IntegerVector1 = _mm_loadu_si128((const __m128i*)&RowInput[4]); - __m128i IntegerVector2 = _mm_loadu_si128((const __m128i*)&RowInput[8]); - __m128i IntegerVector3 = _mm_loadu_si128((const __m128i*)&RowInput[12]); - RowInput += 16; - - if (bias != nullptr) { - IntegerVector0 = _mm_add_epi32(IntegerVector0, _mm_loadu_si128((const __m128i *)&bias[0])); - IntegerVector1 = _mm_add_epi32(IntegerVector1, _mm_loadu_si128((const __m128i *)&bias[4])); - IntegerVector2 = _mm_add_epi32(IntegerVector2, _mm_loadu_si128((const __m128i *)&bias[8])); - IntegerVector3 = _mm_add_epi32(IntegerVector3, _mm_loadu_si128((const __m128i *)&bias[12])); - bias += 16; - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - - __m128 FloatVector0 = _mm_cvtepi32_ps(IntegerVector0); - __m128 FloatVector1 = _mm_cvtepi32_ps(IntegerVector1); - __m128 FloatVector2 = _mm_cvtepi32_ps(IntegerVector2); - __m128 FloatVector3 = _mm_cvtepi32_ps(IntegerVector3); - - if (scale != nullptr) { - - FloatVector0 = _mm_mul_ps(FloatVector0, _mm_loadu_ps(&scale[0])); - FloatVector1 = _mm_mul_ps(FloatVector1, _mm_loadu_ps(&scale[4])); - FloatVector2 = _mm_mul_ps(FloatVector2, _mm_loadu_ps(&scale[8])); - FloatVector3 = _mm_mul_ps(FloatVector3, _mm_loadu_ps(&scale[12])); - scale += 16; - - } else { - - FloatVector0 = _mm_mul_ps(FloatVector0, PerMatrixScaleVector); - FloatVector1 = _mm_mul_ps(FloatVector1, PerMatrixScaleVector); - FloatVector2 = _mm_mul_ps(FloatVector2, PerMatrixScaleVector); - FloatVector3 = _mm_mul_ps(FloatVector3, PerMatrixScaleVector); - } - - FloatVector0 = _mm_max_ps(FloatVector0, MinimumValueVector); - FloatVector1 = _mm_max_ps(FloatVector1, MinimumValueVector); - FloatVector2 = _mm_max_ps(FloatVector2, MinimumValueVector); - FloatVector3 = _mm_max_ps(FloatVector3, MinimumValueVector); - - FloatVector0 = _mm_min_ps(FloatVector0, MaximumValueVector); - FloatVector1 = _mm_min_ps(FloatVector1, MaximumValueVector); - FloatVector2 = _mm_min_ps(FloatVector2, MaximumValueVector); - FloatVector3 = _mm_min_ps(FloatVector3, MaximumValueVector); - - IntegerVector0 = _mm_cvtps_epi32(FloatVector0); - IntegerVector1 = _mm_cvtps_epi32(FloatVector1); - IntegerVector2 = _mm_cvtps_epi32(FloatVector2); - IntegerVector3 = _mm_cvtps_epi32(FloatVector3); - - IntegerVector0 = _mm_add_epi32(IntegerVector0, ZeroPointVector); - IntegerVector1 = _mm_add_epi32(IntegerVector1, ZeroPointVector); - IntegerVector2 = _mm_add_epi32(IntegerVector2, ZeroPointVector); - IntegerVector3 = _mm_add_epi32(IntegerVector3, ZeroPointVector); - - __m128i WordVector0; - __m128i WordVector1; - __m128i ByteVector; - - if (std::is_signed::value) { - - WordVector0 = _mm_packs_epi32(IntegerVector0, IntegerVector1); - WordVector1 = _mm_packs_epi32(IntegerVector2, IntegerVector3); - ByteVector = _mm_packs_epi16(WordVector0, WordVector1); - - } else { - - WordVector0 = _mm_packus_epi16(IntegerVector0, IntegerVector1); - WordVector1 = _mm_packus_epi16(IntegerVector2, IntegerVector3); - ByteVector = _mm_packus_epi16(WordVector0, WordVector1); - - } - - _mm_storeu_si128((__m128i*)RowOutput, ByteVector); - RowOutput += 16; - - n -= 16; - } - - // - // Process the remaining columns of the matrices. - // - - while (n > 0) { - - // - // Load the input data and optionally add the per-column bias. - // - - __m128i IntegerVector; - - if (n >= 4) { - - IntegerVector = _mm_loadu_si128((const __m128i*)&RowInput[0]); - RowInput += 4; - - if (bias != nullptr) { - IntegerVector = _mm_add_epi32(IntegerVector, _mm_loadu_si128((const __m128i*)&bias[0])); - bias += 4; - } - - } else { - - int32_t IntegerValue = *RowInput++; - - if (bias != nullptr) { - IntegerValue += *bias++; - } - - IntegerVector = _mm_cvtsi32_si128(IntegerValue); - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - - __m128 FloatVector = _mm_cvtepi32_ps(IntegerVector); - __m128 ScaleVector; - - if (scale != nullptr) { - - if (n >= 4) { - ScaleVector = _mm_loadu_ps(scale); - scale += 4; - } else { - ScaleVector = _mm_load_ss(scale); - scale += 1; - } - - } else { - ScaleVector = PerMatrixScaleVector; - } - - FloatVector = _mm_mul_ps(FloatVector, ScaleVector); - - FloatVector = _mm_max_ps(FloatVector, MinimumValueVector); - FloatVector = _mm_min_ps(FloatVector, MaximumValueVector); - - IntegerVector = _mm_cvtps_epi32(FloatVector); - IntegerVector = _mm_add_epi32(IntegerVector, ZeroPointVector); - - if (std::is_signed::value) { - - IntegerVector = _mm_packs_epi32(IntegerVector, IntegerVector); - IntegerVector = _mm_packs_epi16(IntegerVector, IntegerVector); - - } else { - - IntegerVector = _mm_packus_epi16(IntegerVector, IntegerVector); - IntegerVector = _mm_packus_epi16(IntegerVector, IntegerVector); - - } - - uint32_t OutputValue = uint32_t(_mm_cvtsi128_si32(IntegerVector)); - - if (n >= 4) { - - *reinterpret_cast(RowOutput) = OutputValue; - RowOutput += 4; - - n -= 4; - - } else { - - *RowOutput = uint8_t(OutputValue); - RowOutput += 1; - - n -= 1; - } - } - - // Next Row - Input += InputLeadingDimension; - Output += OutputLeadingDimension; - } -} - -#elif defined(MLAS_NEON64_INTRINSICS) - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ) -{ - const float32x4_t PerMatrixScaleVector = PerColumnScale ? vdupq_n_f32(0) : vld1q_dup_f32(Scale); - const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); - - if (nullptr != Bias) { - Bias += StartN; - } - if (PerColumnScale) { - Scale += StartN; - } - - Input += StartM * InputLeadingDimension + StartN; - Output += StartM * OutputLeadingDimension + StartN; - - // - // Step through each row of the output matrix. - // - - while (CountM-- > 0) { - - const int32_t* bias = Bias; - const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = CountN; - - auto* RowInput = Input; - auto* RowOutput = Output; - - // - // Process 16 columns of the matrices at a time. - // - - while (n >= 16) { - - // - // Load the input data and optionally add the per-column bias. - // - - int32x4x4_t IntegerVector; - - IntegerVector.val[0] = vld1q_s32(&RowInput[0]); - IntegerVector.val[1] = vld1q_s32(&RowInput[4]); - IntegerVector.val[2] = vld1q_s32(&RowInput[8]); - IntegerVector.val[3] = vld1q_s32(&RowInput[12]); - RowInput += 16; - - if (bias != nullptr) { - IntegerVector.val[0] = vaddq_s32(IntegerVector.val[0], vld1q_s32(&bias[0])); - IntegerVector.val[1] = vaddq_s32(IntegerVector.val[1], vld1q_s32(&bias[4])); - IntegerVector.val[2] = vaddq_s32(IntegerVector.val[2], vld1q_s32(&bias[8])); - IntegerVector.val[3] = vaddq_s32(IntegerVector.val[3], vld1q_s32(&bias[12])); - bias += 16; - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - - float32x4x4_t FloatVector; - - FloatVector.val[0] = vcvtq_f32_s32(IntegerVector.val[0]); - FloatVector.val[1] = vcvtq_f32_s32(IntegerVector.val[1]); - FloatVector.val[2] = vcvtq_f32_s32(IntegerVector.val[2]); - FloatVector.val[3] = vcvtq_f32_s32(IntegerVector.val[3]); - - if (scale != nullptr) { - - float32x4x4_t PerColumnScaleVector; - - PerColumnScaleVector.val[0] = vld1q_f32(&scale[0]); - PerColumnScaleVector.val[1] = vld1q_f32(&scale[4]); - PerColumnScaleVector.val[2] = vld1q_f32(&scale[8]); - PerColumnScaleVector.val[3] = vld1q_f32(&scale[12]); - scale += 16; - - FloatVector.val[0] = vmulq_f32(FloatVector.val[0], PerColumnScaleVector.val[0]); - FloatVector.val[1] = vmulq_f32(FloatVector.val[1], PerColumnScaleVector.val[1]); - FloatVector.val[2] = vmulq_f32(FloatVector.val[2], PerColumnScaleVector.val[2]); - FloatVector.val[3] = vmulq_f32(FloatVector.val[3], PerColumnScaleVector.val[3]); - - } else { - - FloatVector.val[0] = vmulq_f32(FloatVector.val[0], PerMatrixScaleVector); - FloatVector.val[1] = vmulq_f32(FloatVector.val[1], PerMatrixScaleVector); - FloatVector.val[2] = vmulq_f32(FloatVector.val[2], PerMatrixScaleVector); - FloatVector.val[3] = vmulq_f32(FloatVector.val[3], PerMatrixScaleVector); - } - - // - // Convert the float values to integer using "round to nearest even". - // Results are saturated to the range of int32_t. - // - - IntegerVector.val[0] = vcvtnq_s32_f32(FloatVector.val[0]); - IntegerVector.val[1] = vcvtnq_s32_f32(FloatVector.val[1]); - IntegerVector.val[2] = vcvtnq_s32_f32(FloatVector.val[2]); - IntegerVector.val[3] = vcvtnq_s32_f32(FloatVector.val[3]); - - // - // Pack the integers with saturation to 16-bit values and shift by - // the zero point, then pack the integers again to bytes. - // - - int16x8x2_t WordVector; - - WordVector.val[0] = vqmovn_high_s32(vqmovn_s32(IntegerVector.val[0]), IntegerVector.val[1]); - WordVector.val[1] = vqmovn_high_s32(vqmovn_s32(IntegerVector.val[2]), IntegerVector.val[3]); - - WordVector.val[0] = vqaddq_s16(WordVector.val[0], ZeroPointVector); - WordVector.val[1] = vqaddq_s16(WordVector.val[1], ZeroPointVector); - - if (std::is_signed::value) { - vst1q_s8(reinterpret_cast(RowOutput), - vqmovn_high_s16(vqmovn_s16(WordVector.val[0]), WordVector.val[1])); - } else { - vst1q_u8(reinterpret_cast(RowOutput), - vqmovun_high_s16(vqmovun_s16(WordVector.val[0]), WordVector.val[1])); - } - RowOutput += 16; - - n -= 16; - } - - // - // Process the remaining columns of the matrices. - // - - while (n > 0) { - - // - // Load the input data and optionally add the per-column bias. - // - - int32x4_t IntegerVector; - - if (n >= 4) { - - IntegerVector = vld1q_s32(&RowInput[0]); - RowInput += 4; - - if (bias != nullptr) { - IntegerVector = vaddq_s32(IntegerVector, vld1q_s32(&bias[0])); - bias += 4; - } - - } else { - - IntegerVector = vld1q_dup_s32(RowInput); - RowInput += 1; - - if (bias != nullptr) { - IntegerVector = vaddq_s32(IntegerVector, vld1q_dup_s32(bias)); - bias += 1; - } - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - - float32x4_t FloatVector = vcvtq_f32_s32(IntegerVector); - float32x4_t ScaleVector; - - if (scale != nullptr) { - - if (n >= 4) { - ScaleVector = vld1q_f32(scale); - scale += 4; - } else { - ScaleVector = vld1q_dup_f32(scale); - scale += 1; - } - - } else { - ScaleVector = PerMatrixScaleVector; - } - - FloatVector = vmulq_f32(FloatVector, ScaleVector); - - // - // Convert the float values to integer using "round to nearest even". - // Results are saturated to the range of int32_t. - // - - IntegerVector = vcvtnq_s32_f32(FloatVector); - - // - // Pack the integers with saturation to 16-bit values and shift by - // the zero point, then pack the integers again to unsigned bytes. - // - - int16x8_t WordVector = vcombine_s16(vqmovn_s32(IntegerVector), vdup_n_s16(0)); - WordVector = vqaddq_s16(WordVector, ZeroPointVector); - - uint8x16_t ByteVector; - - if (std::is_signed::value) { - ByteVector = vcombine_u8(vreinterpret_u8_s8(vqmovn_s16(WordVector)), vdup_n_u8(0)); - } else { - ByteVector = vcombine_u8(vqmovun_s16(WordVector), vdup_n_u8(0)); - } - - if (n >= 4) { - - vst1q_lane_u32(reinterpret_cast(RowOutput), - vreinterpretq_u32_u8(ByteVector), 0); - RowOutput += 4; - - n -= 4; - - } else { - - vst1q_lane_u8(reinterpret_cast(RowOutput), ByteVector, 0); - RowOutput += 1; - - n -= 1; - } - } - - // Next Row - Input += InputLeadingDimension; - Output += OutputLeadingDimension; - } -} - -#elif defined(MLAS_TARGET_POWER) - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ) -{ - float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; - float MinimumValue = float(std::numeric_limits::lowest() - ZeroPoint); - float MaximumValue = float(std::numeric_limits::max() - ZeroPoint); - - auto PerMatrixScaleVector = vec_splats(PerMatrixScaleValue); - auto MinimumVector = vec_splats(MinimumValue); - auto MaximumVector = vec_splats(MaximumValue); - auto ZeroPointVector = vec_splats(int32_t(ZeroPoint)); - - // Workaround to avoid 'variable set but not used' message - MLAS_UNREFERENCED_PARAMETER(PerMatrixScaleVector); - - if (nullptr != Bias) { - Bias += StartN; - } - if (PerColumnScale) { - Scale += StartN; - } - - Input += StartM * InputLeadingDimension + StartN; - Output += StartM * OutputLeadingDimension + StartN; - - // - // Step through each row of the output matrix. - // - - while (CountM-- > 0) { - - const int32_t* bias = Bias; - const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = CountN; - - auto* RowInput = Input; - auto* RowOutput = Output; - - // Process 16 cols at a time - - while (n >= 16) { - - auto IntegerVector0 = vec_xl(0, &RowInput[0]); - auto IntegerVector1 = vec_xl(0, &RowInput[4]); - auto IntegerVector2 = vec_xl(0, &RowInput[8]); - auto IntegerVector3 = vec_xl(0, &RowInput[12]); - RowInput += 16; - - if (bias != nullptr) { - IntegerVector0 = vec_add(IntegerVector0, vec_xl(0, &bias[0])); - IntegerVector1 = vec_add(IntegerVector1, vec_xl(0, &bias[4])); - IntegerVector2 = vec_add(IntegerVector2, vec_xl(0, &bias[8])); - IntegerVector3 = vec_add(IntegerVector3, vec_xl(0, &bias[12])); - bias += 16; - } - - auto FloatVector0 = vec_ctf(IntegerVector0, 0); - auto FloatVector1 = vec_ctf(IntegerVector1, 0); - auto FloatVector2 = vec_ctf(IntegerVector2, 0); - auto FloatVector3 = vec_ctf(IntegerVector3, 0); - - if (scale != nullptr) { - FloatVector0 = vec_mul(FloatVector0, vec_xl(0, &scale[0])); - FloatVector1 = vec_mul(FloatVector1, vec_xl(0, &scale[4])); - FloatVector2 = vec_mul(FloatVector2, vec_xl(0, &scale[8])); - FloatVector3 = vec_mul(FloatVector3, vec_xl(0, &scale[12])); - scale += 16; - } else { - FloatVector0 = vec_mul(FloatVector0, PerMatrixScaleVector); - FloatVector1 = vec_mul(FloatVector1, PerMatrixScaleVector); - FloatVector2 = vec_mul(FloatVector2, PerMatrixScaleVector); - FloatVector3 = vec_mul(FloatVector3, PerMatrixScaleVector); - } - - FloatVector0 = vec_max(FloatVector0, MinimumVector); - FloatVector1 = vec_max(FloatVector1, MinimumVector); - FloatVector2 = vec_max(FloatVector2, MinimumVector); - FloatVector3 = vec_max(FloatVector3, MinimumVector); - - FloatVector0 = vec_min(FloatVector0, MaximumVector); - FloatVector1 = vec_min(FloatVector1, MaximumVector); - FloatVector2 = vec_min(FloatVector2, MaximumVector); - FloatVector3 = vec_min(FloatVector3, MaximumVector); - - FloatVector0 = vec_round(FloatVector0); - FloatVector1 = vec_round(FloatVector1); - FloatVector2 = vec_round(FloatVector2); - FloatVector3 = vec_round(FloatVector3); - - auto IntegerOutVector0 = vec_signed(FloatVector0); - auto IntegerOutVector1 = vec_signed(FloatVector1); - auto IntegerOutVector2 = vec_signed(FloatVector2); - auto IntegerOutVector3 = vec_signed(FloatVector3); - - IntegerOutVector0 = vec_add(IntegerOutVector0, ZeroPointVector); - IntegerOutVector1 = vec_add(IntegerOutVector1, ZeroPointVector); - IntegerOutVector2 = vec_add(IntegerOutVector2, ZeroPointVector); - IntegerOutVector3 = vec_add(IntegerOutVector3, ZeroPointVector); - - auto ShortVector0 = vec_pack(IntegerOutVector0, IntegerOutVector1); - auto ShortVector1 = vec_pack(IntegerOutVector2, IntegerOutVector3); - auto CharVector = vec_pack(ShortVector0, ShortVector1); - - vec_xst(CharVector, 0, (int8_t *) RowOutput); - RowOutput += 16; - n -= 16; - } - - while (n >= 4) { - int8_t OutputBuffer[16]; - - auto IntegerVector = vec_xl(0, &RowInput[0]); - RowInput += 4; - - if (bias != nullptr) { - IntegerVector = vec_add(IntegerVector, vec_xl(0, &bias[0])); - bias += 4; - } - - auto FloatVector = vec_ctf(IntegerVector, 0); - - if (scale != nullptr) { - FloatVector = vec_mul(FloatVector, vec_xl(0, scale)); - scale += 4; - } else { - FloatVector = vec_mul(FloatVector, PerMatrixScaleVector); - } - - FloatVector = vec_max(FloatVector, MinimumVector); - FloatVector = vec_min(FloatVector, MaximumVector); - FloatVector = vec_round(FloatVector); - - auto IntegerOutVector = vec_signed(FloatVector); - IntegerOutVector = vec_add(IntegerOutVector, ZeroPointVector); - - auto ShortVector = vec_pack(IntegerOutVector, vec_splats((int32_t) 0)); - auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); - - vec_xst(CharVector, 0, OutputBuffer); - memcpy(RowOutput, OutputBuffer, 4); - - RowOutput += 4; - n -= 4; - } - - while (n > 0) { - auto IntegerValue = RowInput[0]; - RowInput += 1; - - if (bias != nullptr) { - IntegerValue += bias[0]; - bias += 1; - } - - float FloatValue = float(IntegerValue); - float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue; - - FloatValue *= ScaleValue; - FloatValue = std::max(FloatValue, MinimumValue); - FloatValue = std::min(FloatValue, MaximumValue); - - IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - - MLAS_ROUNDING_BIAS_MAGIC_BITS; - - *RowOutput++ = OutputType(IntegerValue + ZeroPoint); - - n -= 1; - } - - // Next Row - Input += InputLeadingDimension; - Output += OutputLeadingDimension; - } -} - -#elif defined(MLAS_LSX_INTRINSICS) - -template -void -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ) -{ - //TO BE CHECK - float min_f = float(std::numeric_limits::lowest() - ZeroPoint); - float max_f = float(std::numeric_limits::max() - ZeroPoint); - const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); - const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); - const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); - const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); - - if (nullptr != Bias) { - Bias += StartN; - } - if (PerColumnScale) { - Scale += StartN; - } - - Input += StartM * InputLeadingDimension + StartN; - Output += StartM * OutputLeadingDimension + StartN; - // - // Step through each row of the output matrix. - // - - while (CountM-- > 0) { - - const int32_t* bias = Bias; - const float* scale = PerColumnScale ? Scale : nullptr; - size_t n = CountN; - - auto* RowInput = Input; - auto* RowOutput = Output; - - // - // Process 16 columns of the matrices at a time. - // - - while (n >= 16) { - - // - // Load the input data and optionally add the per-column bias. - // - - __m128i IntegerVector0 = __lsx_vld((const __m128i*)&RowInput[0], 0); - __m128i IntegerVector1 = __lsx_vld((const __m128i*)&RowInput[4], 0); - __m128i IntegerVector2 = __lsx_vld((const __m128i*)&RowInput[8], 0); - __m128i IntegerVector3 = __lsx_vld((const __m128i*)&RowInput[12], 0); - RowInput += 16; - - if (bias != nullptr) { - IntegerVector0 = __lsx_vadd_w(IntegerVector0, __lsx_vld((const __m128i *)&bias[0], 0)); - IntegerVector1 = __lsx_vadd_w(IntegerVector1, __lsx_vld((const __m128i *)&bias[4], 0)); - IntegerVector2 = __lsx_vadd_w(IntegerVector2, __lsx_vld((const __m128i *)&bias[8], 0)); - IntegerVector3 = __lsx_vadd_w(IntegerVector3, __lsx_vld((const __m128i *)&bias[12], 0)); - bias += 16; - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - - __m128 FloatVector0 = __lsx_vffint_s_w(IntegerVector0); - __m128 FloatVector1 = __lsx_vffint_s_w(IntegerVector1); - __m128 FloatVector2 = __lsx_vffint_s_w(IntegerVector2); - __m128 FloatVector3 = __lsx_vffint_s_w(IntegerVector3); - - if (scale != nullptr) { - - FloatVector0 = __lsx_vfmul_s(FloatVector0, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[0], 0))); - FloatVector1 = __lsx_vfmul_s(FloatVector1, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[4], 0))); - FloatVector2 = __lsx_vfmul_s(FloatVector2, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[8], 0))); - FloatVector3 = __lsx_vfmul_s(FloatVector3, MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)&scale[12], 0))); - scale += 16; - - } else { - - FloatVector0 = __lsx_vfmul_s(FloatVector0, PerMatrixScaleVector); - FloatVector1 = __lsx_vfmul_s(FloatVector1, PerMatrixScaleVector); - FloatVector2 = __lsx_vfmul_s(FloatVector2, PerMatrixScaleVector); - FloatVector3 = __lsx_vfmul_s(FloatVector3, PerMatrixScaleVector); - } - FloatVector0 = __lsx_vfmax_s(FloatVector0, MinimumValueVector); - FloatVector1 = __lsx_vfmax_s(FloatVector1, MinimumValueVector); - FloatVector2 = __lsx_vfmax_s(FloatVector2, MinimumValueVector); - FloatVector3 = __lsx_vfmax_s(FloatVector3, MinimumValueVector); - - FloatVector0 = __lsx_vfmin_s(FloatVector0, MaximumValueVector); - FloatVector1 = __lsx_vfmin_s(FloatVector1, MaximumValueVector); - FloatVector2 = __lsx_vfmin_s(FloatVector2, MaximumValueVector); - FloatVector3 = __lsx_vfmin_s(FloatVector3, MaximumValueVector); - - IntegerVector0 = __lsx_vftint_w_s(FloatVector0); - IntegerVector1 = __lsx_vftint_w_s(FloatVector1); - IntegerVector2 = __lsx_vftint_w_s(FloatVector2); - IntegerVector3 = __lsx_vftint_w_s(FloatVector3); - - IntegerVector0 = __lsx_vadd_w(IntegerVector0, ZeroPointVector); - IntegerVector1 = __lsx_vadd_w(IntegerVector1, ZeroPointVector); - IntegerVector2 = __lsx_vadd_w(IntegerVector2, ZeroPointVector); - IntegerVector3 = __lsx_vadd_w(IntegerVector3, ZeroPointVector); - - __m128i WordVector0; - __m128i WordVector1; - __m128i ByteVector; - - if (std::is_signed::value) { - - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(IntegerVector0, 15); - tmp1 = __lsx_vsat_w(IntegerVector1, 15); - WordVector0 = __lsx_vpickev_h(tmp1, tmp); - - tmp = __lsx_vsat_w(IntegerVector2, 15); - tmp1 = __lsx_vsat_w(IntegerVector3, 15); - WordVector1 = __lsx_vpickev_h(tmp1, tmp); - - tmp = __lsx_vsat_h(WordVector0, 7); - tmp1 = __lsx_vsat_h(WordVector1, 7); - ByteVector = __lsx_vpickev_b(tmp1, tmp); - - - } else { - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(IntegerVector0, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(IntegerVector1, zero); - tmp3 = __lsx_vsat_hu(tmp, 7); - WordVector0 = __lsx_vpickev_b(tmp3, tmp2); - - tmp = __lsx_vmax_h(IntegerVector2, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(IntegerVector3, zero); - tmp3 = __lsx_vsat_hu(tmp, 7); - WordVector1 = __lsx_vpickev_b(tmp3, tmp2); - - tmp = __lsx_vmax_h(WordVector0, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(WordVector1, zero); - tmp3 = __lsx_vsat_hu(tmp, 7); - ByteVector = __lsx_vpickev_b(tmp3, tmp2); - - } - - __lsx_vst(ByteVector, (__m128i*)RowOutput, 0); - RowOutput += 16; - - n -= 16; - } - - // - // Process the remaining columns of the matrices. - // - - while (n > 0) { - - // - // Load the input data and optionally add the per-column bias. - // - - __m128i IntegerVector; - - if (n >= 4) { - - IntegerVector = __lsx_vld((const __m128i*)&RowInput[0], 0); - RowInput += 4; - - if (bias != nullptr) { - IntegerVector = __lsx_vadd_w(IntegerVector, __lsx_vld((const __m128i*)&bias[0], 0)); - bias += 4; - } - - } else { - - int32_t IntegerValue = *RowInput++; - - if (bias != nullptr) { - IntegerValue += *bias++; - } - IntegerVector = __lsx_vldrepl_w(&IntegerValue, 0); - } - - // - // Convert to integer values to float and apply the per-tensor or - // per-column scaling. - // - __m128 FloatVector = __lsx_vffint_s_w(IntegerVector); - __m128 ScaleVector; - - if (scale != nullptr) { - - if (n >= 4) { - ScaleVector = MlasReinterpretAsFloat32x4(__lsx_vld((__m128i *)scale, 0)); - scale += 4; - } else { - ScaleVector = (__m128)__lsx_vldrepl_w(scale, 0); - scale += 1; - } - - } else { - ScaleVector = PerMatrixScaleVector; - } - FloatVector = __lsx_vfmul_s(FloatVector, ScaleVector); - - FloatVector = __lsx_vfmax_s(FloatVector, MinimumValueVector); - FloatVector = __lsx_vfmin_s(FloatVector, MaximumValueVector); - - IntegerVector = __lsx_vftint_w_s(FloatVector); - IntegerVector = __lsx_vadd_w(IntegerVector, ZeroPointVector); - - if (std::is_signed::value) { - - __m128i tmp; - tmp = __lsx_vsat_w(IntegerVector, 15); - IntegerVector = __lsx_vpickev_h(tmp, tmp); - - tmp = __lsx_vsat_h(IntegerVector, 7); - IntegerVector = __lsx_vpickev_b(tmp, tmp); - - } else { - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2; - - tmp = __lsx_vmax_h(IntegerVector, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - IntegerVector = __lsx_vpickev_b(tmp2, tmp2); - - tmp = __lsx_vmax_h(IntegerVector, zero); - tmp2 = __lsx_vsat_hu(tmp, 7); - IntegerVector = __lsx_vpickev_b(tmp2, tmp2); - - } - - uint32_t OutputValue = uint32_t(__lsx_vpickve2gr_w(IntegerVector, 0)); - - if (n >= 4) { - - *reinterpret_cast(RowOutput) = OutputValue; - RowOutput += 4; - - n -= 4; - - } else { - - *RowOutput = uint8_t(OutputValue); - RowOutput += 1; - - n -= 1; - } - } - - // Next Row - Input += InputLeadingDimension; - Output += OutputLeadingDimension; - } -} - -#else - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - OutputType* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - OutputType ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ) -{ - const float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; - const float MinimumValue = float(std::numeric_limits::lowest() - ZeroPoint); - const float MaximumValue = float(std::numeric_limits::max() - ZeroPoint); - - if (nullptr != Bias) { - Bias += StartN; - } - if (PerColumnScale) { - Scale += StartN; - } - - Input += StartM * InputLeadingDimension + StartN; - Output += StartM * OutputLeadingDimension + StartN; - - // - // Step through each row of the output matrix. - // - - while (CountM-- > 0) { - - const int32_t* bias = Bias; - const float* scale = Scale; - size_t n = CountN; - - auto* RowInput = Input; - auto* RowOutput = Output; - - while (n > 0) { - - int32_t IntegerValue = *RowInput++; - - if (bias != nullptr) { - IntegerValue += *bias++; - } - - float FloatValue = float(IntegerValue); - float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue; - - FloatValue *= ScaleValue; - FloatValue = std::max(FloatValue, MinimumValue); - FloatValue = std::min(FloatValue, MaximumValue); - - // - // Use the fast rounding trick adapted from XNNPACK: bias the floating - // point value by the first floating point value that has no - // fractional bits. The add operation performs the "round to nearest - // even". Extract the mantissa bits from this floating point value to - // obtain the rounded integer value. - // - - IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - - MLAS_ROUNDING_BIAS_MAGIC_BITS; - - *RowOutput++ = OutputType(IntegerValue + ZeroPoint); - - n -= 1; - } - - // Next Row - Input += InputLeadingDimension; - Output += OutputLeadingDimension; - } -} - -#endif - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - int8_t* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - int8_t ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ); - -template -void -MLASCALL -MlasRequantizeOutput( - const int32_t* Input, - size_t InputLeadingDimension, - uint8_t* Output, - size_t OutputLeadingDimension, - const int32_t* Bias, - const float* Scale, - bool PerColumnScale, - uint8_t ZeroPoint, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN - ); - -void -MLASCALL -MlasFindMinMaxElement( - const float* Input, - float* Min, - float* Max, - size_t N - ) -/*++ - -Routine Description: - - This routine finds the minimum and maximum values of the supplied buffer. - -Arguments: - - Input - Supplies the input buffer. - - Min - Returns the minimum value of the supplied buffer. - - Max - Returns the maximum value of the supplied buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().ReduceMinimumMaximumF32Kernel(Input, Min, Max, N); -#else - MlasReduceMinimumMaximumF32Kernel(Input, Min, Max, N); -#endif -} diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp deleted file mode 100644 index b329ea2ffb149..0000000000000 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ /dev/null @@ -1,980 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - reorder.cpp - -Abstract: - - This module implements routines to reorder buffers to and from blocked - formats. - ---*/ - -#include "mlasi.h" - -// -// Define the parameters to execute segments of a NCHW output reordering -// operation on worker threads. -// - -struct MLAS_REORDER_OUTPUT_NCHW_BLOCK { - ptrdiff_t TargetThreadCount; - const float* S; - float* D; - size_t OutputChannels; - size_t OutputSize; - size_t TasksCount; -}; - -MLAS_FORCEINLINE -void -MlasReorderGatherFloat32x4( - const float* S, - float* D, - size_t GatherStride - ) -/*++ - -Routine Description: - - This routine gathers floats from the source buffer and writes a vector to - the destination buffer. - -Arguments: - - S - Supplies the address of the source buffer. - - D - Supplies the address of the destination buffer. - - GatherStride - Supplies the stride to read elements from the source buffer. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_SSE41_INTRINSICS) - __m128 v = _mm_load_ss(&S[0 * GatherStride]); - v = _mm_insert_ps(v, _mm_load_ss(&S[1 * GatherStride]), 0x10); - v = _mm_insert_ps(v, _mm_load_ss(&S[2 * GatherStride]), 0x20); - v = _mm_insert_ps(v, _mm_load_ss(&S[3 * GatherStride]), 0x30); - - _mm_storeu_ps(D, v); -#else - float f0 = S[0 * GatherStride]; - float f1 = S[1 * GatherStride]; - float f2 = S[2 * GatherStride]; - float f3 = S[3 * GatherStride]; - - D[0] = f0; - D[1] = f1; - D[2] = f2; - D[3] = f3; -#endif -} - -MLAS_FORCEINLINE -void -MlasReorderScatterFloat32x4( - const float* S, - float* D, - size_t ScatterStride - ) -/*++ - -Routine Description: - - This routine scatters a vector read from the source buffer to the - destination buffer. - -Arguments: - - S - Supplies the address of the source buffer. - - D - Supplies the address of the destination buffer. - - ScatterStride - Supplies the stride to write elements to the destination - buffer. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_SSE41_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) - MLAS_FLOAT32X4 v = MlasLoadFloat32x4(S); - - MlasStoreLaneFloat32x4<0>(&D[ScatterStride * 0], v); - MlasStoreLaneFloat32x4<1>(&D[ScatterStride * 1], v); - MlasStoreLaneFloat32x4<2>(&D[ScatterStride * 2], v); - MlasStoreLaneFloat32x4<3>(&D[ScatterStride * 3], v); -#else - float f0 = S[0]; - float f1 = S[1]; - float f2 = S[2]; - float f3 = S[3]; - - D[ScatterStride * 0] = f0; - D[ScatterStride * 1] = f1; - D[ScatterStride * 2] = f2; - D[ScatterStride * 3] = f3; -#endif -} - -MLAS_FORCEINLINE -void -MlasReorderTransposeFloat32x4x4( - const float* S, - float* D, - size_t GatherStride, - size_t ScatterStride - ) -/*++ - -Routine Description: - - This routine transposes a 4x4 float matrix read from the source buffer and - writes the result to the destination buffer. - -Arguments: - - S - Supplies the address of the source buffer. - - D - Supplies the address of the destination buffer. - - GatherStride - Supplies the stride to read elements from the source buffer. - - ScatterStride - Supplies the stride to write vectors to the destination - buffer. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_SSE2_INTRINSICS) - MLAS_FLOAT32X4 v[4]; - MLAS_FLOAT32X4 t[4]; - - v[0] = MlasLoadFloat32x4(&S[GatherStride * 0]); - v[1] = MlasLoadFloat32x4(&S[GatherStride * 1]); - v[2] = MlasLoadFloat32x4(&S[GatherStride * 2]); - v[3] = MlasLoadFloat32x4(&S[GatherStride * 3]); - - t[0] = _mm_unpacklo_ps(v[0], v[1]); - t[2] = _mm_unpackhi_ps(v[0], v[1]); - t[1] = _mm_unpacklo_ps(v[2], v[3]); - t[3] = _mm_unpackhi_ps(v[2], v[3]); - - v[0] = _mm_movelh_ps(t[0], t[1]); - v[1] = _mm_movehl_ps(t[1], t[0]); - v[2] = _mm_movelh_ps(t[2], t[3]); - v[3] = _mm_movehl_ps(t[3], t[2]); - - MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); - MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); - MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); - MlasStoreFloat32x4(&D[ScatterStride * 3], v[3]); -#elif defined(MLAS_LSX_INTRINSICS) - - MLAS_FLOAT32X4 v[4]; - MLAS_FLOAT32X4 t[4]; - - v[0] = MlasLoadFloat32x4(&S[GatherStride * 0]); - v[1] = MlasLoadFloat32x4(&S[GatherStride * 1]); - v[2] = MlasLoadFloat32x4(&S[GatherStride * 2]); - v[3] = MlasLoadFloat32x4(&S[GatherStride * 3]); - - t[0] = (__m128)__lsx_vilvl_w((__m128i)v[1], (__m128i)v[0]); - t[2] = (__m128)__lsx_vilvh_w((__m128i)v[1], (__m128i)v[0]); - t[1] = (__m128)__lsx_vilvl_w((__m128i)v[3], (__m128i)v[2]); - t[3] = (__m128)__lsx_vilvh_w((__m128i)v[3], (__m128i)v[2]); - - - v[0] = (__m128)__lsx_vpickev_d((__m128i) t[1],(__m128i) t[0]); - v[1] = (__m128)__lsx_vpickod_d((__m128i) t[1],(__m128i) t[0]); - v[2] = (__m128)__lsx_vpickev_d((__m128i) t[3],(__m128i) t[2]); - v[3] = (__m128)__lsx_vpickod_d((__m128i) t[3],(__m128i) t[2]); - - MlasStoreFloat32x4(&D[ScatterStride * 0], v[0]); - MlasStoreFloat32x4(&D[ScatterStride * 1], v[1]); - MlasStoreFloat32x4(&D[ScatterStride * 2], v[2]); - MlasStoreFloat32x4(&D[ScatterStride * 3], v[3]); -#else - MlasReorderScatterFloat32x4(&S[GatherStride * 0], &D[0], ScatterStride); - MlasReorderScatterFloat32x4(&S[GatherStride * 1], &D[1], ScatterStride); - MlasReorderScatterFloat32x4(&S[GatherStride * 2], &D[2], ScatterStride); - MlasReorderScatterFloat32x4(&S[GatherStride * 3], &D[3], ScatterStride); -#endif -} - -void -MLASCALL -MlasReorderInputNchw( - const float* S, - float* D, - size_t InputChannels, - size_t InputSize - ) -/*++ - -Routine Description: - - This routine reorders an input buffer from NCHW to NCHWc format. - -Arguments: - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - - InputChannels - Supplies the number of NCHW channels. - - InputSize - Supplies the spatial input size of the tensors. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - // - // Iterate over BlockSize batches of the input channels. - // - - for (size_t i = InputChannels; i > 0;) { - - const size_t InputChannelsThisIteration = std::min(i, BlockSize); - i -= InputChannelsThisIteration; - - const float* s = S; - float* d = D; - size_t InputSizeRemaining = InputSize; - - for (; InputSizeRemaining >= 4; InputSizeRemaining -= 4) { - - const float* ss = s; - float* dd = d; - size_t bc = 0; - - for (; bc < InputChannelsThisIteration; bc += 4) { - MlasReorderTransposeFloat32x4x4(ss, dd, InputSize, BlockSize); - ss += 4 * InputSize; - dd += 4; - } - - for (; bc < BlockSize; bc += 4) { - MlasStoreFloat32x4(&dd[BlockSize * 0], ZeroFloat32x4); - MlasStoreFloat32x4(&dd[BlockSize * 1], ZeroFloat32x4); - MlasStoreFloat32x4(&dd[BlockSize * 2], ZeroFloat32x4); - MlasStoreFloat32x4(&dd[BlockSize * 3], ZeroFloat32x4); - dd += 4; - } - - s += 4; - d += 4 * BlockSize; - } - - for (; InputSizeRemaining > 0; InputSizeRemaining--) { - - const float* ss = s; - float* dd = d; - size_t bc = 0; - - for (; bc < InputChannelsThisIteration; bc += 4) { - MlasReorderGatherFloat32x4(ss, dd, InputSize); - ss += 4 * InputSize; - dd += 4; - } - - for (; bc < BlockSize; bc += 4) { - MlasStoreFloat32x4(dd, ZeroFloat32x4); - dd += 4; - } - - s += 1; - d += BlockSize; - } - - S += BlockSize * InputSize; - D += BlockSize * InputSize; - } -} - -void -MLASCALL -MlasReorderInputNhwc( - const float* S, - float* D, - size_t InputChannels, - size_t RowCount, - size_t FullRowCount - ) -/*++ - -Routine Description: - - This routine reorders an input buffer from NHWC to NCHWc format. - -Arguments: - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - - InputChannels - Supplies the number of NHWC channels. - - RowCount - Supplies the number of NHWC rows to process. This number may be - less than FullRowCount to support threaded operation. - - FullRowCount - Supplies the total number of NHWC rows per image. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - // - // Iterate over batches of the input size to improve locality. - // - - for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) { - - constexpr size_t OuterRowCountBatch = 32; - - const size_t OuterRowCountThisIteration = std::min(OuterRowCountRemaining, OuterRowCountBatch); - OuterRowCountRemaining -= OuterRowCountThisIteration; - - // - // Iterate over BlockSize batches of the input channels. - // - - const float* s = S; - float* d = D; - - for (size_t i = InputChannels; i > 0;) { - - const size_t InputChannelsThisIteration = std::min(i, BlockSize); - i -= InputChannelsThisIteration; - - const float* ss = s; - float* dd = d; - size_t InnerRowCountRemaining = OuterRowCountThisIteration; - - if (InputChannelsThisIteration == BlockSize) { - - if (BlockSize == 8) { - - while (InnerRowCountRemaining-- > 0) { - - MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]); - MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]); - - MlasStoreFloat32x4(&dd[0], v0); - MlasStoreFloat32x4(&dd[4], v1); - - ss += InputChannels; - dd += 8; - } - - } else { - - while (InnerRowCountRemaining-- > 0) { - - MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]); - MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]); - MLAS_FLOAT32X4 v2 = MlasLoadFloat32x4(&ss[8]); - MLAS_FLOAT32X4 v3 = MlasLoadFloat32x4(&ss[12]); - - MlasStoreFloat32x4(&dd[0], v0); - MlasStoreFloat32x4(&dd[4], v1); - MlasStoreFloat32x4(&dd[8], v2); - MlasStoreFloat32x4(&dd[12], v3); - - ss += InputChannels; - dd += 16; - } - } - - } else { - - size_t BlockPadding = BlockSize - InputChannelsThisIteration; - - while (InnerRowCountRemaining-- > 0) { - - std::copy_n(ss, InputChannelsThisIteration, dd); - std::fill_n(dd + InputChannelsThisIteration, BlockPadding, 0.0f); - - ss += InputChannels; - dd += BlockSize; - } - } - - s += InputChannelsThisIteration; - d += BlockSize * FullRowCount; - } - - S += InputChannels * OuterRowCountThisIteration; - D += BlockSize * OuterRowCountThisIteration; - } -} - -void -MlasReorderOutputNchwThreaded( - void* Context, - ptrdiff_t Index - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - NCHW output reordering operation. - -Arguments: - - Context - Supplies the pointer to the context for the threaded operation. - - Index - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - const auto* WorkBlock = (MLAS_REORDER_OUTPUT_NCHW_BLOCK*)Context; - - const size_t OutputChannels = WorkBlock->OutputChannels; - const size_t OutputSize = WorkBlock->OutputSize; - const float* S = WorkBlock->S; - float* D = WorkBlock->D; - - const size_t BlockSize = MlasNchwcGetBlockSize(); - const size_t TasksPerBatch = size_t(ceil(((float)OutputChannels) / BlockSize)); - const size_t LastTaskInBatchIndex = TasksPerBatch - 1; - - // - // Compute the range of task indices to use for this thread. - // - - size_t TaskStart; - size_t TasksRemaining; - - MlasPartitionWork(Index, WorkBlock->TargetThreadCount, WorkBlock->TasksCount, - &TaskStart, &TasksRemaining); - - size_t TaskEnd = TaskStart + TasksRemaining; - // - // Rebase the pointers to the source and destination buffers for this thread. - // - - size_t FirstBatchIndex = TaskStart / TasksPerBatch; - size_t FirstTaskInBatchIndex = TaskStart % TasksPerBatch; - S += BlockSize * OutputSize * (FirstBatchIndex * TasksPerBatch + FirstTaskInBatchIndex); - D += OutputSize * (FirstBatchIndex * OutputChannels + BlockSize * FirstTaskInBatchIndex); - - // - // Transpose NCHWc blocks associated with tasks in the range [TaskStart, TaskEnd) - // from the source buffer to the destination buffer. - // - - for (size_t t = TaskStart; t < TaskEnd; t++) { - size_t TaskInBatchIndex = t % TasksPerBatch; - - const size_t OutputChannelsThisIteration = (TaskInBatchIndex < LastTaskInBatchIndex) ? - BlockSize : OutputChannels - BlockSize * LastTaskInBatchIndex; - const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - - const float* s = S; - float* d = D; - size_t OutputSizeRemaining = OutputSize; - - for (; OutputSizeRemaining >= 4; OutputSizeRemaining -= 4) { - - const float* ss = s; - float* dd = d; - size_t bc = 0; - - for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { - MlasReorderTransposeFloat32x4x4(ss, dd, BlockSize, OutputSize); - ss += 4; - dd += 4 * OutputSize; - } - - for (; bc < OutputChannelsThisIteration; bc += 1) { - MlasReorderGatherFloat32x4(ss, dd, BlockSize); - ss += 1; - dd += OutputSize; - } - - s += 4 * BlockSize; - d += 4; - } - - for (; OutputSizeRemaining > 0; OutputSizeRemaining--) { - - const float* ss = s; - float* dd = d; - size_t bc = 0; - - for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { - MlasReorderScatterFloat32x4(ss, dd, OutputSize); - ss += 4; - dd += 4 * OutputSize; - } - - for (; bc < OutputChannelsThisIteration; bc += 1) { - *dd = *ss++; - dd += OutputSize; - } - - s += BlockSize; - d += 1; - } - - S += BlockSize * OutputSize; - D += OutputChannelsThisIteration * OutputSize; - } -} - - -void -MLASCALL -MlasReorderOutputNchw( - const int64_t* OutputShape, - const float* S, - float* D, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine reorders an output buffer from NCHWc to NCHW format. - -Arguments: - - OutputShape - Supplies the shape of the output tensor. - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - -Return Value: - - None. - ---*/ -{ - MLAS_REORDER_OUTPUT_NCHW_BLOCK WorkBlock; - - // - // Capture the NCHW reorder output operation parameters to the work block. - // - - WorkBlock.S = S; - WorkBlock.D = D; - WorkBlock.OutputChannels = size_t(OutputShape[1]); - WorkBlock.OutputSize = size_t(OutputShape[2]) * size_t(OutputShape[3]); - - const size_t BlockSize = MlasNchwcGetBlockSize(); - const size_t TasksPerBatch = size_t(ceil(((float)WorkBlock.OutputChannels) / BlockSize)); - const size_t BatchCount = size_t(OutputShape[0]); - const size_t TasksCount = BatchCount * TasksPerBatch; - WorkBlock.TasksCount = TasksCount; - - // - // Schedule the operation across a set of worker threads if the output - // tensor is sufficienly large. Limit the number of threads to at least - // the number of available tasks. - // - - ptrdiff_t TargetThreadCount = 1; - const size_t BufferSize = BatchCount * WorkBlock.OutputChannels * WorkBlock.OutputSize; - if (BufferSize > 1024 && TasksCount > 1) { - TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); - if (size_t(TargetThreadCount) > TasksCount) { - TargetThreadCount = ptrdiff_t(TasksCount); - } - } - WorkBlock.TargetThreadCount = TargetThreadCount; - - MlasExecuteThreaded(MlasReorderOutputNchwThreaded, &WorkBlock, TargetThreadCount, ThreadPool); -} - -void -MLASCALL -MlasReorderOutputNhwc( - const int64_t* OutputShape, - const float* S, - float* D - ) -/*++ - -Routine Description: - - This routine reorders an output buffer from NCHWc to NHWC format. - -Arguments: - - OutputShape - Supplies the shape of the output tensor. - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - const size_t BatchCount = size_t(OutputShape[0]); - const size_t OutputChannels = size_t(OutputShape[3]); - const size_t OutputSize = size_t(OutputShape[1]) * size_t(OutputShape[2]); - - const size_t AlignedOutputChannels = (OutputChannels + BlockSize - 1) & ~(BlockSize - 1); - - // - // Copy NCHWc blocks from the source buffer to the destination buffer. - // - - for (size_t batch = 0; batch < BatchCount; batch++) { - - const float* s = S; - size_t OutputSizeRemaining = OutputSize; - - for (; OutputSizeRemaining > 0; OutputSizeRemaining--) { - - const float* ss = s; - - for (size_t o = OutputChannels; o > 0;) { - - const size_t OutputChannelsThisIteration = std::min(o, BlockSize); - const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - o -= OutputChannelsThisIteration; - - size_t bc = 0; - - for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { - MlasStoreFloat32x4(&D[bc], MlasLoadFloat32x4(&ss[bc])); - } - - for (; bc < OutputChannelsThisIteration; bc += 1) { - D[bc] = ss[bc]; - } - - ss += BlockSize * OutputSize; - D += OutputChannelsThisIteration; - } - - s += BlockSize; - } - - S += AlignedOutputChannels * OutputSize; - } -} - -void -MLASCALL -MlasReorderFilterOIHWBiBo( - const int64_t* FilterShape, - const float* S, - float* D - ) -/*++ - -Routine Description: - - This routine reorders a filter buffer from OIHW to OIHWBiBo format. - -Arguments: - - FilterShape - Supplies the shape of the filter tensor. - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - const size_t OutputChannels = size_t(FilterShape[0]); - const size_t InputChannels = size_t(FilterShape[1]); - const size_t KernelHeight = size_t(FilterShape[2]); - const size_t KernelWidth = size_t(FilterShape[3]); - - const size_t KernelSize = KernelHeight * KernelWidth; - const size_t InputStride = InputChannels * KernelSize; - - const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - // - // Transform the filter tensor from format OIHW to OIHWBiBo: - // - // OutputChannelBlock[0] = { - // InputChannelBlock[0] = { - // Kernel[0][0] = { - // InputChannel[0] = { filter[0 filter[1] ... filter[BlockSize-1] }, - // InputChannel[1] = { filter[0 filter[1] ... filter[BlockSize-1] }, - // ... - // InputChannel[BlockSize-1] = { filter[0] filter[1] ... filter[BlockSize-1] }, - // }, - // Kernel[0][1] = { - // ... - // }, - // ... - // Kernel[KernelHeight-1][KernelWidth-1] = { - // ... - // }, - // }, - // InputChannelBlock[BlockSize] = { - // ... - // }, - // ... - // InputChannelBlock[InputChannels-BlockSize] = { - // ... - // }, - // }, - // OutputChannelBlock[BlockSize] = { - // ... - // }, - // OutputChannelBlock[OutputChannels-BlockSize] = { - // ... - // }; - // - - // - // Iterate over BlockSize batches of the output channels. - // - // The final batch may be less than BlockSize, but must be a multiple of 4. - // The unaligned count results in zero padding below. - // - - for (size_t o = OutputChannels; o > 0;) { - - const size_t OutputChannelsThisIteration = std::min(o, BlockSize); - const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - o -= OutputChannelsThisIteration; - - // - // Iterate over BlockSize batches of the input channels. - // - // The final batch may be less than BlockSize, but must be a multiple - // of 4. - // - - const float* S_InputChannels = S; - - for (size_t i = InputChannels; i > 0;) { - - const size_t InputChannelsThisIteration = std::min(i, BlockSize); - i -= InputChannelsThisIteration; - - // - // Iterate over each index of the kernel. - // - - const float* S_KernelSize = S_InputChannels; - - for (size_t k = 0; k < KernelSize; k++) { - - // - // Construct a filter block of BlockSize by BlockSize. - // - - const float* S_BlockSize = S_KernelSize; - - for (size_t bi = 0; bi < InputChannelsThisIteration; bi++) { - - // - // Transpose from the source filter buffer to the destination - // buffer. Zero pad the filter block if the output channels - // is not block aligned. - // - - const float* s = S_BlockSize; - size_t bo = 0; - - for (; bo < AlignedOutputChannelsThisIteration; bo += 4) { - MlasReorderGatherFloat32x4(s, D, InputStride); - s += 4 * InputStride; - D += 4; - } - - for (; bo < OutputChannelsThisIteration; bo += 1) { - *D++ = *s; - s += InputStride; - } - - for (; bo < BlockSize; bo += 1) { - *D++ = 0.0f; - } - - S_BlockSize += KernelSize; - } - - for (size_t z = 0; z < (BlockSize - InputChannelsThisIteration) * (BlockSize / 4); z++) { - MlasStoreFloat32x4(D, ZeroFloat32x4); - D += 4; - } - - S_KernelSize += 1; - } - - S_InputChannels += BlockSize * KernelSize; - } - - S += BlockSize * InputStride; - } -} - -void -MLASCALL -MlasReorderFilterOIHWBo( - const int64_t* FilterShape, - const float* S, - float* D - ) -/*++ - -Routine Description: - - This routine reorders a filter buffer from OIHW to OIHWBo format. - -Arguments: - - FilterShape - Supplies the shape of the filter tensor. - - S - Supplies the address of the source tensor. - - D - Supplies the address of the destination tensor. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - const size_t OutputChannels = size_t(FilterShape[0]); - const size_t InputChannels = size_t(FilterShape[1]); - const size_t KernelHeight = size_t(FilterShape[2]); - const size_t KernelWidth = size_t(FilterShape[3]); - - const size_t KernelSize = KernelHeight * KernelWidth; - const size_t InputStride = InputChannels * KernelSize; - - // - // Transform the filter tensor from format OIHW to OIHWBo: - // - // OutputChannelBlock[0] = { - // InputChannel[0] = { - // Kernel[0][0] = filter[0 filter[1] ... filter[BlockSize-1] }, - // Kernel[0][1] = { filter[0 filter[1] ... filter[BlockSize-1] }, - // ... - // Kernel[KernelHeight-1][KernelWidth-1] = { filter[0 filter[1] ... filter[BlockSize-1] }, - // }, - // InputChannel[1] = { - // ... - // }, - // ... - // InputChannel[InputChannels-1] = { - // ... - // }, - // }, - // OutputChannelBlock[BlockSize] = { - // ... - // }, - // OutputChannelBlock[OutputChannels-BlockSize] = { - // ... - // }; - // - - // - // Iterate over BlockSize batches of the output channels. - // - // The final batch may be less than BlockSize, but must be a multiple of 4. - // The unaligned count results in zero padding below. - // - - for (size_t o = OutputChannels; o > 0;) { - - const size_t OutputChannelsThisIteration = std::min(o, BlockSize); - const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); - o -= OutputChannelsThisIteration; - - // - // Iterate over each of the input channels. - // - - const float* S_InputChannels = S; - - for (size_t i = 0; i < InputChannels; i += 1) { - - // - // Iterate over each index of the kernel. - // - - const float* S_KernelSize = S_InputChannels; - - for (size_t k = 0; k < KernelSize; k++) { - - // - // Transpose a float[4] from the source filter buffer to the - // destination buffer. Zero pad the filter block if the output - // channels is not block aligned. - // - - const float* s = S_KernelSize; - size_t bo = 0; - - for (; bo < AlignedOutputChannelsThisIteration; bo += 4) { - MlasReorderGatherFloat32x4(s, D, InputStride); - s += 4 * InputStride; - D += 4; - } - - for (; bo < OutputChannelsThisIteration; bo += 1) { - *D++ = *s; - s += InputStride; - } - - for (; bo < BlockSize; bo += 1) { - *D++ = 0.0f; - } - - S_KernelSize += 1; - } - - S_InputChannels += KernelSize; - } - - S += BlockSize * InputStride; - } -} diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h deleted file mode 100644 index de7fd72fad45a..0000000000000 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ /dev/null @@ -1,399 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - sbgemm.h - -Abstract: - - This module defines the set of template functions to implement bfloat16 - precision matrix/matrix multiply operation (SBGEMM). - - To implement a new kernel, template functions below need to be specialized: - MlasSBGemmConvertPackB - MlasSBGemmPackedBOffset - MlasSBGemmPackedBLeadingDim - MlasSBGemmKernel - - MlasSBGemmOperation is the shared kernel driver. - - A kernel type should define the following constants: - bool PackNeeded; Whether B needs to be packed - size_t KernelMaxM; Max # rows the vectorized kernel can process - size_t PackedK; Packed alignment on the K dim (power of 2) - size_t PackedN; Packed alignment on the n dim (power of 2) - MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; ---*/ - -#if defined(__aarch64__) && defined(__linux__) - -#pragma once - -#include -#include - -#include "mlasi.h" - -/** - * @brief Define the default striding parameters for - * the bfloat16 precision gemm operation - */ -struct MLAS_SBGEMM_STRIDES { - size_t M; - size_t N; - size_t K; -}; - -/** - * @brief Convert fp32 matrix B to bf16 and pack the data - * - * @tparam KernelType - * @param[out] D Address of packing buffer - * @param[in] B Address of source matrix B in fp32 - * @param[in] ldb Leading dimension of B - * @param[in] CountN # of column to pack - * @param[in] CountK # of rows to pack - */ -template -void -MlasSBGemmConvertPackB( - bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK -); - -/** - * @brief Find the location of PackedB[StartK, StartN] - * - * @tparam KernelType - * @param PackedB - * @param DimN Total columns of the packing buffer - * @param DimK Total rows of the packing buffer - * @param StartN - * @param StartK - * @return Address of PackedB[StartK, StartN] - */ -template -MLAS_FORCEINLINE const bfloat16_t* -MlasSBGemmPackedBOffset( - const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK -) -{ - // By default the packed buffer is just a row major - // K row by N column buffer - MLAS_UNREFERENCED_PARAMETER(DimK); - return PackedB + StartK * DimN + StartN; -} - -/** - * @brief leading dimension of the packed B buffer - * Related to how B is packed - * @tparam KernelType - * @param DimN - * @param DimK - * @return leading dimension of the packed B buffer - */ -template -MLAS_FORCEINLINE size_t -MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) -{ - // By default the packed buffer is just a row major - // K row by N column buffer - MLAS_UNREFERENCED_PARAMETER(DimK); - return DimN; -} - -template -void -MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); - -template -MLAS_FORCEINLINE void -MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) -{ - constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; - size_t PackedStrideN = Strides.N; - size_t PackedStrideK = Strides.K; - - // - // Step through each slice of matrix B along the N dimension. - // - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - const size_t SliceStartN = RangeStartN + n; - CountN = std::min(RangeCountN - n, PackedStrideN); - - // - // Step through each slice of matrix B along the K dimension. - // - size_t CountK; - for (size_t k = 0; k < K; k += CountK) { - bool ZeroMode = (k == 0); - CountK = std::min(K - k, PackedStrideK); - - const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; - float* c = C + n; - const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); - MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); - } - if (PostProcessor != nullptr) { - ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) - ->Process(C + n, M, SliceStartN, M, CountN, ldc); - } - } -} - -template -void -MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) -{ - // - // Compute the strides to step through slices of the input matrices. - // - // Expand the N stride if K is small or expand the K stride if N is small - // for better utilization of the B panel. Avoid changing the K stride if - // the A panel needs to be used for transposing. - // - constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; - size_t StrideN = Strides.N; - size_t StrideK = Strides.K; - - if (N >= K) { - while (StrideK / 2 >= K) { - StrideN *= 2; - StrideK /= 2; - } - } else { - while (StrideN > 16 && StrideN / 2 >= N) { - StrideK *= 2; - StrideN /= 2; - } - } - - constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); - MlasThreadedBufAlloc(packBSize); - uint8_t* p = ThreadedBufHolder.get(); - auto* PanelB = reinterpret_cast(p); - - // - // Step through each slice of matrix B along the N dimension. - // - size_t CountN; - for (size_t n = 0; n < N; n += CountN) { - CountN = std::min(N - n, StrideN); - - // - // Step through each slice of matrix B along the N dimension. - // - size_t CountK; - for (size_t k = 0; k < K; k += CountK) { - CountK = std::min(K - k, StrideK); - - // - // Copy a panel of matrix B to a local packed buffer. - // - MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); - - auto* c = C + n; - const float* pbias = - ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart - - bool ZeroMode = (k == 0); - MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); - } - if (PostProcessor != nullptr) { - ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); - } - } -} - -template -void -MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) -{ - const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; - const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; - - // - // Partition the operation along the M dimension. - // - size_t RangeStartM; - size_t RangeCountM; - - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); - - // - // Partition the operation along the N dimension. - // - size_t RangeStartN; - size_t RangeCountN; - - const size_t BlockedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); - - RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); - - // - // Dispatch the partitioned operation. - // - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - const float* A = (const float*)DataParams->A + RangeStartM * lda; - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - const float* bias = DataParams->Bias; - - if (!DataParams->BIsfp32) { - MlasSBGemmPackedOperation( - RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, - lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor - ); - } else { - const size_t ldb = DataParams->ldb; - const float* B = (const float*)DataParams->B + RangeStartN; - MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); - } -} - -// -// dispatch structure. -// -typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); - -typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( - bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK -); - -/** - * @brief Hardware dependent dispatch for half precision GEMM - */ -struct MLAS_SBGEMM_DISPATCH { - MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ - MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ - size_t PackedK; - size_t PackedN; - size_t StrideM; - size_t BufOverRead; -}; - -extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; - -MLAS_FORCEINLINE -const MLAS_SBGEMM_DISPATCH* -MlasSBGemmGetDispatch() -{ -#if defined(MLAS_TARGET_ARM64) - return &MlasSBGemmDispatchNeon; -#else - std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; - exit(1); -#endif -} - -size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K) -{ - // - // Compute the number of bytes required to hold the packed buffer. - // - const auto* dispatch = MlasSBGemmGetDispatch(); - if (dispatch == nullptr) return 0; - - const auto padding = dispatch->BufOverRead; - const auto PackedK = dispatch->PackedK; - const auto PackedN = dispatch->PackedN; - - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); - const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; - const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = - (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); - - return AlignedBytesRequired; -} - -void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) -{ - const auto* dispatch = MlasSBGemmGetDispatch(); - if (dispatch == nullptr) return; - - dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); -} - -void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) -{ - const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); - if (dispatch == nullptr) return; - - MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K); - - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - // - // Segment the operation across multiple threads. - // - // N.B. Currently, the operation is segmented as a 1D partition, which - // works okay for operations involving skinny matrices. - // - ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - - if (N > M) { - const size_t BlockedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - if (size_t(ThreadsPerGemm) > BlockedN) { - ThreadsPerGemm = ptrdiff_t(BlockedN); - } - - ThreadCountM = 1; - ThreadCountN = ThreadsPerGemm; - - } else { - if (size_t(ThreadsPerGemm) > M) { - ThreadsPerGemm = ptrdiff_t(M); - } - - ThreadCountM = ThreadsPerGemm; - ThreadCountN = 1; - } - - MlasTrySimpleParallel( - ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { - ptrdiff_t GemmIdx = tid / ThreadsPerGemm; - ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; - operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); - } - ); -} -#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp deleted file mode 100644 index a6a73996c548b..0000000000000 --- a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp +++ /dev/null @@ -1,362 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - sbgemm_kernel_neon.cpp - -Abstract: - - This module implements bfloat16 precision GEMM kernel for neon. - ---*/ - -#if defined(__aarch64__) && defined(__linux__) - -#include "arm_neon.h" -#include "mlasi.h" -#include "sbgemm.h" - -struct MLAS_SBGEMM_KERNEL_NEON { - static constexpr bool PackNeeded = true; - static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process - static constexpr size_t PackedK = 4; - static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K -}; - -bool MLASCALL -MlasBf16AccelerationSupported() -{ -#if defined(MLAS_TARGET_ARM64) - return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); -#else - return false; -#endif -} - -/* - This routine converts fp32 to bf16 and copies elements from the source - matrix to the destination packed buffer. - - 4x2 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the SBGEMM kernels. The remaining - rows and columns are padded to 4 and 2 alignment. -*/ -MLAS_FORCEINLINE -void -MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) -{ - // - // Copy data from matrix B into the destination buffer 4x2 blocks at a - // time. - // - // - while (CountN >= 8) { - const float* b = B; - int y = static_cast(CountK); - - while (y > 0) { - MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); - - if (y >= 4) { - t0_l = MlasLoadFloat32x4(&b[ldb * 0]); - t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); - t1_l = MlasLoadFloat32x4(&b[ldb * 1]); - t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); - t2_l = MlasLoadFloat32x4(&b[ldb * 2]); - t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); - t3_l = MlasLoadFloat32x4(&b[ldb * 3]); - t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); - } else { - switch (y) { - case 3: - t0_l = MlasLoadFloat32x4(&b[ldb * 0]); - t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); - t1_l = MlasLoadFloat32x4(&b[ldb * 1]); - t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); - t2_l = MlasLoadFloat32x4(&b[ldb * 2]); - t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); - break; - case 2: - t0_l = MlasLoadFloat32x4(&b[ldb * 0]); - t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); - t1_l = MlasLoadFloat32x4(&b[ldb * 1]); - t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); - break; - case 1: - t0_l = MlasLoadFloat32x4(&b[ldb * 0]); - t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); - break; - } - } - - float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); - float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); - float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); - float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); - t0_l = o0_l.val[0]; - t1_l = o0_l.val[1]; - t2_l = o1_l.val[0]; - t3_l = o1_l.val[1]; - - bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); - bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); - - bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); - bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); - - vst1q_bf16(&D[0], t0t1_l_8h); - vst1q_bf16(&D[8], t2t3_l_8h); - - float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); - float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); - float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); - float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); - t0_h = o0_h.val[0]; - t1_h = o0_h.val[1]; - t2_h = o1_h.val[0]; - t3_h = o1_h.val[1]; - - bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); - bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); - - bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); - bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); - - vst1q_bf16(&D[16], t0t1_h_8h); - vst1q_bf16(&D[24], t2t3_h_8h); - - D += 32; - b += ldb * 4; - y -= 4; - }; - B += 8; - CountN -= 8; - } - - // - // Special case the handling of the remaining columns less than 8 elements - // wide. - // - if (CountN > 0) { - int y = static_cast(CountK); - while (y > 0) { - const float* b = B; - size_t b_inc = 0; - if ((CountN & 4) != 0) { - MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); - MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); - if (y >= 4) { - t0 = MlasLoadFloat32x4(&b[ldb * 0]); - t1 = MlasLoadFloat32x4(&b[ldb * 1]); - t2 = MlasLoadFloat32x4(&b[ldb * 2]); - t3 = MlasLoadFloat32x4(&b[ldb * 3]); - } else { - switch (y) { - case 3: - t0 = MlasLoadFloat32x4(&b[ldb * 0]); - t1 = MlasLoadFloat32x4(&b[ldb * 1]); - t2 = MlasLoadFloat32x4(&b[ldb * 2]); - break; - case 2: - t0 = MlasLoadFloat32x4(&b[ldb * 0]); - t1 = MlasLoadFloat32x4(&b[ldb * 1]); - break; - case 1: - t0 = MlasLoadFloat32x4(&b[ldb * 0]); - break; - } - } - - float32x4x2_t z0 = vzipq_f32(t0, t2); - float32x4x2_t z1 = vzipq_f32(t1, t3); - float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); - float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); - - t0 = o0.val[0]; - t1 = o0.val[1]; - t2 = o1.val[0]; - t3 = o1.val[1]; - - bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); - bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); - - bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); - bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); - - vst1q_bf16(&D[0], t0t1_8h); - vst1q_bf16(&D[8], t2t3_8h); - - D += 16; - b += 4; - b_inc += 4; - } - - if ((CountN & 2) != 0) { - float32x2_t t0 = {0x0, 0x0}; - float32x2_t t1 = {0x0, 0x0}; - float32x2_t t2 = {0x0, 0x0}; - float32x2_t t3 = {0x0, 0x0}; - - if (y >= 4) { - t0 = vld1_f32(&b[ldb * 0]); - t1 = vld1_f32(&b[ldb * 1]); - t2 = vld1_f32(&b[ldb * 2]); - t3 = vld1_f32(&b[ldb * 3]); - } else { - switch (y) { - case 3: - t0 = vld1_f32(&b[ldb * 0]); - t1 = vld1_f32(&b[ldb * 1]); - t2 = vld1_f32(&b[ldb * 2]); - break; - case 2: - t0 = vld1_f32(&b[ldb * 0]); - t1 = vld1_f32(&b[ldb * 1]); - break; - case 1: - t0 = vld1_f32(&b[ldb * 0]); - break; - } - } - - float32x2x2_t z0 = vzip_f32(t0, t2); - float32x2x2_t z1 = vzip_f32(t1, t3); - float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); - float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); - - float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); - float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); - - bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); - bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); - - vst1q_bf16(&D[0], t_8h); - - D += 8; - b += 2; - b_inc += 2; - } - if ((CountN & 1) != 0) { - float a = 0.0f; - float b = 0.0f; - float c = 0.0f; - float d = 0.0f; - - if (y >= 4) { - a = *(float*)(&B[ldb * 0 + b_inc]); - b = *(float*)(&B[ldb * 1 + b_inc]); - c = *(float*)(&B[ldb * 2 + b_inc]); - d = *(float*)(&B[ldb * 3 + b_inc]); - } else { - switch (y) { - case 3: - a = *(float*)(&B[ldb * 0 + b_inc]); - b = *(float*)(&B[ldb * 1 + b_inc]); - c = *(float*)(&B[ldb * 2 + b_inc]); - break; - case 2: - a = *(float*)(&B[ldb * 0 + b_inc]); - b = *(float*)(&B[ldb * 1 + b_inc]); - break; - case 1: - a = *(float*)(&B[ldb * 0 + b_inc]); - break; - } - } - - float32x2_t t0 = {a, 0x0}; - float32x2_t t1 = {b, 0x0}; - float32x2_t t2 = {c, 0x0}; - float32x2_t t3 = {d, 0x0}; - - float32x2x2_t z0 = vzip_f32(t0, t2); - float32x2x2_t z1 = vzip_f32(t1, t3); - float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); - float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); - - float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); - float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); - - bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); - bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); - - vst1q_bf16(&D[0], t_8h); - - D += 8; - b += 1; - b_inc += 1; - } - B += 4 * ldb; - y -= 4; - } - } -} - -template -void -MlasSBGemmConvertPackB( - bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK -) -{ - const auto* dispatch = MlasSBGemmGetDispatch(); - if (dispatch == nullptr) return; - - const auto PackedN = dispatch->PackedN; - - const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); - - // - // Step through each slice of matrix B along the K dimension. - // - size_t K_block_size; - constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; - - for (size_t k = 0; k < CountK; k += K_block_size) { - K_block_size = std::min(CountK - k, Strides.K); - - MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); - PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; - } -} - -template <> -MLAS_FORCEINLINE void -MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) -{ - while (CountM > 0) { - size_t RowsHandled; - if (ZeroMode) { - RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); - } else { - RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); - } - C += ldc * RowsHandled; - A += lda * RowsHandled; - CountM -= RowsHandled; - } -} - -const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { - MlasSBGemmOperation, - MlasSBGemmConvertPackB, - MLAS_SBGEMM_KERNEL_NEON::PackedK, - MLAS_SBGEMM_KERNEL_NEON::PackedN, - MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, - 32 // kernel may read beyond buffer end by 32 bytes -}; -#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/scalar/SconvDepthwiseKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SconvDepthwiseKernelScalar.cpp deleted file mode 100644 index da1cdb96063af..0000000000000 --- a/onnxruntime/core/mlas/lib/scalar/SconvDepthwiseKernelScalar.cpp +++ /dev/null @@ -1,193 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvDepthwiseKernelScalar.cpp - -Abstract: - - This module implements the kernels for the single precision direct - convolution kernels. - ---*/ - -#include "mlasi.h" - -static -void -MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - float* Output, - const float* Zeros - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute convolution on one channel input with one filter channel. - -Arguments: - - Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. - - Input - input channel data start. Input is NCHW, so this pointer point to single H x W image data. - - Filter - Whole filters are of F x CpG x FH x FW, this filter point to single FH x FW filter data. - - Output - whole output are of N x F x OH x OW. This pointer point to single OH x OW output image data. - - Zeroes - Point to working buffer where all 0.0f are filled. - ---*/ -{ - const size_t W = Parameters->InputShape[1]; - const float beta = Parameters->Beta; - - if (W > 1) { - - const float w00 = Filter[0]; - const float w01 = Filter[1]; - const float w02 = Filter[2]; - const float w10 = Filter[3]; - const float w11 = Filter[4]; - const float w12 = Filter[5]; - const float w20 = Filter[6]; - const float w21 = Filter[7]; - const float w22 = Filter[8]; - - const size_t H = Parameters->InputShape[0]; - const size_t pad_top = Parameters->Padding[0]; - const size_t pad_left = Parameters->Padding[1]; - const size_t stride_h = Parameters->StrideShape[0]; - const size_t stride_w = Parameters->StrideShape[1]; - - // We treat pad_left, pad_top are hard require. - // While pad_right and pad_bottom could be adjusted if they do not 100% match other parameters. - const size_t pad_right = (((Parameters->OutputShape[1] - 1) * stride_w + 3) > (pad_left + W)) ? 1 : 0; - - const float* row0 = (pad_top > 0) ? Zeros : (Input - pad_left); - // Need to handle effective pad_bottom is 2 when H == 1 - const float* row1 = (H + pad_top <= 1) ? Zeros : (Input + (1 - pad_top) * W) - pad_left; - const float* row2 = (H + pad_top <= 2) ? Zeros : (row1 + W); - - for (size_t h = 0, out_row = Parameters->OutputShape[0]; out_row > 0; --out_row) { - auto out_col = Parameters->OutputShape[1]; - - if (pad_left == 1) { - float dotsum = w01 * row0[1] + w02 * row0[2] + w11 * row1[1] + w12 * row1[2] + - w21 * row2[1] + w22 * row2[2] + (beta == 0.f ? 0.f : *Output * beta); - *Output++ = dotsum; - out_col--; - row0 += stride_w; - row1 += stride_w; - row2 += stride_w; - } - - for (; out_col > pad_right; out_col--) { - float dotsum = w00 * row0[0] + w01 * row0[1] + w02 * row0[2] + w10 * row1[0] + - w11 * row1[1] + w12 * row1[2] + w20 * row2[0] + w21 * row2[1] + - w22 * row2[2] + (beta == 0.f ? 0.f : *Output * beta); - *Output++ = dotsum; - row0 += stride_w; - row1 += stride_w; - row2 += stride_w; - } - - if (out_col == 1) { // pad_right == 1 - float dotsum = w00 * row0[0] + w01 * row0[1] + w10 * row1[0] + w11 * row1[1] + - w20 * row2[0] + w21 * row2[1] + (beta == 0.f ? 0.f : *Output * beta); - *Output++ = dotsum; - } - - h += stride_h; - row0 = (Input + (h - pad_top) * W) - pad_left; - row1 = row0 + W; - row2 = (h + 2 >= H + pad_top) ? Zeros : (row1 + W); - } - - } else { // W == 1 - - const size_t H = Parameters->InputShape[0]; - const size_t pad_left = Parameters->Padding[1]; - const size_t pad_top = Parameters->Padding[0]; - const size_t stride_h = Parameters->StrideShape[0]; - size_t out_row = Parameters->OutputShape[0]; - - // Make sure pad_bottom is consistent with other parameters. - size_t pad_bottom = ((out_row - 1) * stride_h + 3) > (pad_top + H) ? - ((out_row - 1) * stride_h + 3) - (pad_top + H) : 0; - - const float w0 = Filter[pad_left ? 1 : 0]; - const float w1 = Filter[pad_left ? 4 : 3]; - const float w2 = Filter[pad_left ? 7 : 6]; - auto init_v = (beta == 0.f ? 0.f : *Output * beta); - - if (pad_top == 1) { - *Output++ = w1 * Input[0] + w2 * ((H + pad_top <= 2) ? 0.0f : Input[1]) + init_v; - out_row--; - } - - for (const float* row = Input + pad_top * stride_h - pad_top; out_row > pad_bottom; --out_row) { - // All pixels are in the input col - auto init = (beta == 0.f ? 0.f : *Output * beta); - *Output++ = w0 * row[0] + w1 * row[1] + w2 * row[2] + init; - row += stride_h; - } - - if (out_row > 0) { - // last 1 or 2 rows are from the padding zero row. - // out_row == 1 when arrive here - if (pad_bottom == 1) { - const float* row = Input + H - 2; - *Output++ = w0 * row[0] + w1 * row[1] + init_v; - } else { // pad_bottom == 2 and H == 1 and padding_top == 0 - *Output++ = w0 * Input[0] + init_v; - } - } - } - -} - - -void -MlasConvDepthwiseFloat_CHW( - const MLAS_CONV_PARAMETERS* Parameters, - const float* Input, - const float* Filter, - float* Output, - const float* Zeros - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute depthwise convolution for one filter channel on one input channel. - -Arguments: - - Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. - - Input - input channel data start. Input is NCHW, so this pointer point to single H x W image data. - - Filter - Whole filters are of F x CpG x FH x FW, this filter point to single FH x FW filter data. - - Output - whole output are of N x F x OH x OW. This pointer point to single OH x OW output image data. - - Zeroes - Point to working buffer where all 0.0f are filled. - -Note: - No checking here as it is inner loop. Logic in generating Parameters controls the check. - - Currently only support 2d kernel 3x3. - Will add general case and more special case if needed later. - ---*/ -{ - MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(Parameters, Input, Filter, Output, Zeros); -} diff --git a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp deleted file mode 100644 index 62729256dac23..0000000000000 --- a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp +++ /dev/null @@ -1,474 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelScalar.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "mlasi.h" - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. Note that in scalar, - the packing wide is 4. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - float Row0Block00; - float Row0Block01; - float Row0Block02; - float Row0Block03; - - float Row1Block00; - float Row1Block01; - float Row1Block02; - float Row1Block03; - -#if defined(_WIN32) - - if (!ProcessTwoRows) { - UNREFERENCED_PARAMETER(lda); - UNREFERENCED_PARAMETER(ldc); - } - -#endif - - do { - - float BElements00; - float BElements01; - float BElements02; - float BElements03; - - float Row0AElements0; - float Row0AElements1; - float Row1AElements0; - float Row1AElements1; - - // - // Clear the block accumulators. - // - - Row0Block00 = 0.0f; - Row0Block01 = 0.0f; - Row0Block02 = 0.0f; - Row0Block03 = 0.0f; - - if (ProcessTwoRows) { - Row1Block00 = 0.0f; - Row1Block01 = 0.0f; - Row1Block02 = 0.0f; - Row1Block03 = 0.0f; - } - - // - // Compute the 4x1 or 4x2 output block. - // - - const float* a = A; - size_t k = CountK; - - while (k >= 2) { - - Row0AElements0 = a[0]; - Row0AElements1 = a[1]; - - if (ProcessTwoRows) { - Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; - } - - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; - Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; - Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; - Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; - Row0Block03 = Row0Block03 + BElements03 * Row0AElements0; - - if (ProcessTwoRows) { - Row1Block00 = Row1Block00 + BElements00 * Row1AElements0; - Row1Block01 = Row1Block01 + BElements01 * Row1AElements0; - Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; - Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; - } - - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; - Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; - Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; - Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; - Row0Block03 = Row0Block03 + BElements03 * Row0AElements1; - - if (ProcessTwoRows) { - Row1Block00 = Row1Block00 + BElements00 * Row1AElements1; - Row1Block01 = Row1Block01 + BElements01 * Row1AElements1; - Row1Block02 = Row1Block02 + BElements02 * Row1AElements1; - Row1Block03 = Row1Block03 + BElements03 * Row1AElements1; - } - - a += 2; - B += 8; - k -= 2; - } - - if (k > 0) { - - Row0AElements0 = a[0]; - - if (ProcessTwoRows) { - Row1AElements0 = a[lda]; - } - - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; - Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; - Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; - Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; - Row0Block03 = Row0Block03 + BElements03 * Row0AElements0; - - if (ProcessTwoRows) { - Row1Block00 = Row1Block00 + BElements00 * Row1AElements0; - Row1Block01 = Row1Block01 + BElements01 * Row1AElements0; - Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; - Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; - } - - B += 4; - } - - // - // Multiply by the alpha value. - // - - Row0Block00 = Row0Block00 * alpha; - Row0Block01 = Row0Block01 * alpha; - Row0Block02 = Row0Block02 * alpha; - Row0Block03 = Row0Block03 * alpha; - - if (ProcessTwoRows) { - Row1Block00 = Row1Block00 * alpha; - Row1Block01 = Row1Block01 * alpha; - Row1Block02 = Row1Block02 * alpha; - Row1Block03 = Row1Block03 * alpha; - } - - if (CountN >= 4) { - - // - // Store the entire output block. - // - - if (!ZeroMode) { - Row0Block00 = Row0Block00 + C[0]; - Row0Block01 = Row0Block01 + C[1]; - Row0Block02 = Row0Block02 + C[2]; - Row0Block03 = Row0Block03 + C[3]; - } - - C[0] = Row0Block00; - C[1] = Row0Block01; - C[2] = Row0Block02; - C[3] = Row0Block03; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block00 = Row1Block00 + C[ldc]; - Row1Block01 = Row1Block01 + C[ldc + 1]; - Row1Block02 = Row1Block02 + C[ldc + 2]; - Row1Block03 = Row1Block03 + C[ldc + 3]; - } - - C[ldc] = Row1Block00; - C[ldc + 1] = Row1Block01; - C[ldc + 2] = Row1Block02; - C[ldc + 3] = Row1Block03; - } - - } else { - - // - // Store the partial output block. - // - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Row0Block00 = Row0Block00 + C[0]; - Row0Block01 = Row0Block01 + C[1]; - } - - C[0] = Row0Block00; - C[1] = Row0Block01; - Row0Block00 = Row0Block02; - Row0Block01 = Row0Block03; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block00 = Row1Block00 + C[ldc]; - Row1Block01 = Row1Block01 + C[ldc + 1]; - } - - C[ldc] = Row1Block00; - C[ldc + 1] = Row1Block01; - Row1Block00 = Row1Block02; - Row1Block01 = Row1Block03; - } - - C += 2; - } - - if ((CountN & 1) != 0) { - - if (!ZeroMode) { - Row0Block00 = Row0Block00 + C[0]; - } - - C[0] = Row0Block00; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block00 = Row1Block00 + C[ldc]; - } - - C[ldc] = Row1Block00; - } - } - - break; - } - - C += 4; - CountN -= 4; - - } while (CountN > 0); - - return ProcessTwoRows ? 2 : 1; -} - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - - if (CountM >= 2) { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } - - return RowsHandled; -} - -size_t -MLASCALL -MlasSgemmKernelZero( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} - -size_t -MLASCALL -MlasSgemmKernelAdd( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} diff --git a/onnxruntime/core/mlas/lib/scalar/SgemvKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SgemvKernelScalar.cpp deleted file mode 100644 index 609a6f251e22f..0000000000000 --- a/onnxruntime/core/mlas/lib/scalar/SgemvKernelScalar.cpp +++ /dev/null @@ -1,169 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemvKernelScalar.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/vector - multiply operation (SGEMV). - ---*/ - -#include "mlasi.h" - -void -MLASCALL -MlasGemvFloatKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t ldb, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. This handles the special case of M=1. - - The elements in matrix B are not transposed. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb - Supplies the first dimension of matrix B. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - None. - ---*/ -{ - if (ZeroMode && CountK > 0) { - float* c = C; - const float* b = B; - const float A0 = A[0]; - auto N = CountN; - constexpr size_t kWidth = 4; - for (; N >= kWidth; N -= kWidth) { - c[0] = A0 * b[0]; - c[1] = A0 * b[1]; - c[2] = A0 * b[2]; - c[3] = A0 * b[3]; - c += kWidth; - b += kWidth; - } - - for (; N > 0; N--) { - c[0] = A0 * b[0]; - c++; - b++; - } - A++; - B += ldb; - - CountK--; - } - - for (; CountK >= 4; CountK -= 4) { - float* c = C; - const float* b = B; - const float* b2 = B + ldb * 2; - - const float A0 = A[0]; - const float A1 = A[1]; - const float A2 = A[2]; - const float A3 = A[3]; - - constexpr size_t kWidth = 4; - auto N = CountN; - for (; N >= kWidth; N -= kWidth) { - float c0 = c[0] + A0 * b[0]; - float c1 = c[1] + A0 * b[1]; - float c2 = c[2] + A0 * b[2]; - float c3 = c[3] + A0 * b[3]; - - c0 += A1 * b[ldb + 0]; - c1 += A1 * b[ldb + 1]; - c2 += A1 * b[ldb + 2]; - c3 += A1 * b[ldb + 3]; - - c0 += A2 * b2[0]; - c1 += A2 * b2[1]; - c2 += A2 * b2[2]; - c3 += A2 * b2[3]; - - c0 += A3 * b2[ldb + 0]; - c1 += A3 * b2[ldb + 1]; - c2 += A3 * b2[ldb + 2]; - c3 += A3 * b2[ldb + 3]; - - c[0] = c0; - c[1] = c1; - c[2] = c2; - c[3] = c3; - - c += kWidth; - b += kWidth; - b2 += kWidth; - } - - for (; N > 0; N--) { - c[0] += A0 * b[0] + A1 * b[ldb] + A2 * b2[0] + A3 * b2[ldb]; - c++; - b++; - b2++; - } - - B += 4 * ldb; - A += 4; - } - - for (; CountK > 0; CountK--) { - float* c = C; - const float* b = B; - const float A0 = A[0]; - constexpr size_t kWidth = 4; - auto N = CountN; - for (; N >= kWidth; N -= kWidth) { - c[0] += A0 * b[0]; - c[1] += A0 * b[1]; - c[2] += A0 * b[2]; - c[3] += A0 * b[3]; - - c += kWidth; - b += kWidth; - } - - for (; N > 0; N--) { - c[0] += A0 * b[0]; - c++; - b++; - } - B += ldb; - A++; - } -} diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp deleted file mode 100644 index 4d7a1ceb4eee7..0000000000000 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ /dev/null @@ -1,1741 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sgemm.cpp - -Abstract: - - This module implements the single precision matrix/matrix multiply - operation (SGEMM). - ---*/ - -#include "mlasi.h" - -// -// Define the number of rows from matrix A to transpose to a local buffer. -// -// N.B. AVX processes a maximum of 4 rows, FMA3 processes a maximum of 6 -// rows, and AVX512F processes a maximum of 12 rows. -// - -#define MLAS_SGEMM_TRANSA_ROWS 12 - -// -// Define the parameters to execute segments of a SGEMM operation on worker -// threads. -// - -void -MlasSgemmMultiplyBeta( - float* C, - size_t CountM, - size_t CountN, - size_t ldc, - float beta - ) -/*++ - -Routine Description: - - This routine multiplies all elements of the output matrix by the beta - scalar value. - -Arguments: - - C - Supplies the address of matrix C. - - CountM - Supplies the number of rows from matrix C. - - CountN - Supplies the number of columns from matrix C. - - ldc - Supplies the first dimension of matrix C. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - -Return Value: - - None. - ---*/ -{ - MLAS_FLOAT32X4 BetaBroadcast = MlasBroadcastFloat32x4(beta); - - while (CountM-- > 0) { - - float* c = C; - size_t n = CountN; - - while (n >= 4) { - MlasStoreFloat32x4(c, MlasMultiplyFloat32x4(MlasLoadFloat32x4(c), BetaBroadcast)); - c += 4; - n -= 4; - } - - while (n > 0) { -#if defined(MLAS_SSE2_INTRINSICS) - _mm_store_ss(c, _mm_mul_ss(_mm_load_ss(c), BetaBroadcast)); -#else - *c = *c * beta; -#endif - c += 1; - n -= 1; - } - - C += ldc; - } -} - -void -MlasSgemmTransposeA( - float* D, - const float* A, - size_t lda, - size_t CountY, - size_t CountX - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - buffer. - -Arguments: - - D - Supplies the address of the destination buffer. - - A - Supplies the address of the source matrix. - - lda - Supplies the number of elements per row of the source matrix. - - CountY - Supplies the number of columns of the source matrix to transpose. - - CountX - Supplies the number of rows of the source matrix to transpose. - -Return Value: - - None. - ---*/ -{ - size_t ldd = CountX; - - // - // Transpose elements from matrix A into the destination buffer 4 columns - // at a time. - // - - while (CountX >= 4) { - - float* d = D; - const float* a = A; - size_t y = CountY; - - do { - - float t0 = a[0]; - float t1 = a[lda]; - float t2 = a[lda * 2]; - float t3 = a[lda * 3]; - - d[0] = t0; - d[1] = t1; - d[2] = t2; - d[3] = t3; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - - D += 4; - A += lda * 4; - CountX -= 4; - } - - // - // Transpose elements from matrix A into the destination buffer for the - // remaining columns. - // - - if (CountX >= 2) { - - float* d = D; - const float* a = A; - size_t y = CountY; - - do { - - float t0 = a[0]; - float t1 = a[lda]; - - d[0] = t0; - d[1] = t1; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - - D += 2; - A += lda * 2; - CountX -= 2; - } - - if (CountX >= 1) { - - float* d = D; - const float* a = A; - size_t y = CountY; - - do { - - d[0] = a[0]; - - d += ldd; - a += 1; - y--; - - } while (y > 0); - } -} - -#if !defined(MLAS_TARGET_WASM_SCALAR) - -void -MlasSgemmCopyPackB( - float* D, - const float* B, - size_t ldb, - size_t CountX, - size_t CountY - ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - - Columns of 16 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the SGEMM kernels. Any remaining - columns less than 16 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountX - Supplies the number of columns of the source matrix to copy. - - CountY - Supplies the number of rows of the source matrix to copy. - -Return Value: - - None. - ---*/ -{ - // - // Copy data from matrix B into the destination buffer 16 columns at a - // time. - // - - while (CountX >= 16) { - - const float* b = B; - size_t y = CountY; - - do { - -#if defined(MLAS_NEON_INTRINSICS) - vst4q_f32(D, vld4q_f32(b)); -#else - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]); - MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&b[8]); - MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&b[12]); - - MlasStoreAlignedFloat32x4(&D[0], t0); - MlasStoreAlignedFloat32x4(&D[4], t1); - MlasStoreAlignedFloat32x4(&D[8], t2); - MlasStoreAlignedFloat32x4(&D[12], t3); -#endif - - D += 16; - b += ldb; - y--; - - } while (y > 0); - - B += 16; - CountX -= 16; - } - - // - // Special case the handling of the remaining columns less than 16 elements - // wide. - // - - if (CountX > 0) { - - MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - -#if defined(MLAS_NEON_INTRINSICS) - float32x4x4_t ZeroFloat32x4x4 = { ZeroFloat32x4, ZeroFloat32x4, ZeroFloat32x4, ZeroFloat32x4 }; -#endif - - size_t y = CountY; - - do { - - float* d = D; - const float* b = B; - -#if defined(MLAS_NEON_INTRINSICS) - vst4q_f32(d, ZeroFloat32x4x4); -#else - MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); - MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4); - MlasStoreAlignedFloat32x4(d + 8, ZeroFloat32x4); - MlasStoreAlignedFloat32x4(d + 12, ZeroFloat32x4); -#endif - - if ((CountX & 8) != 0) { - - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(b); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(b + 4); - - MlasStoreAlignedFloat32x4(d, t0); - MlasStoreAlignedFloat32x4(d + 4, t1); - - d += 8; - b += 8; - } - - if ((CountX & 4) != 0) { - - MlasStoreAlignedFloat32x4(d, MlasLoadFloat32x4(b)); - - d += 4; - b += 4; - } - - if ((CountX & 2) != 0) { - - float t0 = b[0]; - float t1 = b[1]; - - d[0] = t0; - d[1] = t1; - - d += 2; - b += 2; - } - - if ((CountX & 1) != 0) { - d[0] = b[0]; - } - - D += 16; - B += ldb; - y--; - - } while (y > 0); - } -} - -template -inline -void -MlasSgemmTransposePackBNx4( - float* D, - const float* B, - size_t ldb - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - 4 columns of N rows from the source matrix are transposed to N columns of 4 - rows in the destination packed buffer. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - -Return Value: - - None. - ---*/ -{ - for (unsigned n = 0; n < N / 4; n++) { - - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&B[ldb * 0]); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&B[ldb * 1]); - MLAS_FLOAT32X4 t2 = MlasLoadFloat32x4(&B[ldb * 2]); - MLAS_FLOAT32X4 t3 = MlasLoadFloat32x4(&B[ldb * 3]); - -#if defined(MLAS_NEON_INTRINSICS) - float32x4x2_t z0 = vzipq_f32(t0, t2); - float32x4x2_t z1 = vzipq_f32(t1, t3); - float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); - float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); - t0 = o0.val[0]; - t1 = o0.val[1]; - t2 = o1.val[0]; - t3 = o1.val[1]; -#else - MLAS_FLOAT32X4 z0 = MlasInterleaveLowFloat32x4(t0, t2); - MLAS_FLOAT32X4 z1 = MlasInterleaveHighFloat32x4(t0, t2); - MLAS_FLOAT32X4 z2 = MlasInterleaveLowFloat32x4(t1, t3); - MLAS_FLOAT32X4 z3 = MlasInterleaveHighFloat32x4(t1, t3); - t0 = MlasInterleaveLowFloat32x4(z0, z2); - t1 = MlasInterleaveHighFloat32x4(z0, z2); - t2 = MlasInterleaveLowFloat32x4(z1, z3); - t3 = MlasInterleaveHighFloat32x4(z1, z3); -#endif - - MlasStoreAlignedFloat32x4(&D[0], t0); - MlasStoreAlignedFloat32x4(&D[16], t1); - MlasStoreAlignedFloat32x4(&D[32], t2); - MlasStoreAlignedFloat32x4(&D[48], t3); - - D += 4; - B += ldb * 4; - } -} - -void -MlasSgemmTransposePackB( - float* D, - const float* B, - size_t ldb, - size_t CountY, - size_t CountX - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - Columns of 16 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the SGEMM kernels. Any remaining - columns less than 16 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountY - Supplies the number of rows of the source matrix to transpose. - - CountX - Supplies the number of columns of the source matrix to transpose. - -Return Value: - - None. - ---*/ -{ - // - // Transpose elements from matrix B into the packed buffer 16 rows at a - // time. - // - - while (CountY >= 16) { - - const float* b = B; - size_t x = CountX; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - - MLAS_SGEMM_TRANSPOSE_PACKB_BLOCK_ROUTINE* SgemmTransposePackB16x4Routine = - GetMlasPlatform().TransposePackB16x4Routine; - - while (x >= 4) { - - SgemmTransposePackB16x4Routine(&D[0], &b[0], ldb); - - D += 16 * 4; - b += 4; - x -= 4; - } - -#else - - while (x >= 4) { - - MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb); - - D += 16 * 4; - b += 4; - x -= 4; - } - -#endif - - while (x > 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - float t2 = b[ldb * 2]; - float t3 = b[ldb * 3]; - float t4 = b[ldb * 4]; - float t5 = b[ldb * 5]; - float t6 = b[ldb * 6]; - float t7 = b[ldb * 7]; - float t8 = b[ldb * 8]; - float t9 = b[ldb * 9]; - float t10 = b[ldb * 10]; - float t11 = b[ldb * 11]; - float t12 = b[ldb * 12]; - float t13 = b[ldb * 13]; - float t14 = b[ldb * 14]; - float t15 = b[ldb * 15]; - - D[0] = t0; - D[1] = t1; - D[2] = t2; - D[3] = t3; - D[4] = t4; - D[5] = t5; - D[6] = t6; - D[7] = t7; - D[8] = t8; - D[9] = t9; - D[10] = t10; - D[11] = t11; - D[12] = t12; - D[13] = t13; - D[14] = t14; - D[15] = t15; - - D += 16; - b += 1; - x--; - } - - B += ldb * 16; - CountY -= 16; - } - - // - // Special case the handling of the less than 16 remaining rows. - // - - if (CountY > 0) { - - MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); - - size_t x = CountX; - - // - // Transpose 4 columns at a time. - // - - while (x >= 4) { - - float* d = D; - const float* b = B; - - if ((CountY & 8) != 0) { - - MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); - - d += 8; - b += ldb * 8; - - } else { - - MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[24], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[28], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[40], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[44], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[56], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[60], ZeroFloat32x4); - } - - if ((CountY & 4) != 0) { - - MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); - - d += 4; - b += ldb * 4; - - } else { - - MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[20], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[36], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[52], ZeroFloat32x4); - } - - MlasStoreAlignedFloat32x4(&d[0], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[16], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[32], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[48], ZeroFloat32x4); - - if ((CountY & 2) != 0) { - - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); - MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[ldb]); - -#if defined(MLAS_SSE2_INTRINSICS) - __m128 v0 = _mm_unpacklo_ps(t0, t1); - __m128 v1 = _mm_unpackhi_ps(t0, t1); - _mm_storel_pi((__m64*)&d[0], v0); - _mm_storeh_pi((__m64*)&d[16], v0); - _mm_storel_pi((__m64*)&d[32], v1); - _mm_storeh_pi((__m64*)&d[48], v1); -#else - MlasStoreLaneFloat32x4<0>(&d[0], t0); - MlasStoreLaneFloat32x4<0>(&d[1], t1); - MlasStoreLaneFloat32x4<1>(&d[16], t0); - MlasStoreLaneFloat32x4<1>(&d[17], t1); - MlasStoreLaneFloat32x4<2>(&d[32], t0); - MlasStoreLaneFloat32x4<2>(&d[33], t1); - MlasStoreLaneFloat32x4<3>(&d[48], t0); - MlasStoreLaneFloat32x4<3>(&d[49], t1); -#endif - - d += 2; - b += ldb * 2; - } - - if ((CountY & 1) != 0) { - -#if defined(MLAS_NEON_INTRINSICS) - MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); - - MlasStoreLaneFloat32x4<0>(&d[0], t0); - MlasStoreLaneFloat32x4<1>(&d[16], t0); - MlasStoreLaneFloat32x4<2>(&d[32], t0); - MlasStoreLaneFloat32x4<3>(&d[48], t0); -#else - d[0] = b[0]; - d[16] = b[1]; - d[32] = b[2]; - d[48] = b[3]; -#endif - } - - D += 16 * 4; - B += 4; - x -= 4; - } - - // - // Transpose the remaining columns. - // - - while (x > 0) { - - float* d = D; - const float* b = B; - - if ((CountY & 8) != 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - float t2 = b[ldb * 2]; - float t3 = b[ldb * 3]; - float t4 = b[ldb * 4]; - float t5 = b[ldb * 5]; - float t6 = b[ldb * 6]; - float t7 = b[ldb * 7]; - - d[0] = t0; - d[1] = t1; - d[2] = t2; - d[3] = t3; - d[4] = t4; - d[5] = t5; - d[6] = t6; - d[7] = t7; - - d += 8; - b += ldb * 8; - - } else { - - MlasStoreAlignedFloat32x4(&d[8], ZeroFloat32x4); - MlasStoreAlignedFloat32x4(&d[12], ZeroFloat32x4); - } - - if ((CountY & 4) != 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - float t2 = b[ldb * 2]; - float t3 = b[ldb * 3]; - - d[0] = t0; - d[1] = t1; - d[2] = t2; - d[3] = t3; - - d += 4; - b += ldb * 4; - - } else { - - MlasStoreAlignedFloat32x4(&d[4], ZeroFloat32x4); - } - - MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); - - if ((CountY & 2) != 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - - d[0] = t0; - d[1] = t1; - - d += 2; - b += ldb * 2; - } - - if ((CountY & 1) != 0) { - d[0] = b[0]; - } - - D += 16; - B += 1; - x--; - } - } -} - -#else //defined(MLAS_TARGET_WASM_SCALAR) - -void -MlasSgemmCopyPackB( - float* D, - const float* B, - size_t ldb, - size_t CountX, - size_t CountY - ) -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - - Columns of 4 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the SGEMM kernels. Any remaining - columns less than 4 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountX - Supplies the number of columns of the source matrix to copy. - - CountY - Supplies the number of rows of the source matrix to copy. - -Return Value: - - None. - ---*/ -{ - // - // Copy data from matrix B into the destination buffer 4 columns at a - // time. - // - - while (CountX >= 4) { - - const float* b = B; - size_t y = CountY; - - do { - - std::copy_n(b, 4, D); - - D += 4; - b += ldb; - y--; - - } while (y > 0); - - B += 4; - CountX -= 4; - } - - // - // Special case the handling of the remaining columns less than 4 elements - // wide. - // - - if (CountX > 0) { - - size_t y = CountY; - - do { - - std::fill_n(D, 4, 0.0f); - - float* d = D; - const float* b = B; - - if ((CountX & 2) != 0) { - - float t0 = b[0]; - float t1 = b[1]; - - d[0] = t0; - d[1] = t1; - - d += 2; - b += 2; - } - - if ((CountX & 1) != 0) { - d[0] = b[0]; - } - - D += 4; - B += ldb; - y--; - - } while (y > 0); - } -} - -void -MlasSgemmTransposePackB( - float* D, - const float* B, - size_t ldb, - size_t CountY, - size_t CountX - ) -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - Columns of 4 elements from the source matrix are unrolled to be physically - contiguous for better locality inside the SGEMM kernels. Any remaining - columns less than 4 elements wide are zero-padded. - -Arguments: - - D - Supplies the address of the destination packed buffer. - - B - Supplies the address of the source matrix. - - ldb - Supplies the number of elements per row of the source matrix. - - CountY - Supplies the number of rows of the source matrix to transpose. - - CountX - Supplies the number of columns of the source matrix to transpose. - -Return Value: - - None. - ---*/ -{ - auto TransposePackByVector = [&](float *D, const float* B) { - - float b0 = B[0]; - float b1 = B[1]; - float b2 = B[2]; - float b3 = B[3]; - - D[0] = b0; - D[4] = b1; - D[8] = b2; - D[12] = b3; - }; - - // - // Transpose elements from matrix B into the packed buffer 4 rows at a - // time. - // - - while (CountY >= 4) { - - const float* b = B; - size_t x = CountX; - - while (x >= 4) { - - TransposePackByVector(&D[0], &b[ldb * 0]); - TransposePackByVector(&D[1], &b[ldb * 1]); - TransposePackByVector(&D[2], &b[ldb * 2]); - TransposePackByVector(&D[3], &b[ldb * 3]); - - D += 4 * 4; - b += 4; - x -= 4; - } - - while (x > 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - float t2 = b[ldb * 2]; - float t3 = b[ldb * 3]; - - D[0] = t0; - D[1] = t1; - D[2] = t2; - D[3] = t3; - - D += 4; - b += 1; - x--; - } - - B += ldb * 4; - CountY -= 4; - } - - // - // Special case the handling of the less than 16 remaining rows. - // - - if (CountY > 0) { - - size_t x = CountX; - - // - // Transpose 4 columns at a time. - // - - while (x >= 4) { - - std::fill_n(D, 16, 0.0f); - - float* d = D; - const float* b = B; - - if ((CountY & 2) != 0) { - - TransposePackByVector(&d[0], &b[ldb * 0]); - TransposePackByVector(&d[1], &b[ldb * 1]); - - d += 2; - b += ldb * 2; - } - - if ((CountY & 1) != 0) { - TransposePackByVector(&d[0], &b[ldb * 0]); - } - - D += 4 * 4; - B += 4; - x -= 4; - } - - // - // Transpose the remaining columns. - // - - while (x > 0) { - - std::fill_n(D, 4, 0.0f); - - float* d = D; - const float* b = B; - - if ((CountY & 2) != 0) { - - float t0 = b[0]; - float t1 = b[ldb]; - - d[0] = t0; - d[1] = t1; - - d += 2; - b += ldb * 2; - } - - if ((CountY & 1) != 0) { - d[0] = b[0]; - } - - D += 4; - B += 1; - x--; - } - } -} - -#endif - -MLAS_FORCEINLINE -float* -MlasSgemmKernelLoop( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine steps through the rows of the input and output matrices calling - the kernel until all rows have been processed. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the number of rows from matrix A and matrix C to iterate - over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the next address of matrix C. - ---*/ -{ - while (CountM > 0) { - - size_t RowsHandled; - -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); -#else - if (ZeroMode) { - RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); - } -#endif - - C += ldc * RowsHandled; - A += lda * RowsHandled; - CountM -= RowsHandled; - } - - return C; -} - -void -MlasSgemmOperation( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - float beta, - float* C, - size_t ldc - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - TransB - Supplies the transpose operation for matrix B. - - M - Supplies the number of rows of matrix A and matrix C. - - N - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - -Return Value: - - None. - ---*/ -{ - float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_STRIDEK]; - MLAS_DECLSPEC_ALIGN(float PanelB[MLAS_SGEMM_STRIDEN * MLAS_SGEMM_STRIDEK], 16 * sizeof(float)); - - // - // Handle the special case of K equals zero. Apply the beta multiplier to - // the output matrix and exit. - // - - if (K == 0) { - MlasSgemmMultiplyBeta(C, M, N, ldc, beta); - return; - } - - // - // Handle the special case of a small M. The data from matrix B is not - // referenced multiple times, so using a local packed buffer is a wasted - // memory copy. - // - - if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { - -#if defined(MLAS_TARGET_AMD64) - - MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; - - if (TransB == CblasNoTrans) { - SgemmKernelM1Routine = GetMlasPlatform().KernelM1Routine; - } else { - SgemmKernelM1Routine = GetMlasPlatform().KernelM1TransposeBRoutine; - } - - if (SgemmKernelM1Routine != nullptr) { - SgemmKernelM1Routine(A, B, C, K, N, ldb, beta); - return; - } - -#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM) - - if (TransB == CblasNoTrans) { - MlasGemvFloatKernel(A, B, C, K, N, ldb, (beta == 0.0f)); - return; - } - -#endif - - } - - // - // Handle the case when both B and C are column-vectors that are contiguous in memory. - // Because transposition of such vectors doesn't change their layout, and - // Transpose(A*B) = Transpose(B) * Transpose(A), we can apply the same 'small-M' - // optimization as above, with A and B flipped. - // - - if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { - -#if defined(MLAS_TARGET_AMD64) - - MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; - - if (TransA == CblasNoTrans) { - SgemmKernelM1Routine = GetMlasPlatform().KernelM1TransposeBRoutine; - } else { - SgemmKernelM1Routine = GetMlasPlatform().KernelM1Routine; - } - - if (SgemmKernelM1Routine != nullptr) { - SgemmKernelM1Routine(B, A, C, K, M, lda, beta); - return; - } - -#endif - - } - - // - // Compute the strides to step through slices of the input matrices. - // - // Expand the N stride if K is small or expand the K stride if N is small - // for better utilization of the B panel. Avoid changing the K stride if - // the A panel needs to be used for transposing. - // - - size_t StrideN = MLAS_SGEMM_STRIDEN; - size_t StrideK = MLAS_SGEMM_STRIDEK; - - if (N >= K) { - - while (StrideK / 2 >= K) { - StrideN *= 2; - StrideK /= 2; - } - - } else if (TransA == CblasNoTrans) { - - while (StrideN > 16 && StrideN / 2 >= N) { - StrideK *= 2; - StrideN /= 2; - } - } - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < N; n += CountN) { - - CountN = std::min(N - n, StrideN); - - // - // Multiply the output matrix by beta as needed. - // - - if (beta != 0.0f && beta != 1.0f) { - MlasSgemmMultiplyBeta(C + n, M, CountN, ldc, beta); - } - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - bool ZeroMode = (beta == 0.0f); - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, StrideK); - - // - // Copy or transpose a panel of matrix B to a local packed buffer. - // - - if (TransB == CblasNoTrans) { - MlasSgemmCopyPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); - } else { - MlasSgemmTransposePackB(PanelB, B + k + n * ldb, ldb, CountN, CountK); - } - - // - // Step through each slice of matrix A along the M dimension. - // - - float* c = C + n; - - if (TransA == CblasNoTrans) { - - MlasSgemmKernelLoop(A + k, PanelB, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); - - } else { - - const float* a = A + k * lda; - size_t RowsRemaining = M; - - while (RowsRemaining > 0) { - - // - // Transpose elements from matrix A into a local buffer. - // - - size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_SGEMM_TRANSA_ROWS)); - - MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); - - RowsRemaining -= RowsTransposed; - a += RowsTransposed; - - // - // Step through the rows of the local buffer. - // - - c = MlasSgemmKernelLoop(PanelA, PanelB, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); - } - } - - ZeroMode = false; - } - } -} - -void -MlasSgemmPackedOperation( - CBLAS_TRANSPOSE TransA, - size_t M, - size_t RangeStartN, - size_t RangeCountN, - size_t K, - float alpha, - const float* A, - size_t lda, - const void* PackedB, - size_t AlignedN, - float beta, - float* C, - size_t ldc - ) -/*++ - -Routine Description: - - This routine implements the single precision matrix/matrix multiply - operation (SGEMM). - -Arguments: - - TransA - Supplies the transpose operation for matrix A. - - M - Supplies the number of rows of matrix A and matrix C. - - RangeStartN - Supplies the starting column from packed matrix B. - - RangeCountN - Supplies the number of columns of matrix B and matrix C. - - K - Supplies the number of columns of matrix A and the number of rows of - matrix B. - - alpha - Supplies the scalar alpha multiplier (see SGEMM definition). - - A - Supplies the address of matrix A. - - lda - Supplies the first dimension of matrix A. - - PackedB - Supplies the address of packed matrix B. - - AlignedN - Supplies the total number of aligned columns for packed matrix B. - - ldb - Supplies the first dimension of matrix B. - - beta - Supplies the scalar beta multiplier (see SGEMM definition). - - C - Supplies the address of matrix C. - - ldc - Supplies the first dimension of matrix C. - -Return Value: - - None. - ---*/ -{ - float PanelA[MLAS_SGEMM_TRANSA_ROWS * MLAS_SGEMM_PACKED_STRIDEK]; - - // - // Step through each slice of matrix B along the N dimension. - // - - size_t CountN; - - for (size_t n = 0; n < RangeCountN; n += CountN) { - - const size_t SliceStartN = RangeStartN + n; - - CountN = std::min(RangeCountN - n, size_t(MLAS_SGEMM_PACKED_STRIDEN)); - - // - // Multiply the output matrix by beta as needed. - // - - if (beta != 0.0f && beta != 1.0f) { - MlasSgemmMultiplyBeta(C + n, M, CountN, ldc, beta); - } - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - bool ZeroMode = (beta == 0.0f); - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); - - // - // Step through each slice of matrix A along the M dimension. - // - - const float* pb = (const float*)PackedB + AlignedN * k + CountK * SliceStartN; - float* c = C + n; - - if (TransA == CblasNoTrans) { - - MlasSgemmKernelLoop(A + k, pb, c, CountK, M, CountN, lda, ldc, alpha, ZeroMode); - - } else { - - const float* a = A + k * lda; - size_t RowsRemaining = M; - - while (RowsRemaining > 0) { - - // - // Transpose elements from matrix A into a local buffer. - // - - size_t RowsTransposed = std::min(RowsRemaining, size_t(MLAS_SGEMM_TRANSA_ROWS)); - - MlasSgemmTransposeA(PanelA, a, lda, RowsTransposed, CountK); - - RowsRemaining -= RowsTransposed; - a += RowsTransposed; - - // - // Step through the rows of the local buffer. - // - - c = MlasSgemmKernelLoop(PanelA, pb, c, CountK, RowsTransposed, CountN, CountK, ldc, alpha, ZeroMode); - } - } - - ZeroMode = false; - } - } -} - -void -MlasSgemmThreaded( - const ptrdiff_t ThreadCountM, - const ptrdiff_t ThreadCountN, - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, - const size_t M, - const size_t N, - const size_t K, - - const MLAS_SGEMM_DATA_PARAMS* DataParams, - ptrdiff_t ThreadId - ) -/*++ - -Routine Description: - - This routine is invoked from a worker thread to execute a segment of a - SGEMM operation. - -Arguments: - - ThreadCountM - Supplies the total thread partition on the M dimension. - - ThreadCountN - Supplies the total thread partition on the N dimension. - - TransA - Supplies the transpose operation on A matrix - - TransB - Supplies the transpose operation on B matrix - - M, N, K - Supplies the shape of the multiplication - - DataParams - Supplies the data position and layout of the matrices - - ThreadId - Supplies the current index of the threaded operation. - -Return Value: - - None. - ---*/ -{ - - const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; - const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; - - // - // Partition the operation along the M dimension. - // - - size_t RangeStartM; - size_t RangeCountM; - - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); - - // - // Partition the operation along the N dimension. - // - - size_t RangeStartN; - size_t RangeCountN; - - const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, - &RangeCountN); - - RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); - - // - // Dispatch the partitioned operation. - // - - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const float* A = DataParams->A + RangeStartM * ((TransA == CblasNoTrans) ? lda : 1); - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - if (DataParams->BIsPacked) { - - MlasSgemmPackedOperation(TransA, RangeCountM, RangeStartN, RangeCountN, - K, DataParams->alpha, A, lda, DataParams->B, - BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, DataParams->beta, C, ldc); - - } else { - - const size_t ldb = DataParams->ldb; - - const float* B = (const float*)DataParams->B + RangeStartN * ((TransB == CblasNoTrans) ? 1 : ldb); - - MlasSgemmOperation(TransA, TransB, RangeCountM, RangeCountN, K, - DataParams->alpha, A, lda, B, ldb, DataParams->beta, C, ldc); - } -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif -void -MLASCALL -MlasGemmBatch( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_SGEMM_DATA_PARAMS* Data, - size_t BatchSize, - MLAS_THREADPOOL* ThreadPool - ) -{ - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K); - - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - // - // Segment the operation across multiple threads. - // - // N.B. Currently, the operation is segmented as a 1D partition, which - // works okay for operations involving skinny matrices. - // - - ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchSize - 1) / BatchSize; - ptrdiff_t ThreadCountM; - ptrdiff_t ThreadCountN; - - if (N > M) { - - const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - if (size_t(ThreadsPerGemm) > BlockedN) { - ThreadsPerGemm = ptrdiff_t(BlockedN); - } - - ThreadCountM = 1; - ThreadCountN = ThreadsPerGemm; - - } else { - - if (size_t(ThreadsPerGemm) > M) { - ThreadsPerGemm = ptrdiff_t(M); - } - - ThreadCountM = ThreadsPerGemm; - ThreadCountN = 1; - } - - MlasTrySimpleParallel(ThreadPool, - ThreadsPerGemm * static_cast(BatchSize), - [=](ptrdiff_t tid) - { - ptrdiff_t GemmIdx = tid / ThreadsPerGemm; - ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; - MlasSgemmThreaded(ThreadCountM, ThreadCountN, - TransA, TransB, M, N, K, &(Data[GemmIdx]), ThreadIdx); - }); -} -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - -size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K - ) -/*++ - -Routine Description: - - This routine computes the length in bytes for the packed matrix B buffer. - -Arguments: - - N - Supplies the number of columns of matrix B. - - K - Supplies the number of rows of matrix B. - -Return Value: - - Returns the size in bytes for the packed matrix B buffer. - ---*/ -{ - // - // Compute the number of bytes required to hold the packed buffer. - // - - const size_t AlignedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); - - const size_t BytesRequired = AlignedN * K * sizeof(float); - const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & - ~(BufferAlignment - 1); - - return AlignedBytesRequired; -} - -void -MLASCALL -MlasGemmPackB( - CBLAS_TRANSPOSE TransB, - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ) -/*++ - -Routine Description: - - This routine packs the contents of matrix B to the destination buffer. The - destination buffer should be sized based on MlasGemmPackBSize(). For best - performance, the destination buffer should be aligned to the value returned - from MlasGetPreferredBufferAlignment(). - -Arguments: - - TransB - Supplies the transpose operation for matrix B. - - N - Supplies the number of columns of matrix B. - - K - Supplies the number of rows of matrix B. - - B - Supplies the address of matrix B. - - ldb - Supplies the first dimension of matrix B. - - PackedB - Supplies the address of packed matrix B. - -Return Value: - - None. - ---*/ -{ - const size_t AlignedN = - (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); - - // - // Step through each slice of matrix B along the K dimension. - // - - size_t CountK; - - for (size_t k = 0; k < K; k += CountK) { - - CountK = std::min(K - k, size_t(MLAS_SGEMM_PACKED_STRIDEK)); - - if (TransB == CblasNoTrans) { - MlasSgemmCopyPackB((float*)PackedB, B + k * ldb, ldb, N, CountK); - } else { - MlasSgemmTransposePackB((float*)PackedB, B + k, ldb, N, CountK); - } - - PackedB = (float*)PackedB + AlignedN * CountK; - } -} diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp deleted file mode 100644 index f9cf1605787aa..0000000000000 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ /dev/null @@ -1,1896 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - snchwc.cpp - -Abstract: - - This module implements the single precision operations using the NCHWc - blocking format. - ---*/ - -#include "mlasi.h" - -// -// Define the base thread context for NCWHc convolution or pooling operations. -// - -struct MLAS_NCHWC_WORK_BLOCK -{ - ptrdiff_t tids; - size_t BatchCount; - size_t InputChannels; - size_t InputShape[2]; - size_t InputSize; - size_t OutputChannels; - size_t OutputShape[2]; - size_t OutputSize; - size_t KernelShape[2]; - size_t DilationShape[2]; - size_t Padding[4]; - size_t StrideShape[2]; - size_t OutputCountLeftPad[2]; - size_t OutputCount[2]; - size_t OutputCountRightPad[2]; -}; - -// -// Define the worker thread context for a NCHWc convolution operation. -// - -struct MLAS_NCHWC_CONV_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK -{ - const float* Input; - const float* Filter; - const float* Bias; - const MLAS_ACTIVATION* Activation; - float* Output; - size_t GroupCount; - bool ZeroMode; -}; - -// -// Define the worker thread context for a NCHWc pooling operation. -// - -struct MLAS_NCHWC_POOL_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK -{ - const float* Input; - float* Output; - MLAS_POOLING_KIND PoolingKind; -}; - -// -// Define the convolution kernel flags. -// - -#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 -#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 -#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 -#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 - -size_t -MLASCALL -MlasNchwcGetBlockSize( - void - ) -/*++ - -Routine Description: - - This routine returns the NCHWc block size for the platform. - -Arguments: - - None. - -Return Value: - - Returns the NCHWc block size for the platform. If NCHWc support is not - available for the platform, then returns one. - - N.B. Using the value one as the flag to indicate no support avoids compiler - warnings in optimized builds when using this value in division or modulus - math. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - return GetMlasPlatform().NchwcBlockSize; -#else - return 1; -#endif -} - -void -MlasNchwcPrepareWorkBlock( - MLAS_NCHWC_WORK_BLOCK* WorkBlock, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape - ) -/*++ - -Routine Description: - - This routine prepares for a convolution or pooling operation by computing - required parameters given the shape attributes. - -Arguments: - - WorkBlock - Supplies the structure that contains the common convolution - and pooling parameters. - - InputShape - Supplies the shape of the input tensor. - - KernelShape - Supplies the shape of the kernel transform. - - DilationShape - Supplies the shape of the dilation. - - Padding - Supplies the number of padding elements at the edge of the input - tensor. - - StrideShape - Supplies the shape of the stride. - - OutputShape - Supplies the shape of the output tensor. - -Return Value: - - None. - ---*/ -{ - // - // Extract and skip over the the batch and channel counts. - // - - WorkBlock->BatchCount = size_t(InputShape[0]); - WorkBlock->InputChannels = size_t(InputShape[1]); - WorkBlock->OutputChannels = size_t(OutputShape[1]); - - InputShape += 2; - OutputShape += 2; - - // - // Extract the shape information along each dimension. - // - - size_t InputSize = 1; - size_t OutputSize = 1; - bool CanFlattenShape = true; - - for (size_t dim = 0; dim < 2; dim++) { - - const size_t InputValue = size_t(InputShape[dim]); - const size_t OutputValue = size_t(OutputShape[dim]); - - WorkBlock->InputShape[dim] = InputValue; - WorkBlock->OutputShape[dim] = OutputValue; - - InputSize *= InputValue; - OutputSize *= OutputValue; - - if (KernelShape != nullptr) { - WorkBlock->KernelShape[dim] = size_t(KernelShape[dim]); - } else { - WorkBlock->KernelShape[dim] = InputValue; - } - - if (DilationShape != nullptr) { - WorkBlock->DilationShape[dim] = size_t(DilationShape[dim]); - } else { - WorkBlock->DilationShape[dim] = 1; - } - - CanFlattenShape &= (WorkBlock->DilationShape[dim] == 1); - - if (Padding != nullptr) { - WorkBlock->Padding[dim] = size_t(Padding[dim]); - WorkBlock->Padding[dim + 2] = size_t(Padding[dim + 2]); - } else { - WorkBlock->Padding[dim] = 0; - WorkBlock->Padding[dim + 2] = 0; - } - - CanFlattenShape &= (WorkBlock->Padding[dim] == 0 && WorkBlock->Padding[dim + 2] == 0); - - if (StrideShape != nullptr) { - WorkBlock->StrideShape[dim] = size_t(StrideShape[dim]); - } else { - WorkBlock->StrideShape[dim] = 1; - } - - CanFlattenShape &= (WorkBlock->StrideShape[dim] == 1); - } - - WorkBlock->InputSize = InputSize; - WorkBlock->OutputSize = OutputSize; - - // - // Detect operations where the kernel is using the entire input width, - // has strides and dilations set to one, and no padding. These operations - // are transformed from outputting [N][1] to [1][N] by flattening the - // operation to a single line using striding equal to the original width. - // - // With the originally shape, the NCHWc kernels would process a single - // output per output line. After reshaping, the NCHWc kernels are able to - // process multiple outputs per output line which typically performs better, - // despite potentially using fewer threads due to the decreased output - // height. - // - - if (CanFlattenShape && (WorkBlock->InputShape[1] == WorkBlock->KernelShape[1])) { - - WorkBlock->StrideShape[1] = WorkBlock->InputShape[1]; - - WorkBlock->InputShape[1] *= WorkBlock->InputShape[0]; - WorkBlock->InputShape[0] = 1; - - WorkBlock->OutputShape[1] *= WorkBlock->OutputShape[0]; - WorkBlock->OutputShape[0] = 1; - - WorkBlock->KernelShape[1] *= WorkBlock->KernelShape[0]; - WorkBlock->KernelShape[0] = 1; - } - - // - // Compute the number of output elements affected by left and right padding. - // - - for (size_t dim = 0; dim < 2; dim++) { - - const size_t SpanValue = - WorkBlock->DilationShape[dim] * (WorkBlock->KernelShape[dim] - 1) + 1; - const size_t StrideValue = WorkBlock->StrideShape[dim]; - const size_t PaddingLeftValue = WorkBlock->Padding[dim]; - const size_t InputValue = WorkBlock->InputShape[dim]; - - size_t OutputCountWithLeftPad; - - if (InputValue + PaddingLeftValue >= SpanValue) { - OutputCountWithLeftPad = (InputValue + PaddingLeftValue - SpanValue) / StrideValue + 1; - } else { - OutputCountWithLeftPad = 0; - } - - size_t OutputCountLeftPad = (PaddingLeftValue + StrideValue - 1) / StrideValue; - - if (OutputCountLeftPad > OutputCountWithLeftPad) { - OutputCountLeftPad = OutputCountWithLeftPad; - } - - const size_t OutputValue = WorkBlock->OutputShape[dim]; - - WorkBlock->OutputCountLeftPad[dim] = OutputCountLeftPad; - WorkBlock->OutputCount[dim] = OutputCountWithLeftPad - OutputCountLeftPad; - WorkBlock->OutputCountRightPad[dim] = OutputValue - OutputCountWithLeftPad; - } -} - -// -// Base implementation for neural network algorithms (convolution and pooling). -// - -struct MLAS_NCHWC_NN_ALGORITHM -{ - static constexpr size_t HeightShapeIndex = 0; - static constexpr size_t WidthShapeIndex = 1; - - const size_t BlockSize = MlasNchwcGetBlockSize(); - - // - // Capture these values from the work block for use as local constants. - // - - const size_t BatchCount; - const size_t InputChannels; - const size_t OutputChannels; - const size_t InputHeight; - const size_t InputWidth; - const size_t InputSize; - const size_t OutputHeight; - const size_t OutputWidth; - const size_t OutputSize; - const size_t KernelHeight; - const size_t KernelWidth; - const size_t KernelSize; - const size_t DilationHeight; - const size_t DilationWidth; - const size_t PaddingLeftY; - const size_t PaddingLeftX; - const size_t StrideHeight; - const size_t StrideWidth; - const size_t OutputCountLeftPadY; - const size_t OutputCountY; - const size_t OutputCountLeftPadX; - const size_t OutputCountX; - const size_t OutputCountRightPadX; - - MLAS_NCHWC_NN_ALGORITHM(const MLAS_NCHWC_WORK_BLOCK* WorkBlock) : - BatchCount(WorkBlock->BatchCount), - InputChannels(WorkBlock->InputChannels), - OutputChannels(WorkBlock->OutputChannels), - InputHeight(WorkBlock->InputShape[HeightShapeIndex]), - InputWidth(WorkBlock->InputShape[WidthShapeIndex]), - InputSize(WorkBlock->InputSize), - OutputHeight(WorkBlock->OutputShape[HeightShapeIndex]), - OutputWidth(WorkBlock->OutputShape[WidthShapeIndex]), - OutputSize(WorkBlock->OutputSize), - KernelHeight(WorkBlock->KernelShape[HeightShapeIndex]), - KernelWidth(WorkBlock->KernelShape[WidthShapeIndex]), - KernelSize(KernelHeight * KernelWidth), - DilationHeight(WorkBlock->DilationShape[HeightShapeIndex]), - DilationWidth(WorkBlock->DilationShape[WidthShapeIndex]), - PaddingLeftY(WorkBlock->Padding[HeightShapeIndex]), - PaddingLeftX(WorkBlock->Padding[WidthShapeIndex]), - StrideHeight(WorkBlock->StrideShape[HeightShapeIndex]), - StrideWidth(WorkBlock->StrideShape[WidthShapeIndex]), - OutputCountLeftPadY(WorkBlock->OutputCountLeftPad[HeightShapeIndex]), - OutputCountY(WorkBlock->OutputCount[HeightShapeIndex]), - OutputCountLeftPadX(WorkBlock->OutputCountLeftPad[WidthShapeIndex]), - OutputCountX(WorkBlock->OutputCount[WidthShapeIndex]), - OutputCountRightPadX(WorkBlock->OutputCountRightPad[WidthShapeIndex]) - { - } -}; - -constexpr size_t MLAS_NCHWC_NN_ALGORITHM::HeightShapeIndex; -constexpr size_t MLAS_NCHWC_NN_ALGORITHM::WidthShapeIndex; - -template -void -MlasNchwcThreaded( - void* Context, - ptrdiff_t Index - ) -{ - AlgorithmType((decltype(AlgorithmType::WorkBlock))Context).Execute(Index); -} - -// -// Base implementation for convolution algorithms. -// - -struct MLAS_NCHWC_CONV_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM -{ - // - // Capture these values from the work block for use as local constants. - // - - const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock; - const size_t GroupCount; - const MLAS_ACTIVATION* Activation; - const MLAS_ACTIVATION_KIND ActivationKind; - const bool ZeroMode; - - // - // Capture the buffer pointers from the work block. - // - // These fields are updated as the threads step through the convolution - // operation. - // - - const float* Input; - const float* Filter; - const float* Bias; - float* Output; - - MLAS_NCHWC_CONV_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_NN_ALGORITHM(WorkBlock), - WorkBlock(WorkBlock), - GroupCount(WorkBlock->GroupCount), - Activation(WorkBlock->Activation), - ActivationKind(Activation->ActivationKind), - ZeroMode(WorkBlock->ZeroMode) - { - Input = WorkBlock->Input; - Filter = WorkBlock->Filter; - Bias = WorkBlock->Bias; - Output = WorkBlock->Output; - } - - unsigned - ComputeKernelFlags( - size_t ic, - size_t ChannelCount - ) - { - unsigned KernelFlags = 0; - - // - // Accumulate into the output buffer if this isn't the first input - // channel contributing to the output element or if the caller has - // requested that the output buffer not be zero initialized (Conv/Sum - // fusion). - // - - if (ic != 0 || !ZeroMode) { - KernelFlags |= MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT; - } - - if (ic + ChannelCount == InputChannels) { - - // - // Add the bias buffer into the output buffer if necessary. - // - - if (Bias != nullptr) { - KernelFlags |= MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION; - } - - // - // Test for fused ReLU activation or other types of activation run - // outside of the convolution kernel. - // - - if (ActivationKind == MlasReluActivation) { - KernelFlags |= MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION; - } else if (ActivationKind != MlasIdentityActivation) { - KernelFlags |= MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION; - } - } - - return KernelFlags; - } - - void - ComputeEffectiveKernel( - size_t ph, - size_t FilterStride, - const float** filter, - size_t* ih, - size_t* EffectiveKernelHeight - ) - { - // - // Compute the first input row and kernel height. If this output row - // uses padding from one or more input padding rows, then adjust the - // kernel parameters to keep within the input bounds. - // - - *ih = ph * StrideHeight - PaddingLeftY; - *EffectiveKernelHeight = KernelHeight; - - if ((ph - OutputCountLeftPadY) >= OutputCountY) { - - size_t ihStep = *ih; - - for (size_t kh = 0; kh < KernelHeight; kh++) { - - if (ihStep >= InputHeight) { - - if (ihStep == *ih) { - *ih += DilationHeight; - *filter += FilterStride; - } - - *EffectiveKernelHeight -= 1; - } - - ihStep += DilationHeight; - } - } - } - - void - DoActivation( - float* output, - size_t FilterCount, - size_t BlockedOutputWidth - ) - { - // - // Invoke activation doing an inplace update. - // - // The width of the output matrix is the number of written output - // elements. Pointwise convolution may write multiple logical rows - // at once, so this output count may be greater than OutputWidth. - // - // The convolution kernels write to one or more output positions - // across NCHWc output planes, so the stride is set to the blocked - // output size instead of the output width as done in NCHW convolution. - // - - MlasActivation(Activation, output, nullptr, FilterCount, - BlockedOutputWidth, BlockSize * OutputSize); - } -}; - -// -// Base implementation for grouped convolution algorithms. -// - -struct MLAS_NCHWC_GROUPED_CONV_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM -{ - // - // Slice the convolution operation such that multiple filter blocks are - // reused for a given set of input inside the kernel. - // - - static constexpr size_t FilterSetSize = 4; - - const size_t FilterSetCount; - - // - // Stores the current output line, filter cluster, and group that this thread - // is operating on. - // - - size_t ph; - size_t FilterSet; - size_t Group; - size_t WorkRemaining; - size_t FilterCount; - - MLAS_NCHWC_GROUPED_CONV_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_CONV_ALGORITHM(WorkBlock), - FilterSetCount((OutputChannels + (BlockSize * FilterSetSize) - 1) / (BlockSize * FilterSetSize)) - { - } - - void ComputeFilterCount(void) - { - FilterCount = std::min(FilterSetSize, (OutputChannels / BlockSize) - FilterSet * FilterSetSize); - } - - void PrepareWork(ptrdiff_t Index) - { - const size_t TotalWork = BatchCount * GroupCount * FilterSetCount * OutputHeight; - - size_t WorkIndex; - - MlasPartitionWork(Index, WorkBlock->tids, TotalWork, &WorkIndex, &WorkRemaining); - - // - // Extract the current batch, group, filter cluster, and output line - // from the starting work index. - // - - ph = WorkIndex % OutputHeight; - const size_t BatchGroupFilterSet = WorkIndex / OutputHeight; - - FilterSet = BatchGroupFilterSet % FilterSetCount; - const size_t BatchGroup = BatchGroupFilterSet / FilterSetCount; - - Group = BatchGroup % GroupCount; - - // - // Advance the convolution buffer pointers to the current position - // computed above. - // - - Input += BatchGroup * InputChannels * InputSize; - - Output += BatchGroup * OutputChannels * OutputSize; - Output += BlockSize * FilterSet * FilterSetSize * OutputSize; - - Filter += Group * OutputChannels * InputChannels * KernelSize; - Filter += BlockSize * FilterSet * FilterSetSize * InputChannels * KernelSize; - - if (Bias != nullptr) { - Bias += Group * OutputChannels; - Bias += BlockSize * FilterSet * FilterSetSize; - } - - // - // Compute the number of filter set to use for the next iteration. - // - - ComputeFilterCount(); - } - - void CompleteWork(size_t WorkThisIteration) - { - // - // Adjust the amount of work remaining and check if the end of an output - // image has been reached. - // - - WorkRemaining -= WorkThisIteration; - - if ((ph += WorkThisIteration) == OutputHeight) { - - size_t BlockedFilterCount = BlockSize * FilterCount; - - Output += BlockedFilterCount * OutputSize; - Filter += BlockedFilterCount * InputChannels * KernelSize; - - if (Bias != nullptr) { - Bias += BlockedFilterCount; - } - - // - // Advance the input if the all filter sets have been processed. - // - - if (++FilterSet == FilterSetCount) { - - Input += InputChannels * InputSize; - - // - // Reset filter and bias if all groups have been processed. - // - - if (++Group == GroupCount) { - - Filter = WorkBlock->Filter; - Bias = WorkBlock->Bias; - - Group = 0; - } - - FilterSet = 0; - } - - ComputeFilterCount(); - - ph = 0; - } - } -}; - -constexpr size_t MLAS_NCHWC_GROUPED_CONV_ALGORITHM::FilterSetSize; - -// -// Implementation of the direct convolution algorithm where the input buffer is -// in NCHWc format. -// - -struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM -{ - MLAS_NCHWC_CONV_NCHWC_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_GROUPED_CONV_ALGORITHM(WorkBlock) - { - } - - void Execute(ptrdiff_t Index) - { - // - // Setup the convolution state based on the thread index. - // - - PrepareWork(Index); - - // - // Loop until all of the work has been completed. - // - - const size_t StrideWidthBytes = BlockSize * StrideWidth * sizeof(float); - const size_t DilationWidthBytes = BlockSize * DilationWidth * sizeof(float); - const size_t FilterStrideBytes = BlockSize * InputChannels * KernelSize * sizeof(float); - const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); - const size_t InputWidthBytes = BlockSize * InputWidth * sizeof(float); - const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); - const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; - - const size_t BlockedOutputWidth = BlockSize * OutputWidth; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; -#else - MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; -#endif - - while (WorkRemaining > 0) { - - // - // Compute the number of output lines to process in this iteration. - // - - size_t WorkThisIteration = std::min(WorkRemaining, OutputHeight - ph); - - // - // Walk over each input image organized as a set of NCHWc blocks. - // - - for (size_t ic = 0; ic < InputChannels; ic += BlockSize) { - - unsigned KernelFlags = ComputeKernelFlags(ic, BlockSize); - - // - // Apply the convolution kernel to each row of the output batch. - // - - const float* input = Input + ic * InputSize; - float* output = Output + ph * BlockedOutputWidth; - - for (size_t work = 0; work < WorkThisIteration; work++) { - - // - // Constrain the effective kernel parameters if the output row - // uses one or more input padding rows. - // - - const float* filter = Filter + BlockSize * ic * KernelSize; - size_t ih; - size_t EffectiveKernelHeight; - - ComputeEffectiveKernel(ph + work, BlockSize * BlockSize * KernelWidth, - &filter, &ih, &EffectiveKernelHeight); - - // - // Invoke the convolution kernel. - // - - Kernel(input + BlockSize * (ih * InputWidth - PaddingLeftX), - filter, output, StrideWidthBytes, DilationWidthBytes, - FilterCount, InputStrideBytes, FilterStrideBytes, - OutputStrideBytes, EffectiveKernelHeight, KernelWidth, - input + BlockSize * (ih * InputWidth), InputWidthBytes, - DilatedInputWidthBytes, OutputCountLeftPadX, OutputCountX, - OutputCountRightPadX, Bias, KernelFlags); - - // - // Test for fused non-ReLU activation. - // - - if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION) != 0) { - DoActivation(output, FilterCount, BlockedOutputWidth); - } - - output += BlockedOutputWidth; - } - } - - // - // Advance the convolution state based on the completed work. - // - - CompleteWork(WorkThisIteration); - } - } -}; - -// -// Implementation of the direct convolution algorithm where the input buffer is -// in NCHW format. -// - -struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM -{ - MLAS_NCHWC_CONV_NCHW_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_GROUPED_CONV_ALGORITHM(WorkBlock) - { - } - - void Execute(ptrdiff_t Index) - { - // - // Setup the convolution state based on the thread index. - // - - PrepareWork(Index); - - // - // Loop until all of the work has been completed. - // - - const size_t StrideWidthBytes = StrideWidth * sizeof(float); - const size_t DilationWidthBytes = DilationWidth * sizeof(float); - const size_t FilterStrideBytes = BlockSize * InputChannels * KernelSize * sizeof(float); - const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); - const size_t InputWidthBytes = InputWidth * sizeof(float); - const size_t DilatedInputWidthBytes = DilationHeight * InputWidth * sizeof(float); - const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; - - const size_t BlockedOutputWidth = BlockSize * OutputWidth; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; -#else - MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; -#endif - - while (WorkRemaining > 0) { - - // - // Constrain the effective kernel parameters if the output row uses - // one or more input padding rows. - // - - const float* filter = Filter; - size_t ih; - size_t EffectiveKernelHeight; - - ComputeEffectiveKernel(ph, BlockSize * KernelWidth, &filter, &ih, - &EffectiveKernelHeight); - - // - // Apply the convolution kernel to each channel of the input tensor. - // - - const float* input = Input; - float* output = Output + BlockSize * ph * OutputWidth; - - for (size_t ic = 0; ic < InputChannels; ic += 1) { - - unsigned KernelFlags = ComputeKernelFlags(ic, 1); - - // - // Invoke the convolution kernel. - // - - Kernel(input + (ih * InputWidth - PaddingLeftX), filter, output, - StrideWidthBytes, DilationWidthBytes, FilterCount, InputStrideBytes, - FilterStrideBytes, OutputStrideBytes, EffectiveKernelHeight, - KernelWidth, input + (ih * InputWidth), InputWidthBytes, - DilatedInputWidthBytes, OutputCountLeftPadX, OutputCountX, - OutputCountRightPadX, Bias, KernelFlags); - - // - // Test for fused non-ReLU activation. - // - - if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION) != 0) { - DoActivation(output, FilterCount, BlockedOutputWidth); - } - - input += InputSize; - filter += BlockSize * KernelSize; - } - - // - // Advance the convolution state based on the completed work. - // - - CompleteWork(1); - } - } -}; - -// -// Implementation of the pointwise convolution algorithm. -// -// Pointwise convolutions have a kernel size of one. To simplify this -// implementation, no input padding is allowed, which matches typical -// usage in models. -// - -struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM -{ - MLAS_NCHWC_CONV_POINTWISE_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_GROUPED_CONV_ALGORITHM(WorkBlock) - { - } - - void Execute(ptrdiff_t Index) - { - // - // Setup the convolution state based on the thread index. - // - - PrepareWork(Index); - - // - // Loop until all of the work has been completed. - // - - const size_t StrideWidthBytes = BlockSize * StrideWidth * sizeof(float); - const size_t InputStrideBytes = BlockSize * InputSize * sizeof(float); - const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); - const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; -#else - MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; -#endif - - while (WorkRemaining > 0) { - - // - // Compute the number of output blocks that can be computed in this - // iteration. Unstrided convolutions can treat the input and output - // as a single line which in turn allows the kernel to use wider - // multiply/accumulate loops. Otherwise, a strided convolution can - // output a single line at a time. - // - - size_t WorkThisIteration; - - if (StrideHeight == 1 && StrideWidth == 1) { - WorkThisIteration = std::min(WorkRemaining, OutputHeight - ph); - } else { - WorkThisIteration = 1; - } - - const size_t OutputThisIteration = WorkThisIteration * OutputWidth; - - // - // Apply the convolution kernel to batches of the input tensor. - // - // Shrinking the batch size causes a slowdown from additional - // flushing of intermediate results to the output tensor. Extending - // the batch sizes causes a slowdown from processor cache thrashing. - // - - const float* input = Input + BlockSize * (ph * StrideHeight * InputWidth); - const float* filter = Filter; - float* output = Output + BlockSize * ph * OutputWidth; - - size_t InputChannelBatch; - - for (size_t ic = 0; ic < InputChannels; ic += InputChannelBatch) { - - constexpr size_t MaximumInputChannelBatch = 128; - - InputChannelBatch = std::min(InputChannels - ic, MaximumInputChannelBatch); - - unsigned KernelFlags = ComputeKernelFlags(ic, InputChannelBatch); - - // - // Invoke the convolution kernel. - // - - Kernel(input, filter, output, StrideWidthBytes, InputChannelBatch / - BlockSize, FilterCount, InputStrideBytes, FilterStrideBytes, - OutputStrideBytes, OutputThisIteration, Bias, KernelFlags); - - // - // Test for fused non-ReLU activation. - // - - if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION) != 0) { - DoActivation(output, FilterCount, BlockSize * OutputThisIteration); - } - - input += MaximumInputChannelBatch * InputSize; - filter += BlockSize * MaximumInputChannelBatch; - } - - // - // Advance the convolution state based on the completed work. - // - - CompleteWork(WorkThisIteration); - } - } -}; - -// -// Implementation of the depthwise separable convolution algorithm. -// -// Depthwise separable convolutions are a form of grouped convolution where -// the number of input and output channels per group are one. -// - -struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM -{ - MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM(const MLAS_NCHWC_CONV_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_CONV_ALGORITHM(WorkBlock) - { - } - - void Execute(ptrdiff_t Index) - { - const size_t GroupBlockCount = ((GroupCount + BlockSize - 1) / BlockSize); - - const size_t TotalWork = BatchCount * GroupBlockCount * OutputHeight; - - size_t WorkIndex; - size_t WorkRemaining; - - MlasPartitionWork(Index, WorkBlock->tids, TotalWork, &WorkIndex, &WorkRemaining); - - // - // Extract the current batch, group block, and output line from the - // starting work index. - // - - size_t ph = WorkIndex % OutputHeight; - const size_t BatchGroup = WorkIndex / OutputHeight; - - size_t Group = BatchGroup % GroupBlockCount; - - // - // Advance the convolution buffer pointers to the current position - // computed above. - // - - Input += BatchGroup * BlockSize * InputSize; - Output += WorkIndex * BlockSize * OutputWidth; - Filter += Group * BlockSize * KernelSize; - - if (Bias != nullptr) { - Bias += BlockSize * Group; - } - - // - // Loop until all of the work has been completed. - // - - const size_t StrideWidthBytes = BlockSize * StrideWidth * sizeof(float); - const size_t DilationWidthBytes = BlockSize * DilationWidth * sizeof(float); - const size_t InputWidthBytes = BlockSize * InputWidth * sizeof(float); - const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); - const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; - - const size_t BlockedOutputWidth = BlockSize * OutputWidth; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; -#else - MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; -#endif - - unsigned KernelFlags = ComputeKernelFlags(0, InputChannels); - - while (WorkRemaining > 0) { - - // - // Constrain the effective kernel parameters if the output row uses - // one or more input padding rows. - // - - const float* filter = Filter; - size_t ih; - size_t EffectiveKernelHeight; - - ComputeEffectiveKernel(ph, BlockSize * KernelWidth, &filter, &ih, &EffectiveKernelHeight); - - // - // Invoke the convolution kernel. - // - - Kernel(Input + BlockSize * (ih * InputWidth - PaddingLeftX), filter, - Output, StrideWidthBytes, DilationWidthBytes, InputStrideBytes, - EffectiveKernelHeight, KernelWidth, Input + BlockSize * (ih * InputWidth), - InputWidthBytes, DilatedInputWidthBytes, OutputCountLeftPadX, - OutputCountX, OutputCountRightPadX, Bias, KernelFlags); - - // - // Test for fused non-ReLU activation. - // - - if ((KernelFlags & MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION) != 0) { - DoActivation(Output, 1, BlockedOutputWidth); - } - - Output += BlockedOutputWidth; - - // - // Adjust the amount of work remaining and check if the end of an - // output image has been reached. - // - - WorkRemaining -= 1; - - if (++ph == OutputHeight) { - - Input += BlockSize * InputSize; - Filter += BlockSize * KernelSize; - - if (Bias != nullptr) { - Bias += BlockSize; - } - - if (++Group == GroupBlockCount) { - - Filter = WorkBlock->Filter; - Bias = WorkBlock->Bias; - - Group = 0; - } - - ph = 0; - } - } - } -}; - -// -// Implementation of the pooling algorithm. -// - -struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM -{ -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) - static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; -#endif - - const MLAS_NCHWC_POOL_WORK_BLOCK* WorkBlock; - - MLAS_NCHWC_POOL_ALGORITHM(const MLAS_NCHWC_POOL_WORK_BLOCK* WorkBlock) : - MLAS_NCHWC_NN_ALGORITHM(WorkBlock), - WorkBlock(WorkBlock) - { - } - - void Execute(ptrdiff_t Index) - { - const size_t TotalWork = - ((BatchCount * InputChannels + BlockSize - 1) / BlockSize) * OutputHeight; - - size_t WorkIndex; - size_t WorkRemaining; - - MlasPartitionWork(Index, WorkBlock->tids, TotalWork, &WorkIndex, &WorkRemaining); - - size_t ph = WorkIndex % OutputHeight; - const size_t BatchChannel = WorkIndex / OutputHeight; - - const float* Input = WorkBlock->Input + BatchChannel * BlockSize * InputSize; - float* Output = WorkBlock->Output + WorkIndex * BlockSize * OutputWidth; - - // - // Loop until all of the work has been completed. - // - - const size_t StrideWidthBytes = BlockSize * StrideWidth * sizeof(float); - const size_t DilationWidthBytes = BlockSize * DilationWidth * sizeof(float); - const size_t InputWidthBytes = BlockSize * InputWidth * sizeof(float); - const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); - const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; -#else - MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; -#endif - - while (WorkRemaining > 0) { - - // - // Compute the first input row and kernel height. If this output row - // uses padding from one or more input padding rows, then adjust the - // kernel parameters to keep within the input bounds. - // - - size_t ih = ph * StrideHeight - PaddingLeftY; - size_t EffectiveKernelHeight = KernelHeight; - - if ((ph - OutputCountLeftPadY) >= OutputCountY) { - - size_t ihStep = ih; - - for (size_t kh = 0; kh < KernelHeight; kh++) { - - if (ihStep >= InputHeight) { - - if (ihStep == ih) { - ih += DilationHeight; - } - - EffectiveKernelHeight -= 1; - } - - ihStep += DilationHeight; - } - } - - // - // Invoke the pooling kernel. - // - - Kernel(Input + BlockSize * (ih * InputWidth - PaddingLeftX), Output, - StrideWidthBytes, DilationWidthBytes, InputStrideBytes, - KernelSize, EffectiveKernelHeight, KernelWidth, - Input + BlockSize * (ih * InputWidth), InputWidthBytes, - DilatedInputWidthBytes, OutputCountLeftPadX, OutputCountX, - OutputCountRightPadX); - - Output += BlockSize * OutputWidth; - - // - // Adjust the amount of work remaining and check if the end of an output - // image has been reached. - // - - WorkRemaining -= 1; - - if (++ph == OutputHeight) { - - Input += BlockSize * InputSize; - - ph = 0; - } - } - } -}; - -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) - -MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = -{ - MlasPoolMaximumFloatKernel, - MlasPoolAverageExcludePadFloatKernel, - MlasPoolAverageIncludePadFloatKernel, -}; - -#endif - -void -MLASCALL -MlasNchwcConv( - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - size_t GroupCount, - const float* Input, - const float* Filter, - const float* Bias, - float* Output, - const MLAS_ACTIVATION* Activation, - bool ZeroMode, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the NCHWc convolution operation. - -Arguments: - - Dimensions - Supplies the number of dimensions. - - InputShape - Supplies the shape of the input tensor. - - KernelShape - Supplies the shape of the kernel transform. - - DilationShape - Supplies the shape of the dilation. - - Padding - Supplies the number of padding elements at the edge of the input - tensor. - - StrideShape - Supplies the shape of the stride. - - OutputShape - Supplies the shape of the output tensor. - - GroupCount - Supplies the number of channel groups. - - Input - Supplies the input tensor. - - Filter - Supplies the filter tensor. - - Bias - Optionally supplies the bias vector. - - Output - Supplies the output tensor. - - Activation - Supplies the parameters for the activation to apply to the - convolution output. - - ZeroMode - Supplies true if the output tensor must be zero initialized - first, else false if the output tensor is accumulated into. This flag is - used to implement Conv/Sum fusion. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_NCHWC_CONV_WORK_BLOCK WorkBlock; - - // - // Capture the convolution specific parameters to the work block. - // - - WorkBlock.Input = Input; - WorkBlock.Output = Output; - WorkBlock.GroupCount = GroupCount; - WorkBlock.Filter = Filter; - WorkBlock.Bias = Bias; - WorkBlock.Activation = Activation; - WorkBlock.ZeroMode = ZeroMode; - - // - // Capture the generic shape parameters to the work block. - // - - MlasNchwcPrepareWorkBlock(&WorkBlock, InputShape, KernelShape, - DilationShape, Padding, StrideShape, OutputShape); - - WorkBlock.InputChannels /= GroupCount; - WorkBlock.OutputChannels /= GroupCount; - - // - // Determine the type of convolution to perform based on the shape - // parameters. - // - // N.B. The caller must be aware of the selection algorithm in order to - // reorder the filter tensor in the expected format for the given algorithm. - // - - MLAS_THREADED_ROUTINE* ThreadedRoutine; - - if (WorkBlock.InputChannels >= MlasNchwcGetBlockSize()) { - if (WorkBlock.KernelShape[0] == 1 && WorkBlock.KernelShape[1] == 1 && - WorkBlock.Padding[0] == 0 && WorkBlock.Padding[1] == 0 && - WorkBlock.Padding[2] == 0 && WorkBlock.Padding[3] == 0) { - ThreadedRoutine = MlasNchwcThreaded; - } else { - ThreadedRoutine = MlasNchwcThreaded; - } - } else if (WorkBlock.InputChannels == 1 && WorkBlock.OutputChannels == 1) { - ThreadedRoutine = MlasNchwcThreaded; - } else { - ThreadedRoutine = MlasNchwcThreaded; - } - - // - // Schedule the operation across a set of worker threads. - // - - WorkBlock.tids = MlasGetMaximumThreadCount(ThreadPool); - - MlasExecuteThreaded(ThreadedRoutine, &WorkBlock, WorkBlock.tids, ThreadPool); -} - -void -MLASCALL -MlasNchwcPool( - MLAS_POOLING_KIND PoolingKind, - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* DilationShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output, - MLAS_THREADPOOL* ThreadPool - ) -/*++ - -Routine Description: - - This routine implements the NCHWc pooling operation. - -Arguments: - - PoolingKind - Supplies the kind of pooling operation to perform. - - InputShape - Supplies the shape of the input tensor. - - KernelShape - Supplies the shape of the kernel transform. - - DilationShape - Supplies the shape of the dilation. - - Padding - Supplies the number of padding elements at the edge of the input - tensor. - - StrideShape - Supplies the shape of the stride. - - OutputShape - Supplies the shape of the output tensor. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - - ThreadPool - Supplies the thread pool object to use, else nullptr if the - base library threading support should be used. - -Return Value: - - None. - ---*/ -{ - MLAS_NCHWC_POOL_WORK_BLOCK WorkBlock; - - // - // Capture the pooling specific parameters to the work block. - // - - WorkBlock.Input = Input; - WorkBlock.Output = Output; - WorkBlock.PoolingKind = PoolingKind; - - // - // Capture the generic shape parameters to the work block. - // - - MlasNchwcPrepareWorkBlock(&WorkBlock, InputShape, KernelShape, - DilationShape, Padding, StrideShape, OutputShape); - - // - // Schedule the operation across a set of worker threads. - // - - WorkBlock.tids = MlasGetMaximumThreadCount(ThreadPool); - - MlasExecuteThreaded(MlasNchwcThreaded, &WorkBlock, WorkBlock.tids, ThreadPool); -} - -void -MLASCALL -MlasNchwcUpsampleNearest( - const int64_t* InputShape, - const int64_t* Scales, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements the NCHWc upsample nearest operation. - -Arguments: - - InputShape - Supplies the shape of the input tensor. - - Scales - Supplies the shape of the spatial scaling. - - Input - Supplies the input tensor. - - Output - Supplies the output tensor. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - const size_t BatchCount = size_t(InputShape[0]); - const size_t ChannelCount = size_t(InputShape[1]); - const size_t InputHeight = size_t(InputShape[2]); - const size_t InputWidth = size_t(InputShape[3]); - - const size_t TotalInputHeight = BatchCount * ChannelCount * InputHeight; - - const size_t ScaleHeight = size_t(Scales[0]); - const size_t ScaleWidth = size_t(Scales[1]); - - const size_t OutputWidth = InputWidth * ScaleWidth; - - // - // Iterate over each line of the input tensor. - // - - for (size_t h = 0; h < TotalInputHeight; h += BlockSize) { - - float* OutputBaseRow = Output; - - // - // Scale the input tensor across the width dimension. - // - - for (size_t w = 0; w < InputWidth; w++) { - - if (BlockSize == 16) { - - MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(Input + 4); - MLAS_FLOAT32X4 v2 = MlasLoadFloat32x4(Input + 8); - MLAS_FLOAT32X4 v3 = MlasLoadFloat32x4(Input + 12); - - for (size_t sw = 0; sw < ScaleWidth; sw++) { - - MlasStoreFloat32x4(Output, v0); - MlasStoreFloat32x4(Output + 4, v1); - MlasStoreFloat32x4(Output + 8, v2); - MlasStoreFloat32x4(Output + 12, v3); - - Output += BlockSize; - } - - } else { - - MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(Input); - MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(Input + 4); - - for (size_t sw = 0; sw < ScaleWidth; sw++) { - - MlasStoreFloat32x4(Output, v0); - MlasStoreFloat32x4(Output + 4, v1); - - Output += BlockSize; - } - } - - Input += BlockSize; - } - - // - // Scale the input tensor across the height dimension by duplicating - // the first output line. - // - - for (size_t sh = 1; sh < ScaleHeight; sh++) { - Output = std::copy_n(OutputBaseRow, OutputWidth * BlockSize, Output); - } - } -} - -MLAS_FORCEINLINE -void -MlasNchwcExtractInterpolation( - float InterpolationValue, - size_t InputLimit, - ptrdiff_t InputIndex[2], - MLAS_FLOAT32X4 Multipliers[2] - ) -{ - InputIndex[0] = ptrdiff_t(InterpolationValue); - InputIndex[1] = std::min(InputIndex[0] + 1, ptrdiff_t(InputLimit - 1)); - - float ScalarMultiplier0 = InterpolationValue - float(InputIndex[0]); - float ScalarMultiplier1 = 1.0f - ScalarMultiplier0; - - Multipliers[0] = MlasBroadcastFloat32x4(ScalarMultiplier0); - Multipliers[1] = MlasBroadcastFloat32x4(ScalarMultiplier1); -} - -void -MLASCALL -MlasNchwcUpsampleLinear( - size_t InputHeight, - size_t InputWidth, - size_t OutputWidth, - float InterpolationHeight, - const float* InterpolationWidth, - const float* Input, - float* Output - ) -/*++ - -Routine Description: - - This routine implements the NCHWc upsample linear operation for a single row. - - The integer portion of each interpolation float supplies the mapping from - output element to input element. The fractional portion supplies the relative - weights for the four points of the interpolation. - -Arguments: - - InputHeight - Supplies the input height. - - InputWidth - Supplies the input width. - - OutputWidth - Supplies the output width. - - InterpolationHeight - Supplies the height interpolation values for the target - row. - - InterpolationWidth - Supplies an array of computed interpolation values of - length OutputWidth. - - Input - Supplies the input spatial buffer. - - Output - Supplies the output row buffer. - -Return Value: - - None. - ---*/ -{ - const size_t BlockSize = MlasNchwcGetBlockSize(); - - ptrdiff_t InputIndexY[2]; - MLAS_FLOAT32X4 MultipliersY[2]; - - MlasNchwcExtractInterpolation(InterpolationHeight, InputHeight, InputIndexY, MultipliersY); - - const float* InputRowY0 = Input + InputIndexY[0] * InputWidth * BlockSize; - const float* InputRowY1 = Input + InputIndexY[1] * InputWidth * BlockSize; - - for (size_t ow = 0; ow < OutputWidth; ow++) { - - ptrdiff_t InputIndexX[2]; - MLAS_FLOAT32X4 MultipliersX[2]; - - MlasNchwcExtractInterpolation(InterpolationWidth[ow], InputWidth, InputIndexX, MultipliersX); - - MLAS_FLOAT32X4 MultiplierY0X0 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[0]); - MLAS_FLOAT32X4 MultiplierY0X1 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[1]); - MLAS_FLOAT32X4 MultiplierY1X0 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[0]); - MLAS_FLOAT32X4 MultiplierY1X1 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[1]); - - for (size_t bc = 0; bc < BlockSize; bc += 4) { - - MLAS_FLOAT32X4 v00 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[0] * BlockSize + bc); - MLAS_FLOAT32X4 v01 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[1] * BlockSize + bc); - MLAS_FLOAT32X4 v10 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[0] * BlockSize + bc); - MLAS_FLOAT32X4 v11 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[1] * BlockSize + bc); - - v00 = MlasMultiplyFloat32x4(MultiplierY1X1, v00); - v01 = MlasMultiplyFloat32x4(MultiplierY1X0, v01); - v10 = MlasMultiplyFloat32x4(MultiplierY0X1, v10); - v11 = MlasMultiplyFloat32x4(MultiplierY0X0, v11); - - MLAS_FLOAT32X4 Reduction0 = MlasAddFloat32x4(v00, v01); - MLAS_FLOAT32X4 Reduction1 = MlasAddFloat32x4(v10, v11); - - MLAS_FLOAT32X4 Reduction = MlasAddFloat32x4(Reduction0, Reduction1); - - MlasStoreFloat32x4(&Output[bc], Reduction); - } - - Output += BlockSize; - } -} - -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) - -// -// Convolution and pooling kernel stubs for architectures that do not yet have -// native support. -// - -void -MLASCALL -MlasConvNchwFloatKernel( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned Flags - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Filter); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(FilterCount); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(FilterStride); - MLAS_UNREFERENCED_PARAMETER(OutputStride); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(Flags); -} - -void -MLASCALL -MlasConvNchwcFloatKernel( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned Flags - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Filter); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(FilterCount); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(FilterStride); - MLAS_UNREFERENCED_PARAMETER(OutputStride); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(Flags); -} - -void -MLASCALL -MlasConvDepthwiseFloatKernel( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned Flags - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Filter); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(Flags); -} - -void -MLASCALL -MlasConvPointwiseFloatKernel( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t InputChannels, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t OutputCount, - const float* Bias, - unsigned Flags - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Filter); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(InputChannels); - MLAS_UNREFERENCED_PARAMETER(FilterCount); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(FilterStride); - MLAS_UNREFERENCED_PARAMETER(OutputStride); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(Flags); -} - -void -MLASCALL -MlasPoolMaximumFloatKernel( - const float* Input, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t ActualKernelSize, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); -} - -void -MLASCALL -MlasPoolAverageExcludePadFloatKernel( - const float* Input, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t ActualKernelSize, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); -} - -void -MLASCALL -MlasPoolAverageIncludePadFloatKernel( - const float* Input, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t InputStride, - size_t ActualKernelSize, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad - ) -{ - MLAS_UNREFERENCED_PARAMETER(Input); - MLAS_UNREFERENCED_PARAMETER(Output); - MLAS_UNREFERENCED_PARAMETER(StrideWidth); - MLAS_UNREFERENCED_PARAMETER(DilationWidth); - MLAS_UNREFERENCED_PARAMETER(InputStride); - MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); - MLAS_UNREFERENCED_PARAMETER(KernelHeight); - MLAS_UNREFERENCED_PARAMETER(KernelWidth); - MLAS_UNREFERENCED_PARAMETER(InputBase); - MLAS_UNREFERENCED_PARAMETER(InputWidth); - MLAS_UNREFERENCED_PARAMETER(DilatedInputWidth); - MLAS_UNREFERENCED_PARAMETER(OutputCountLeftPad); - MLAS_UNREFERENCED_PARAMETER(OutputCount); - MLAS_UNREFERENCED_PARAMETER(OutputCountRightPad); -} - -#endif diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp deleted file mode 100644 index a45494ef2e04f..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ /dev/null @@ -1,771 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, - as well as some SQNBitGemm-related query functions. ---*/ - -#include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" - -#include - -namespace -{ - -enum SQNBitGemmVariant { - SQNBitGemmVariantInvalid = -1, - - // Valid variants - - SQNBitGemmVariant_BitWidth4_CompFp32 = 0, - SQNBitGemmVariant_BitWidth4_CompInt8, - - // End of valid variants - - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. - // Its value is used as an array size. - SQNBitGemmVariantCount, -}; - -SQNBitGemmVariant -GetSQNBitGemmVariant( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } - } - - return SQNBitGemmVariantInvalid; -} - -} // namespace - -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return false; - } - - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); - - switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: { - return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; - } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 - return - (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || - (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); - } - default: { - return false; - } - } -} - -namespace -{ - -size_t -SQNBitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 0; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); - } - - return 0; -} - -size_t -SQNBitGemmPerGemmWorkspaceAlignment( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 1; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); - } - - return 1; -} - -size_t -SQNBitGemmPerGemmWorkspaceStride( - size_t M, - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - return MlasDivRoundup(Size, Alignment) * Alignment; -} - -} // namespace - -size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (PerGemmWorkspaceStride == 0) { - return 0; - } - - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - - const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; - - return WorkspaceSize + Alignment - 1; -} - -size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 0; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( - N, K, BlkLen, ComputeType - ); - } - - return 0; -} - -struct PerGemmQuantAWorkspace { - PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) - : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) - { - QuantData = (std::byte*)PerGemmWorkspace; - QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); - BlockSum = QuantScale + M * BlockCountK; - } - std::byte* QuantData; // NxBlockCountKxBlkLen - float* QuantScale; // NxBlockCountK - float* BlockSum; // NxBlockCountK - void* PerGemmWorkspace_; // memory for above data - size_t M_, BlockCountK_, BlkLen_; -}; - -void MLASCALL -MlasSQNBitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const void* QuantBData, - void* PackedQuantBDataAndOrBlkSumWorkspace, - const void* QuantBScale, - bool has_zp_input, - const void* QuantBZeroPoint, - MLAS_THREADPOOL* ThreadPool -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return; - } - - if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); - Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(QuantBScale), - has_zp_input, - static_cast(QuantBZeroPoint), - packed_quant_b, - ThreadPool - ); - } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. - //assert(QuantBScale == nullptr); - //assert(QuantBZeroPoint == nullptr); - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBDataAndOrBlkSumWorkspace), - ThreadPool - ); - return; - } - } -} - -namespace -{ - -MLAS_FORCEINLINE void -AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) -{ - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); - acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } -} - -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - -void -SQ4BitGemm_CompFp32( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ - constexpr size_t BlkBitWidth = 4; - - MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); - - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const float* A = DataParams->A + RangeStartM * lda; - - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* a_row = A; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - - // - // Step through each slice of matrix B along the N dimension. - // - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* a_row = A; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( - BlkLen, - dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks - ); - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - ); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); -#endif - - if (bias) { - AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc - ); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } -} - -void -SQ4BitGemm_CompInt8( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ -#ifdef MLAS_TARGET_AMD64_IX86 - PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); - constexpr size_t BlkBitWidth = 4; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - - // quant A scale is embedded in QuantData if QuantScale is nullptr. - const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); - const size_t ldc = DataParams->ldc; - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; - const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; - - assert(RangeStartN % 4 == 0); - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; - const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; -#else - constexpr size_t BlkBitWidth = 4; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - - const size_t lda = k_blks * Q8BlkSize(BlkLen); - const size_t ldc = DataParams->ldc; - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; -#endif - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc - ); - } - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; - } - } -#ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) - { - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( - BlkLen, - QuantA, - QuantAScale, - b_col, - b_col_scale, - b_col_zp, - c_blk, - RangeCountM, - CountN, - K, - k_blks, - bias, - ldc, - ABlockSum, - b_blk_sum - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } -#endif - } -} - -typedef void(InitializeWorkspaceFn)( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool -); - -void -InitializeWorkspace_CompInt8( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool -) -{ - MLAS_UNREFERENCED_PARAMETER(N); - - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - - // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); - } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - - void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); - std::byte* QuantARowPtr = quant_a_data.QuantData; - float* QuantARowScalePtr = quant_a_data.QuantScale; - float* QuantARowBlkSum = quant_a_data.BlockSum; - for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); - ARowPtr += data.lda; - QuantARowPtr += BlockCountK * BlkLen; - QuantARowScalePtr += BlockCountK; - QuantARowBlkSum += BlockCountK; - } - }); - } -} - -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; - -constexpr auto OperationMap = []() { - std::array ops; - - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; - - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; - - return ops; -}(); -} // namespace - -void MLASCALL -MlasSQNBitGemmBatch( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const size_t BlkBitWidth, - const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - MLAS_THREADPOOL* ThreadPool -) -{ - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); - assert(Variant != SQNBitGemmVariantInvalid); - - // - // Ensure `Workspace` has correct alignment. - // - if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); - Workspace = reinterpret_cast( - (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) - ); - } - - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; - InitializeWorkspaceOperation != nullptr) { - InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool - ); - } - - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - - if (ThreadPool == nullptr) { - for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); - } else { - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); - } - } - return; - } - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - constexpr size_t StrideM = 128; - - size_t nc = N; - if (ThreadsPerGemm > 1) { - // more than one thread per GEMM - - const size_t BlockedM = MlasDivRoundup(M, StrideM); - const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min( - nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * - MLAS_QGEMM_STRIDEN_THREAD_ALIGN - ); - } - } - const size_t StrideN = nc; - - const size_t ThreadCountM = MlasDivRoundup(M, StrideM); - const size_t ThreadCountN = MlasDivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; - - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - const auto* Data = &DataParams[gemm_i]; - - const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; - const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; - - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); - - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - } else { - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - } - }); -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h deleted file mode 100644 index 2da336ca2f0ec..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ /dev/null @@ -1,340 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm.h - -Abstract: - - This module includes kernel function prototypes and helper functions for - implementing SQNBitGemm. - - SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float - matrix and B is a n-bit quantized integer matrix. B is block quantized, - meaning values of B are divided into blocks and each block has its own - scale and optional zero point. - ---*/ - -#pragma once - -#include "mlas_qnbit.h" -#include "mlasi.h" - -constexpr MLAS_FORCEINLINE size_t -MlasQNBitQuantBBlkSumAlignment() -{ - // 16 floats. this alignment is required by GemmFloatKernel - return 16 * sizeof(float); -} - -constexpr MLAS_FORCEINLINE size_t -MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) -{ - return BlkLen * BlkBitWidth / 8; -} - -MLAS_FORCEINLINE void* -MlasAlignAddress(void* addr, const size_t alignment) -{ - const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); - addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); - return addr; -} - -struct PackedQuantBDataStruct { - PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) - : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) - { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize - constexpr size_t BlkBitWidth = 4; - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - - // _mm256_load_si256 requires alignment on a 32-byte boundary - PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); - } - std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; - - void* QuantBWorkspace_; - size_t N_, BlockCountK_, BlkLen_; -}; - -template -constexpr MLAS_FORCEINLINE size_t -MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) -{ - if constexpr (BlkBitWidth <= 4) { - return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte - } else { - return BlkCount; - } -} - -// -// Kernel dispatch structure. -// - -struct MLAS_SQNBIT_GEMM_DISPATCH { - // - // Quantized B data packing function prototypes. - // - - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType - ); - - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; - - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool - ); - - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; - - typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - const float* QuantBScaleBegin, - bool has_zp_input, - const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, - MLAS_THREADPOOL* ThreadPool - ); - - SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; - - // - // Workspace size calculation function prototypes. - // - - /** - * @brief Gets the required size in bytes of the per-GEMM intermediate workspace. - * Returns a size of zero if no intermediate workspace is needed. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType - ); - - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; - - /** - * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. - * - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType - ); - - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; - - // - // CompFp32 kernel function prototypes. - // - - /** - * @brief Multiply float matrix A with quantized 4-bit integer matrix B. - * B is block quantized and column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. - * - * @param BlkLen Number of values in a block. - * @param A Supplies the A matrix. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - */ - typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias - ); - - SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; - - /** - * @brief Dequantize B into the format expected by the Sgemm kernel. - * B is a quantized 4-bit integer matrix that is block quantized and column major. - * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. - * - * @param BlkLen Number of values in a block. - * @param[out] FpData Supplies the output buffer for the dequantized B float data. - * It should have enough space for - * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen - * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are - * useful, but the kernel implementation can be simplified with the extra space. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param CountN Number of columns of B. - * @param CountK Number of rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. - */ - typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB - ); - - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; - - // - // CompInt8 kernel function prototypes. - // - - /** - * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. - * A and B are block quantized and B is column major. - * - * @param BlkLen Number of values in a block. - * @param QuantA Supplies the quantized A matrix. - Binary data containing block quantized int8 data and scale values. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - * @param ldc Number of elements between adjacent rows of C.. - * @param ABlockSum Supplies the blksum of A. - * @param QuantBBlkSum Supplies the blksum of B. - */ - typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( - size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t ldc, - const float* ABlockSum, - const float* QuantBBlkSum - ); - - SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; - - /** - * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. - * A and B are block quantized and B is column major. - * - * @param BlkLen Number of values in a block. - * @param QuantA Supplies the quantized A matrix. - Binary data containing block quantized int8 data and scale values. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountM Number of rows of A and C to process, an upper bound. - * @param CountN Number of columns of B and C to process. - * @param CountK Number of columns of A and rows of B. - * @param BlockCountK Number of blocks in one row of A and one column of B. - * @param ldc Number of elements between adjacent rows of C. - * @param Bias Bias vector of length N. - * - * @return The number of rows of A and C that were processed, at most CountM. - */ - typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float* Bias - ); - - SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; - - /** - * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. - * - * @param BlkLen Number of values in a block. - * @param A Supplies the A matrix. - * @param CountK Number of columns of A. - * @param[out] QuantA Supplies the output quantized A matrix. - * Binary data containing block quantized int8 data and scale values. - */ - typedef void(QuantizeARow_CompInt8_Fn)( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA - ); - - QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; - - typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA, - float* QuantAScale, - float* AScaledGroupSum // scale_k * Sum_blklen(a_i) - ); - QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; -}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp deleted file mode 100644 index baaa4ba1a3b1f..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ /dev/null @@ -1,1369 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_avx2.cpp.h - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for x64 avx2. - ---*/ - -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx_common_int8.h" -#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" -#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" -#include "sqnbitgemm_kernel_avx2_int8_blklen64.h" - -#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" -#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" - -void -MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) -{ - size_t i = 0; - - // Process 16 elements at a time using AVX2 - for (; i + 15 < size; i += 16) { - // Load 16 FP16 values into an AVX2 register - __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(src_fp16 + i)); - - // Convert FP16 values to FP32 - __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); - __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); - - // Store the converted FP32 values into the output vector - _mm256_storeu_ps(dst_fp32 + i, fp32_values1); - _mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2); - } - - // Process any remaining elements - const MLAS_FP16* fp16 = reinterpret_cast(src_fp16); - for (; i < size; ++i) { - dst_fp32[i] = fp16[i].ToFloat(); - } -} - -void -MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size) -{ - size_t i = 0; - - // Process 8 elements at a time using AVX2 - for (; i + 8 <= size; i += 8) { - __m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]); - __m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT); - _mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk); - } - - // Process any remaining elements - for (; i < size; ++i) { - MLAS_FP16 fp16(src_fp32[i]); - dst_fp16[i] = fp16.val; - } -} - -MLAS_FORCEINLINE -__m256 -load_float_n_avx2(const float* data, int n) -{ - assert(n <= 8); - if (n <= 0) { - return _mm256_setzero_ps(); - } - static const int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; - const __m256i load_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - n)); - return _mm256_maskload_ps(data, load_mask); -} - -MLAS_FORCEINLINE void -Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - const size_t CountN, - const size_t CountK, - const size_t BlockCountK -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - - constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; - // TODO: constexpr use temaplte parameter - /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; - const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time - constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 - const __m128i low_mask = _mm_set1_epi8(0xF); - for (size_t col = 0; col < CountN; col += NCols8) { - const int cols = std::min((int)NCols8, (int)CountN - (int)col); - for (size_t k = 0; k < BlockCountK; k++) { - // count # of tiles plus blks of the current tile from top - const size_t tile_count = col / GemmFloatKernelWidth16; - float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; - if (col % GemmFloatKernelWidth16 >= NCols8) { - // for the second half to 16 width tile - dst_ptr += NCols8; - } - const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; - const float* scale_ptr = QuantBScale + col * BlockCountK + k; - const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; - bool is_lower = (k % 2) == 0; - - __m256i weight_16_epi16[NCols8]; - __m256 scale_8_ps[NCols8]; - UnrolledLoop([&](size_t col_) { - if ((int)col_ < cols) { - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - __m128i bvi = _mm_loadl_epi64((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); - const __m128i lower = _mm_and_si128(bvi, low_mask); - const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi, 4), low_mask), 8); - __m128i weight_16_epi8 = _mm_add_epi8(upper, lower); - - if (HasZeroPoint) { - std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, _mm_set1_epi8(zp)); - } else { - const __m128i eight = _mm_set1_epi8(8); - weight_16_epi8 = _mm_sub_epi8(weight_16_epi8, eight); - } - weight_16_epi16[col_] = _mm256_cvtepi8_epi16(weight_16_epi8); - scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); - } else { - weight_16_epi16[col_] = _mm256_setzero_si256(); - scale_8_ps[col_] = _mm256_setzero_ps(); - } - }); - for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { - __m256 weight_8_ps[8]; - for (size_t col_ = 0; col_ < 8; col_++) { - if ((int)col_ < cols) { - if (i_of_2 == 0) { - __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 0)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); - } else { - __m256i weight_i_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_16_epi16[col_], 1)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_8_epi32), scale_8_ps[col_]); - } - } else { - weight_8_ps[col_] = _mm256_setzero_ps(); - } - } - // transpose and store - __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); - __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); - __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); - __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); - __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); - __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); - __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); - __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); - - __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); - - // next i_of_2th row - const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; - __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); - } - } - } -} - -template -MLAS_FORCEINLINE void -Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( - const size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - const size_t CountN, - const size_t CountK, - const size_t BlockCountK -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time - constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 - constexpr size_t SubblkLen32 = 32; // process SubblkLen32 rows of QuantB at a time - - const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); - const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; - // TODO: constexpr use temaplte parameter - /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; - const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] int count_half_4 = 0; - - const __m256i low_mask = _mm256_set1_epi8(0xF); - for (size_t col = 0; col < CountN; col += NCols8) { - // TODO: handle last tile with cols < NCols8 - const size_t cols = std::min(NCols8, CountN - col); - for (size_t k = 0; k < BlockCountK; k++) { - // count # of tiles plus blks of the current tile from top - const size_t tile_count = col / GemmFloatKernelWidth16; - float* dst_ptr = FpData + (tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; - if (col % GemmFloatKernelWidth16 >= NCols8) { - // for the second half to 16 width tile - dst_ptr += NCols8; - } - const std::byte* b_data_ptr = QuantBData + col * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; - const float* scale_ptr = QuantBScale + col * BlockCountK + k; - const std::byte* zp_ptr = QuantBZeroPoint + col * zp_col_stride_in_bytes + k / 2; - bool is_lower = (k % 2) == 0; - - for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { - __m256i weight_32_epi8[NCols8]; - __m256 scale_8_ps[NCols8]; - if constexpr (IsBlkLen64Layout) { - count_half_4 = 4 * (subblk % 2); - } - UnrolledLoop([&](size_t col_) { - if (col_ < cols) { - if constexpr (IsBlkLen64Layout) { - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 - // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 - // so that all v0-64 of the pack are dequantized. - const __m256i bvi = _mm256_loadu_si256((__m256i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); - weight_32_epi8[col_] = _mm256_and_si256(_mm256_srli_epi16(bvi, count_half_4), low_mask); - } else { - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - __m128i bvi = _mm_loadu_si128((__m128i const*)(b_data_ptr + col_ * b_data_col_stride_in_bytes)); - __m128i lower = _mm_and_si128(bvi, _mm256_castsi256_si128(low_mask)); - __m128i upper = _mm_and_si128(_mm_srli_epi16(bvi, 4), _mm256_castsi256_si128(low_mask)); - weight_32_epi8[col_] = _mm256_set_m128i(upper, lower); - } - - if (HasZeroPoint) { - std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], _mm256_set1_epi8(zp)); - } else { - const __m256i eight = _mm256_set1_epi8(8); - weight_32_epi8[col_] = _mm256_sub_epi8(weight_32_epi8[col_], eight); - } - - scale_8_ps[col_] = _mm256_set1_ps(*(scale_ptr + col_ * BlockCountK)); - } else { - weight_32_epi8[col_] = _mm256_setzero_si256(); - scale_8_ps[col_] = _mm256_setzero_ps(); - } - }); - for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { - __m256 weight_8_ps[8]; - for (size_t col_ = 0; col_ < 8; col_++) { - if (col_ < cols) { - if (i_of_4 == 0) { - __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); - __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); - } else if (i_of_4 == 1) { - __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 0)); - __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); - } else if (i_of_4 == 2) { - __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); - __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 0)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); - } else if (i_of_4 == 3) { - __m256i weight_i_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(weight_32_epi8[col_], 1)); - __m256i weight_i_j_8_epi32 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(weight_i_16_epi16, 1)); - weight_8_ps[col_] = _mm256_mul_ps(_mm256_cvtepi32_ps(weight_i_j_8_epi32), scale_8_ps[col_]); - } - } else { - weight_8_ps[col_] = _mm256_setzero_ps(); - } - } - // transpose and store - __m256 a0 = _mm256_unpacklo_ps(weight_8_ps[0], weight_8_ps[1]); - __m256 a1 = _mm256_unpackhi_ps(weight_8_ps[0], weight_8_ps[1]); - __m256 a2 = _mm256_unpacklo_ps(weight_8_ps[2], weight_8_ps[3]); - __m256 a3 = _mm256_unpackhi_ps(weight_8_ps[2], weight_8_ps[3]); - __m256 a4 = _mm256_unpacklo_ps(weight_8_ps[4], weight_8_ps[5]); - __m256 a5 = _mm256_unpackhi_ps(weight_8_ps[4], weight_8_ps[5]); - __m256 a6 = _mm256_unpacklo_ps(weight_8_ps[6], weight_8_ps[7]); - __m256 a7 = _mm256_unpackhi_ps(weight_8_ps[6], weight_8_ps[7]); - - __m256 b0 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b1 = _mm256_shuffle_ps(a0, a2, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b2 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b3 = _mm256_shuffle_ps(a1, a3, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b4 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b5 = _mm256_shuffle_ps(a4, a6, _MM_SHUFFLE(3, 2, 3, 2)); - __m256 b6 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(1, 0, 1, 0)); - __m256 b7 = _mm256_shuffle_ps(a5, a7, _MM_SHUFFLE(3, 2, 3, 2)); - - const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; - __m256 weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x20); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b0, b4, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b1, b5, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b2, b6, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, weight_transposed_8_ps); - weight_transposed_8_ps = _mm256_permute2f128_ps(b3, b7, 0x31); - _mm256_storeu_ps(dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, weight_transposed_8_ps); - } - dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; - if constexpr (IsBlkLen64Layout) { - b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; - } else { - b_data_ptr += subblk_data_size_in_bytes; - } - } // subblk - } - } -} - -MLAS_FORCEINLINE void -Q4BitBlkDequantBForSgemm_CompFp32_avx2( - const size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - const size_t CountN, - const size_t CountK, - const size_t BlockStrideQuantB -) -{ - if (BlkLen == 16) { - Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_avx2( - FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB - ); - } else if (BlkLen == 32) { - Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( - BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB - ); - } else { - Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_avx2( - BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB - ); - } -} - -template -MLAS_FORCEINLINE -void -SQ4BitGemmKernel_CompInt8_avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - if (BlkLen == 16) { - MlasQ4Int8GemmKernelBlkLen16Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - } else { - MlasQ4Int8GemmKernelBlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } -} - -template -MLAS_FORCEINLINE -void -SQ4BitGemmM1Kernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t /*CountK*/, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint) { - if (BlkLen == 16) { - } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - MlasQ4Int8GemmKernelBlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } - } else { - if (BlkLen == 16) { - } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - MlasQ4Int8GemmKernelBlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } - } -} - -MLAS_FORCEINLINE -size_t -SQ4BitGemmKernel_BlkSum_CompInt8_avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t ldc, - const float* ABlockSum, - const float* QuantBBlkSum -) -{ - if (BlkLen >= 32 && CountM == 1) { - SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); - return CountM; - } - - SQ4BitGemmKernel_CompInt8_avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - float* c_blk = C; - const float* b_blk_sum = QuantBBlkSum; - - size_t RowsRemaining = CountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; - } - return CountM; -} - -size_t -SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t ldc, - const float* ABlockSum, - const float* QuantBBlkSum -) -{ - if (BlkLen >= 32 && CountM == 1) { - SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); - return CountM; - } - - SQ4BitGemmKernel_CompInt8_avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - float* c_blk = C; - const float* b_blk_sum = QuantBBlkSum; - - size_t RowsRemaining = CountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; - } - return CountM; -} - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkLen16_CompFp32_avx2( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* sum_ptr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* bias_ptr -) -{ - if constexpr (!HasZeroPoint) { - // Suppress unused variable warnings - (void)QuantBZeroPointColPtr; - (void)StrideQuantBZeroPoint; - } - - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t SubBlkLen16 = 16; - constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); - static_assert(SubBlkStep8 == 8); // 16 * 4 / 8 - - __m256 acc[NCols]; - UnrolledLoop([&](size_t i) { - acc[i] = _mm256_setzero_ps(); - }); - - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint == true - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - float scale_v[NCols]; - UnrolledLoop([&](size_t i) { - scale_v[i] = *(s + StrideQuantBScale * i); - }); - - std::byte* b_blk_data_col_ptr[NCols]; - UnrolledLoop([&](size_t i) { - b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); - }); - - [[maybe_unused]] uint8_t offset[NCols]; - // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = std::to_integer(zp); - }); - } - - for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { - int kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); - - // Load A row vectors - int n_to_read = std::min(kklen, 8); - __m256 av_lo = load_float_n_avx2(ARowPtr + k + kk, n_to_read); - n_to_read = std::min(kklen - 8, 8); - __m256 av_hi = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); - - UnrolledLoop([&](size_t i) { - // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // Load B col vectors. get SubBlkLen(16) 4 bits quantized features from each column - __m128i bvi4 = _mm_loadl_epi64((__m128i const*)(b_blk_data_col_ptr[i])); - b_blk_data_col_ptr[i] += SubBlkStep8; - - // TODO: avoid _mm_set1_epi8 - //__m128i lower_mask_epi8 = _mm_cmpeq_epi16(bvi4, bvi4); // can use any __m128i - // lower_mask_epi8 = _mm_srli_epi16(lower_mask_epi8, 13); - // lower_mask_epi8 = _mm_packus_epi16(lower_mask_epi8, lower_mask_epi8); - __m128i lower_mask_epi8 = _mm_set1_epi8(0x0F); // Mask to isolate the lower 4 bits - - const __m128i lower = _mm_and_si128(bvi4, lower_mask_epi8); - const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4, 4), lower_mask_epi8), 8); - __m256i bv_epi16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); // unpacked 16 weights of epi16 - - // Subtract zero-point from the integers - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - __m256i zp = _mm256_set1_epi16(offset[i]); - bv_epi16 = _mm256_sub_epi16(bv_epi16, zp); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi16(8); - bv_epi16 = _mm256_sub_epi16(bv_epi16, eight); - } - - // Convert to 16 epi16 to 16 float32 - const __m128i bv_lo = _mm256_extractf128_si256(bv_epi16, 0); - const __m128i bv_hi = _mm256_extractf128_si256(bv_epi16, 1); - - __m256 bvf_lo = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_lo)); - __m256 bvf_hi = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(bv_hi)); - - // multiply by scale - __m256 scale_ps = _mm256_set1_ps(scale_v[i]); - bvf_lo = _mm256_mul_ps(bvf_lo, scale_ps); - bvf_hi = _mm256_mul_ps(bvf_hi, scale_ps); - - // c[m,n] += a[m,k] * b[k,n] - acc[i] = _mm256_fmadd_ps(bvf_lo, av_lo, acc[i]); - acc[i] = _mm256_fmadd_ps(bvf_hi, av_hi, acc[i]); - }); - } // kk - - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } // k - - if constexpr (NCols == 4) { - __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (bias_ptr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - } else { - UnrolledLoop([&](size_t i) { - __m128 vlow = _mm256_castps256_ps128(acc[i]); - __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit - - // Add the two 128-bit vectors together - __m128 vsum = _mm_add_ps(vlow, vhigh); - // Horizontally add the elements of the resulting 128-bit vector - vsum = _mm_hadd_ps(vsum, vsum); - vsum = _mm_hadd_ps(vsum, vsum); - - _mm_store_ss(&sum_ptr[i], vsum); - sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; - }); - } -} - -// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf -template -void -SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols4; - - while (nblk >= 0) { - ComputeDotProducts_BlkLen16_CompFp32_avx2( - BlkLen16, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - - nblk -= NCols4; - } - - // left over columns less than `NCols`? - nblk += NCols4; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkLen16_CompFp32_avx2<1, HasZeroPoint>( - BlkLen16, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -// TODO: flow MlasQ4GemmKernelBlkLen32PlusAvx512f to improve perf -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* sum_ptr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* bias_ptr -) -{ - if constexpr (!HasZeroPoint) { - // Suppress unused variable warnings - (void)QuantBZeroPointColPtr; - (void)StrideQuantBZeroPoint; - } - - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t SubBlkLen32 = 32; - constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); - static_assert(SubBlkStep16 == 16); // 32 * 4 / 8 - - __m256i lowMask = _mm256_set1_epi8(0x0F); - - __m256 acc[NCols]; - UnrolledLoop([&](size_t i) { - acc[i] = _mm256_setzero_ps(); - }); - - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - [[maybe_unused]] int count_half_4 = 0; - // only used if HasZeroPoint == true - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - float scale_v[NCols]; - UnrolledLoop([&](size_t i) { - scale_v[i] = *(s + StrideQuantBScale * i); - }); - - std::byte* b_blk_data_col_ptr[NCols]; - UnrolledLoop([&](size_t i) { - b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + StrideQuantBData * i); - }); - - [[maybe_unused]] uint8_t offset[NCols]; - // not ready for "Manual conversion to float" in neon yet. - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = std::to_integer(zp); - }); - } - - for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { - int kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); - - // Load 4 float8 from A - int n_to_read = std::min(kklen, 8); - __m256 av0_8_ps = load_float_n_avx2(ARowPtr + k + kk, n_to_read); - - n_to_read = std::min(kklen - 8, 8); - __m256 av1_8_ps = load_float_n_avx2(ARowPtr + k + kk + 8, n_to_read); - - n_to_read = std::min(kklen - 16, 8); - __m256 av2_8_ps = load_float_n_avx2(ARowPtr + k + kk + 16, n_to_read); - - n_to_read = std::min(kklen - 24, 8); - __m256 av3_8_ps = load_float_n_avx2(ARowPtr + k + kk + 24, n_to_read); - - if constexpr (IsBlkLen64Layout) { - count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); - } - UnrolledLoop([&](size_t i) { - // Load B col vectors. get SubBlkLen32 4b quantized weights from each column - __m256i bv_32_epi8; - if constexpr (IsBlkLen64Layout) { - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 - // increment b_data_ptr by 2 * SubBlkStep16 if kk % (2 * SubBlkLen32) == 1 - // so that all v0-63 of the pack are processed. - const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_blk_data_col_ptr[i])); - bv_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); - b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; - } else { - // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - __m128i bvi4 = _mm_loadu_si128((const __m128i*)(b_blk_data_col_ptr[i])); - b_blk_data_col_ptr[i] += SubBlkStep16; - - bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); - bv_32_epi8 = _mm256_and_si256(lowMask, bv_32_epi8); - } - - // Subtract zero-point from the integers - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - __m256i zp = _mm256_set1_epi8(offset[i]); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, zp); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, eight); - } - - // Convert to 16 float32 - const __m256i bv0_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 0)); - const __m256i bv1_16_epi16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bv_32_epi8, 1)); - - __m256 bv0_8_ps = - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 0))); - __m256 bv1_8_ps = - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv0_16_epi16, 1))); - __m256 bv2_8_ps = - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 0))); - __m256 bv3_8_ps = - _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(bv1_16_epi16, 1))); - - // multiply by scale - __m256 scale_ps = _mm256_set1_ps(scale_v[i]); - bv0_8_ps = _mm256_mul_ps(bv0_8_ps, scale_ps); - bv1_8_ps = _mm256_mul_ps(bv1_8_ps, scale_ps); - bv2_8_ps = _mm256_mul_ps(bv2_8_ps, scale_ps); - bv3_8_ps = _mm256_mul_ps(bv3_8_ps, scale_ps); - - // c[m,n] += a[m,k] * b[k,n] - acc[i] = _mm256_fmadd_ps(bv0_8_ps, av0_8_ps, acc[i]); - acc[i] = _mm256_fmadd_ps(bv1_8_ps, av1_8_ps, acc[i]); - acc[i] = _mm256_fmadd_ps(bv2_8_ps, av2_8_ps, acc[i]); - acc[i] = _mm256_fmadd_ps(bv3_8_ps, av3_8_ps, acc[i]); - }); - } // kk - - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } // k - - if constexpr (NCols == 4) { - __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (bias_ptr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - } else { - UnrolledLoop([&](size_t i) { - __m128 vlow = _mm256_castps256_ps128(acc[i]); - __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit - - // Add the two 128-bit vectors together - __m128 vsum = _mm_add_ps(vlow, vhigh); - // Horizontally add the elements of the resulting 128-bit vector - vsum = _mm_hadd_ps(vsum, vsum); - vsum = _mm_hadd_ps(vsum, vsum); - - _mm_store_ss(&sum_ptr[i], vsum); - sum_ptr[i] += bias_ptr == nullptr ? 0.0f : bias_ptr[i]; - }); - } -} - -// TODO: flow MlasQ4GemmKernelBlkLen16Avx512f to improve perf -template -void -SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols4; - while (nblk >= 0) { - if (BlkLen >= 64) { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - } else { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - } - - // move to next `NCols` columns - - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - - nblk -= NCols4; - } - - // left over columns less than `NCols`? - nblk += NCols4; - for (int64_t n = 0; n < nblk; ++n) { - if (BlkLen >= 64) { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, true>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - } else { - ComputeDotProducts_BlkLen32Plus_CompFp32_avx2<1, HasZeroPoint, false>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32_avx2( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen16_CompFp32_avx2( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } - } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_avx2( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } - } -} - -void MLASCALL -QuantizeARow_CompInt8_avx2( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA, - float* QuantAScale, - float* AScaledBlkSum // scale_k * Sum_blklen(a_i) -) -{ - // port from MlasQ80BlkQuantRow - assert(BlkLen % 16 == 0); - const __m256 signBit = _mm256_set1_ps(-0.0f); - const __m256i one_16_epi16 = _mm256_srli_epi16( - _mm256_cmpeq_epi16(_mm256_castps_si256(signBit), _mm256_castps_si256(signBit)), 15); - int8_t* blob = reinterpret_cast(QuantA); - float* scale_ptr = QuantAScale; - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t step = std::min(BlkLen, CountK - k); - - __m256 maxAbs = _mm256_setzero_ps(); - for (size_t kk = 0; kk < step; kk += 8) { - const int klen = std::min(8, (int)(step - kk)); - - __m256 v0 = load_float_n_avx2(A + k + kk, klen); - - // Compute max(abs(e)) for the block - maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, v0)); - } - - __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(maxAbs, 1), _mm256_castps256_ps128(maxAbs)); - max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); - max4 = _mm_max_ss(max4, _mm_shuffle_ps(max4, max4, 1)); - const float maxScalar = _mm_cvtss_f32(max4); - - // Quantize these floats - const float scale = maxScalar / 127.f; - *scale_ptr = scale; - scale_ptr++; - - const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps(inverse_scale); - __m128i* dst = reinterpret_cast<__m128i*>(blob); - - __m256i sum_16_epi16 = _mm256_setzero_si256(); - for (size_t kk = 0; kk < step; kk += 16) { - const int klen = std::min(16, (int)(step - kk)); - - int n_to_read = std::min(klen, 8); - __m256 v0 = load_float_n_avx2(A + k + kk, n_to_read); - v0 = _mm256_mul_ps(v0, mul); - v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); - - __m256 v1; - n_to_read = std::min(klen - 8, 8); - if (n_to_read <= 0) { - v1 = _mm256_setzero_ps(); - } else { - v1 = load_float_n_avx2(A + k + kk + 8, n_to_read); - v1 = _mm256_mul_ps(v1, mul); - v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); - } - - __m128i i_16_epi8 = convert_2_ps_to_epi8(v0, v1); - _mm_storeu_si128(dst++, i_16_epi8); - - // accumulate Sum(a_i) - __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8); - sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); - } - if (step < BlkLen) { - memset(blob + step, 0, BlkLen - step); - } - - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); - AScaledBlkSum++; - blob += BlkLen; - } -} - -static void -SQ4BitGemmPackQuantBDataAndBlkSum( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - const float* QuantBScaleBegin, - bool has_zp_input, - const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, - MLAS_THREADPOOL* ThreadPool -) -{ - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - - // TODO: always use SubBlkLen = 64 in CompInt8 - size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { - SubBlkLen = 64; - } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); -} - -// -// Kernel dispatch structure definition. -// -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - - d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - - return d; -}(); - -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - - d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; - - return d; -}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h deleted file mode 100644 index 80d67806ea6e8..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ /dev/null @@ -1,727 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - - -MLAS_FORCEINLINE __m256 -load_and_broadcast_4_scale_2(const float* scale) -{ - // 3 2 1 0 3 2 1 0 (7) - __m256 scale_2_4_ps = _mm256_broadcast_ps((__m128 const*)scale); - - // 2 1 0 0 2 1 0 0 (1) - __m256 scale_2_4_ps_shifted = _mm256_castsi256_ps( - _mm256_bslli_epi128(_mm256_castps_si256(scale_2_4_ps), 4) - ); - - // 3 2 1 0 2 1 0 0: (3) cross lane - __m256 scale_2_4_ps_permutted = _mm256_permute2f128_ps( - scale_2_4_ps_shifted, scale_2_4_ps, 0b00110000 - ); - - // in accumulate_r1_4blk_dot and accumulate_r2_4blk_dot - // _mm256_hadd_epi16 inter leaved dot sum, resulting: - // a31b31|a30b30|a11b11|a10b10|a21b21|a20b20|a01b01|a00b00 - // therefore we need weight to be: - // 3 3 1 1 2 2 0 0 (1) - return _mm256_permute_ps(scale_2_4_ps_permutted, 0b11110101); -} - -MLAS_FORCEINLINE -__m256i -load_16_epi8_as_epi16(const std::byte* ablob) -{ - const __m128i av_epi8 = _mm_lddqu_si128(reinterpret_cast(ablob)); - __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); - return av_epi16; -} - -MLAS_FORCEINLINE void -accumulate_r1_4blk_dot( - const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, - const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, - const float* scale_a, const float* scale_b, - __m256& acc) -{ - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av0_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av1_32_epi8); - const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); - const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); - - // load 4 scales - __m256 scale_a_4_ps = load_and_broadcast_4_scale_2(scale_a); - __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); - __m256 scale_8_ps = _mm256_mul_ps(scale_a_4_ps, scale_b_4_ps); - acc = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc); -} - -MLAS_FORCEINLINE void -accumulate_r2_4blk_dot( - const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, - const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, - const float* scale_a0, const float* scale_a1, const float* scale_b, - __m256& acc0, __m256& acc1 -) -{ - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); - const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); - - // load 4 scales - __m256 scale_a0_4_ps = load_and_broadcast_4_scale_2(scale_a0); - __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); - __m256 scale_8_ps = _mm256_mul_ps(scale_a0_4_ps, scale_b_4_ps); - acc0 = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc0); - - const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); - const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); - const __m256i sum_16_inter_leaved_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); - const __m256i sum_8_inter_leaved_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16_); - const __m256 sum_inter_leaved_ps_ = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32_); - - __m256 scale_a1_4_ps = load_and_broadcast_4_scale_2(scale_a1); - scale_8_ps = _mm256_mul_ps(scale_a1_4_ps, scale_b_4_ps); - acc1 = _mm256_fmadd_ps(sum_inter_leaved_ps_, scale_8_ps, acc1); -} - -static MLAS_FORCEINLINE __m256i -load_4b_packed_1blk_blklen16(const std::byte* QuantBDataPtr) -{ - // | 0 8 |...| 7 15 | - const __m128i bv_packed_64 = _mm_loadl_epi64(reinterpret_cast(QuantBDataPtr)); - const __m128i low_mask = _mm_set1_epi8(0xF); - const __m128i lower_8_epu8 = _mm_and_si128(bv_packed_64, low_mask); // 0~7 - const __m128i upper_8_epu8 = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bv_packed_64, 4), low_mask), 8); // 8~15 - const __m256i bv_16_epu16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper_8_epu8, lower_8_epu8)); // 0~15 - return bv_16_epu16; -} - -static MLAS_FORCEINLINE void -load_4b_packed_4blk_blklen16(const std::byte* QuantBDataPtr, __m256i& bv0_32_epi8, __m256i& bv1_32_epi8) -{ - // | 0 8 |...| 7 15 | 16 24 |...| 23 31 ||| 32 40 |...| 39 47 | 48 56 |...| 55 63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - // 0~7, 16~22, 32~39, 48~55 - __m256i bv0_32_epi8_ = _mm256_and_si256(bv_packed, low_mask); - // 8~15, 24~31, 40~47, 56~63: (1) - __m256i bv1_32_epi8_ = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8_), 4); - // 0~7, 32~39, 16~22, 48~55 <- cross lane (3) - bv0_32_epi8_ = _mm256_permute4x64_epi64(bv0_32_epi8_, 0b11011000); - // 40~47, 8~15, 56~63, 24~31 <- cross lane (3) - bv1_32_epi8_ = _mm256_permute4x64_epi64(bv1_32_epi8_, 0b01110010); - - // 0~7, 8~15, 16~22, 24~31: (1) - bv0_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b11001100); - - // 40~47, 32~39, 56~63, 48~55: (1) - bv1_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b00110011); - - // 32~39, 40~47, 48~55, 56~63: (1) - bv1_32_epi8 = _mm256_shuffle_epi32(bv1_32_epi8, 0b01001110); -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r2c1blk4_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m256& acc0, - __m256& acc1 -) -{ - __m256i bv0_32_epi8, bv1_32_epi8; - load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); - accumulate_r2_4blk_dot(av00_32_epi8, av01_32_epi8, av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, - scale_a0, scale_a1, scale_b, acc0, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r1c1blk4_avx2( - const __m256i& av0_32_epi8, - const __m256i& av1_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m256& acc -) -{ - __m256i bv0_32_epi8, bv1_32_epi8; - load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); - accumulate_r1_4blk_dot(av0_32_epi8, av1_32_epi8, bv0_32_epi8, bv1_32_epi8, scale_a, scale_b, acc); -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r2c1blk1_avx2( - const __m256i& av0_32_epi8, - const __m256i& av1_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale0, - const float& combined_scale1, - __m256& acc0, - __m256& acc1 -) -{ - const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); - - __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av0_32_epi8); - __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); - acc0 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale0), prod_8_ps, acc0); - - prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av1_32_epi8); - prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); - acc1 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale1), prod_8_ps, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r1c1blk1_avx2( - const __m256i& av_16_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale, - __m256& acc -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); - - __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av_16_epi8); - __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); - acc = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale), prod_8_ps, acc); -} - -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen16Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() - }; - - // process 4 blks of 64 4b weights a time - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); - - QuantAPtr += BlkLen16 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - while (k_blks_remaining-- > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); - const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - { - const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); - } - - { - const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); - } - - { - const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); - } - - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); - } - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes8; - QuantBScalePtr++; - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 4 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); - - // process 4 blks of 64 4b weights a time - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + 32; - const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + 32; - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); - - accumulate_blklen16_r2c1blk4_avx2( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += BlkLen16 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - while (k_blks_remaining-- > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); - const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc0, acc1); - - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes8; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc0); - *(SumPtr + ldc) = hsum_float_8(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen16Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, - QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); - // increment block pointers - QuantAPtr += BlkLen16 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - while (k_blks_remaining-- > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); - - const float& scale_a00 = *QuantAScalePtr; - { - // Col0 - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); - } - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); - } - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); - } - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); - } - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes8; - QuantBScalePtr++; - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen16Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 4 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_blklen16_r1c1blk4_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, - QuantAScalePtr, QuantBScalePtr, acc0); - - // increment block pointers - QuantAPtr += BlkLen16 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - while (k_blks_remaining-- > 0) { - const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); - const float& scale_a00 = *QuantAScalePtr; - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_16_epi16, QuantBDataPtr, scale_00, acc0); - - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes8; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -MLAS_FORCEINLINE - size_t - MlasQ4Int8GemmKernelBlkLen16Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t ldc - ) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen16Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen16Avx2( - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen16Avx2( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen16Avx2( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h deleted file mode 100644 index af6f52090adcb..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ /dev/null @@ -1,1049 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - - -MLAS_FORCEINLINE void -accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, - const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) -{ - const __m256i dot_16_epi16 = _mm256_maddubs_epi16( - bv_32_epi8, av_32_epi8 - ); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -} - -#if !defined(__GNUC__) || (__GNUC__ > 10) -MLAS_FORCEINLINE void -accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) -{ - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -} -#endif - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); - // low_mask = _mm256_packus_epi16(low_mask, low_mask); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - // TODO: this (the second line below) is faster and does not keep low_mask in use. - // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - { - const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); - const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); - const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - } - { - const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); - const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); - const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); - } - } else { -#endif - //{ - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - // generating constant 1s is faster here. - // __m256i one = _mm256_set1_epi16(1); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - //} - //{ - const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); - const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); - const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); - const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); - const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); - - __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); - //} -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_b, - __m256& acc0 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); - const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); - const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 - - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - } else { -#endif - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk1_avx2( - const __m256i& av00_32_epi8, - const __m256i& av10_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale00, - const float& combined_scale10, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); - accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); - } else { -#endif - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); - accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk1_avx2( - const __m256i& av00_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale00, - __m256& acc0 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); - } else { -#endif - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -MLAS_FORCEINLINE void -Q4Int8Gemm2x4x2BlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale2 = 2; - const size_t StrideQuantBScale1 = 1; - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() - }; - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); - - { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - } - { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); - } - - { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); - } - - { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); - } - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; - QuantBScalePtr += PerAccuBlk2 * NCols4; - } // k_blks_remaining - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - { - // Col0 - const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); - } - - { - // Col1 - const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); - } - - { - // Col2 - const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); - } - - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); - } - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBDataCol; - QuantBScaleColPtr += NCols4 * BlockCountK; - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); - - accumulate_blklen32_r2c1blk2_avx2( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - } - - if (k_blks_remaining > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); - } - - *SumPtr = hsum_float_8(acc0); - *(SumPtr + ldc) = hsum_float_8(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += BlockCountK; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmXx4BlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale2 = 2; - const size_t StrideQuantBScale1 = 1; - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - - { - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, - QuantAScalePtr, QuantBScalePtr, acc[0]); - } - { - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, - QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] - ); - } - { - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, - QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] - ); - } - { - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, - QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] - ); - } - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; - QuantBScalePtr += PerAccuBlk2 * NCols4; - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a00 = *QuantAScalePtr; - { - // Col0 - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); - } - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); - } - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); - } - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); - } - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBDataCol; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmXxXBlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, - QuantAScalePtr, QuantBScalePtr, acc0 - ); - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const float& scale_a00 = *QuantAScalePtr; - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += BlockCountK; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE - size_t - MlasQ4Int8GemmKernelBlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t ldc - ) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8Gemm2x4x2BlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8Gemm2xXBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmXx4BlkLen32Avx2( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmXxXBlkLen32Avx2( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} - -// this function is to explore larger NCols. With Avx2 it does not improve performance. -// Leave it here until the same is implemented in avx512. -template accumulator> -MLAS_FORCEINLINE -size_t -MlasQ4Int8TileGemmKernelBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc -) -{ - // We process 32 quantized values in a batch. - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const __m256i zero = _mm256_setzero_si256(); - const __m128i low_mask = _mm_set1_epi8(0xF); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - for (size_t m = 0; m < CountM; m++) { - // for each row of A, reset B pointers - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - int64_t nblk = (int64_t)(CountN)-NCols4; - while (nblk >= 0) { - const std::byte* QuantAPtr = QuantA + m * lda; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4]; - - acc[0] = _mm256_setzero_ps(); - acc[1] = _mm256_setzero_ps(); - acc[2] = _mm256_setzero_ps(); - acc[3] = _mm256_setzero_ps(); - - if constexpr (NCols4 == 8) { - acc[4] = _mm256_setzero_ps(); - acc[5] = _mm256_setzero_ps(); - acc[6] = _mm256_setzero_ps(); - acc[7] = _mm256_setzero_ps(); - } - - size_t k_blks_remaining = BlockCountK; - - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); - - if constexpr (NCols4 == 8) { - // Col4 - const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; - const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); - - // Col5 - const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; - const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); - - // Col6 - const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; - const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); - - // Col7 - const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; - const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); - } - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } // k_blks_remaining - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); - - if constexpr (NCols4 == 8) { - // Col4 - const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); - - // Col5 - const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); - - // Col6 - const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); - - // Col7 - const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); - } - } // k_blks_remaining - - if constexpr (NCols4 == 8) { - __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); - if (BiasPtr != nullptr) { - acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); - acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); - } - _mm_storeu_ps(SumPtr, acc_0); - _mm_storeu_ps(SumPtr+4, acc_1); - } else { - __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); - } - - // move to next NCols columns - - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - nblk -= NCols4; - } // while (nblk >= 0) - - nblk += NCols4; - for (int64_t n = 0; n < nblk; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } // m - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h deleted file mode 100644 index 174ebc580904c..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ /dev/null @@ -1,541 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r2c1blk1_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m256& acc0, - __m256& acc1 -) -{ - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); - sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); - __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); - __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - - sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); - sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); - sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); - - acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); - - } else { -#endif - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - - __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); - __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - - dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); - dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); - sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); - - acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r1c1blk1_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m256& acc0 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); - sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); - } else { -#endif - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) - } -#endif -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() - }; - - // process 1 blks of 64 4b weights a time - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr += NCols4; - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE -Q4Int8GemmR2xC1BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); - - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc0); - *(SumPtr + ldc) = hsum_float_8(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr += NCols4; - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_blklen64_r1c1blk1_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 - ); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE size_t -MlasQ4Int8GemmKernelBlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen64Avx2( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen64Avx2( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp deleted file mode 100644 index 13bd369a065bb..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ /dev/null @@ -1,372 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_avx512.cpp.h - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for x64 avx512. - ---*/ - -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx_common_int8.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" - -// -// CompFp32 kernel implementation. -// - -#include "sqnbitgemm_kernel_avx_common_fp32.h" - -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32_avx512( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen16Avx512f( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen16Avx512f( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } else if (BlkLen == 32) { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } else /*if (BlkLen >= 64)*/ { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } -} - -// -// CompInt8 kernel implementation. -// - -MLAS_FORCEINLINE -size_t -SQ4BitGemmKernel_BlkSum_CompInt8_avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* /*QuantBZeroPoint*/, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t ldc, - const float* ABlockSum, - const float* QuantBBlkSum -) -{ - if (BlkLen == 16) { - MlasQ4Int8GemmKernelBlkLen16Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen == 64) { - MlasQ4Int8GemmKernelBlkLen64Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else { - MlasQ4Int8GemmKernelBlkLen128Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } - - float* c_blk = C; - const float* b_blk_sum = QuantBBlkSum; - - size_t RowsRemaining = CountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; - } - return CountM; -} - -void MLASCALL -QuantizeARow_CompInt8_avx512( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA, - float* QuantAScale, - float* AScaledBlkSum // scale_k * Sum_blklen(a_i) -) -{ - // port from MlasQ80BlkQuantRow - assert(BlkLen % 16 == 0); - const __m512 signBit = _mm512_set1_ps(-0.0f); - const __m256i one_16_epi16 = _mm256_set1_epi16(1); - int8_t* blob = reinterpret_cast(QuantA); - float* scale_ptr = QuantAScale; - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t step = std::min(BlkLen, CountK - k); - - __m512 maxAbs = _mm512_setzero_ps(); - for (size_t kk = 0; kk < step; kk += 16) { - const size_t klen = std::min(size_t(16), step - kk); - - uint32_t mask = 0xffff >> (16 - klen); - __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - // Compute max(abs(e)) for the block - maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0)); - } - - __m256 max8 = - _mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0)); - __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8)); - max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); - max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); - const float maxScalar = _mm_cvtss_f32(max4); - - // Quantize these floats - const float scale = maxScalar / 127.f; - *scale_ptr = scale; - scale_ptr++; - - const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; - const __m512 mul = _mm512_set1_ps(inverse_scale); - __m128i* dst = reinterpret_cast<__m128i*>(blob); - - __m256i sum_16_epi16 = _mm256_setzero_si256(); - for (size_t kk = 0; kk < step; kk += 16) { - const size_t klen = std::min(size_t(16), step - kk); - - uint32_t mask = 0xffff >> (16 - klen); - __m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - v0 = _mm512_mul_ps(v0, mul); - - // Round to nearest integer - v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST); - - // Convert floats to integers - __m512i i0 = _mm512_cvtps_epi32(v0); - - // Convert int32 to int8 - __m128i i0_8 = _mm512_cvtepi32_epi8(i0); - _mm_storeu_si128(dst++, i0_8); - - // accumulate Sum(a_i) - __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i0_8); - sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); - - } - if (step < BlkLen) { - memset(blob + step, 0, BlkLen - step); - } - - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); - AScaledBlkSum++; - blob += BlkLen; - } -} - -static void -SQ4BitGemmPackQuantBDataAndBlkSum512( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - const float* QuantBScaleBegin, - bool has_zp_input, - const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, - MLAS_THREADPOOL* ThreadPool -) -{ - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - - size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { - SubBlkLen = 128; - } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); -} - -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; - - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - - d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - - return d; -}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h deleted file mode 100644 index 7d9dc36854621..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h +++ /dev/null @@ -1,1171 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - - -MLAS_FORCEINLINE void -accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, - const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) -{ - const __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) - ); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -} - -MLAS_FORCEINLINE void -accumulate_2blk_dot( - const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, - const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, - const float& combined_scale0, const float& combined_scale1, - const __m256i& one_16_epi16, - __m256& acc) -{ - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) - ); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale_8_ps = _mm256_set_ps( - combined_scale1, combined_scale1, combined_scale0, combined_scale0, - combined_scale1, combined_scale1, combined_scale0, combined_scale0 - ); - acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - // TODO: will this (the second line below) be faster and not keep low_mask in use? - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); - //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) - ); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256d scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_mul( - _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), - _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); - - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - - - const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) - ); - const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) - ); - const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); - const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); - const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); - - __m256d scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps_ = _mm256_mul( - _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), - _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); - acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - const float& combined_scale01, - const float& combined_scale10, - const float& combined_scale11, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). - // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); - //low_mask = _mm256_packus_epi16(low_mask, low_mask); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 - // TODO: will this (the second line below) be faster and not keep low_mask in use? - // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - - //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); - - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - // generating constant 1s is fater here. - // __m256i one = _mm256_set1_epi16(1); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - - // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. - // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); - // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); - // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); - // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); - accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); - accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk1_avx2( - const __m256i& av00_32_epi8, - const __m256i& av10_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - const float& combined_scale10, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - - const int8_t zp = get_zp(true, QuantBZeroPointPtr); - const __m256i bzp = _mm256_set1_epi8(zp); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); - accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - const float& combined_scale01, - __m256& acc0) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 - // TODO: will this be faster and save a use of low_mask? - // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - - //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); - - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); - //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); - accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk1_avx2( - const __m256i& av00_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - __m256& acc0 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - - const int8_t zp = get_zp(true, QuantBZeroPointPtr); - const __m256i bzp = _mm256_set1_epi8(zp); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); -} - -template -MLAS_FORCEINLINE void -Q4Int8Gemm2x4BlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() - }; - - size_t k_blks_remaining = BlockCountK; - - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; - const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - const float& scale_a10 = Q8BlkScale(QuantABlk10); - const float& scale_a11 = Q8BlkScale(QuantABlk11); - - { - // Col0 - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); - } - - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); - } - - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); - } - - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); - } - - // increment block pointers - QuantAPtr += Q8Blk32Size * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } // k_blks_remaining - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); - - const float& scale_a00 = Q8BlkScale(QuantABlk0); - const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); - - { - // Col0 - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); - } - - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); - } - - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); - } - - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); - } - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - // accumulate_blklen32_r2c1_avx2 - const std::byte* QuantAPtr = QuantA + m * lda; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); - const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - const float& scale_a10 = Q8BlkScale(QuantABlk10); - const float& scale_a11 = Q8BlkScale(QuantABlk11); - - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); - - const float& scale_a00 = Q8BlkScale(QuantABlk0); - const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); - } - - *SumPtr = hsum_float_8(acc0); - *(SumPtr + ldc) = hsum_float_8(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmXx4BlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - { - // Col0 - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); - } - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); - } - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); - } - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); - } - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a00 = Q8BlkScale(QuantABlk0); - { - // Col0 - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); - } - { - // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); - } - { - // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); - } - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); - } - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmXxXBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a00 = Q8BlkScale(QuantABlk0); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE - size_t - MlasQ4Int8TileGemmKernelBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc - ) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8Gemm2x4BlkLen32Avx2( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - lda, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8Gemm2xXBlkLen32Avx2( - QuantA, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - lda, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmXx4BlkLen32Avx2( - QuantA + multipleRows * lda, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - lda, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmXxXBlkLen32Avx2( - QuantA + multipleRows * lda, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - lda, - ldc); - } - - return CountM; -} - -// this function is to explore larger NCols. With Avx2 it does not improve performance. -// Leave it here until the same is implemented in avx512. -template accumulator> -MLAS_FORCEINLINE -size_t -MlasQ4Int8GemmKernelBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc -) -{ - // We process 32 quantized values in a batch. - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const __m256i zero = _mm256_setzero_si256(); - const __m128i low_mask = _mm_set1_epi8(0xF); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - for (size_t m = 0; m < CountM; m++) { - // for each row of A, reset B pointers - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - int64_t nblk = (int64_t)(CountN)-NCols4; - while (nblk >= 0) { - const std::byte* QuantAPtr = QuantA + m * lda; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4]; - - acc[0] = _mm256_setzero_ps(); - acc[1] = _mm256_setzero_ps(); - acc[2] = _mm256_setzero_ps(); - acc[3] = _mm256_setzero_ps(); - - if constexpr (NCols4 == 8) { - acc[4] = _mm256_setzero_ps(); - acc[5] = _mm256_setzero_ps(); - acc[6] = _mm256_setzero_ps(); - acc[7] = _mm256_setzero_ps(); - } - - size_t k_blks_remaining = BlockCountK; - - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); - - if constexpr (NCols4 == 8) { - // Col4 - const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; - const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); - - // Col5 - const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; - const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); - - // Col6 - const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; - const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); - - // Col7 - const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; - const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); - } - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } // k_blks_remaining - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); - - if constexpr (NCols4 == 8) { - // Col4 - const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); - - // Col5 - const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); - - // Col6 - const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); - - // Col7 - const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); - } - } // k_blks_remaining - - if constexpr (NCols4 == 8) { - __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); - if (BiasPtr != nullptr) { - acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); - acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); - } - _mm_storeu_ps(SumPtr, acc_0); - _mm_storeu_ps(SumPtr+4, acc_1); - } else { - __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); - } - - // move to next NCols columns - - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - nblk -= NCols4; - } // while (nblk >= 0) - - nblk += NCols4; - for (int64_t n = 0; n < nblk; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); - - // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } // m - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h deleted file mode 100644 index 60a887345d0e0..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ /dev/null @@ -1,581 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - -//static MLAS_FORCEINLINE __m512i -//combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) -//{ -// __m512i result = _mm512_castsi256_si512(a); -// result = _mm512_inserti64x4(result, b, 1); -// return result; -//} - -//static MLAS_FORCEINLINE void -//load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) -//{ -// // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | v64 v96 | ... | v95 v127 | -// const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); -// const __m512i low_mask = _mm512_set1_epi8(0x0F); -// __m512i bv0_64_epi8_ = _mm512_and_si512(bv_packed, low_mask); // 0~31, 64~95 -// __m512i bv1_64_epi8_ = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 32~63, 96~127 -// -// // Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 -// __m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); -// __m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); -// __m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); -// __m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); -// -// // Compose new __m512i variables -// bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); -// bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); -//} - -static MLAS_FORCEINLINE void -dot_accumulate_1blk( - const __m512i& bv0_64_epi8, - const __m512i& bv1_64_epi8, - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const float combined_scale, - __m512& acc -) -{ - __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); - __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); - __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); - __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); - __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); - const __m512i zeros = _mm512_setzero_si512(); - const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); -} - -static MLAS_FORCEINLINE void -dot_accumulate_1blkvnni( - const __m512i& bv0_64_epi8, - const __m512i& bv1_64_epi8, - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const float combined_scale, - __m512& acc -) -{ - __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); - __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(dot0_16_epi32, bv1_64_epi8, av1_64_epi8); - __m512 sum_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen128_r1c1blk1_avx512( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - if constexpr (vnni) { - dot_accumulate_1blkvnni( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a) * (*scale_b), acc - ); - } else { - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a) * (*scale_b), acc - ); - } -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen128_r2c1blk1_avx512( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - if constexpr (vnni) { - dot_accumulate_1blkvnni( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a0) * (*scale_b), acc0 - ); - dot_accumulate_1blkvnni( - bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, - (*scale_a1) * (*scale_b), acc1 - ); - } else { - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a0) * (*scale_b), acc0 - ); - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, - (*scale_a1) * (*scale_b), acc1 - ); - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen128Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 128; - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4 * NRows2] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - - // process 1 blks of 64 4b weights a time - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); - - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr += NCols4; - } // k_blks_remaining - -#if 1 - *SumPtr = _mm512_reduce_add_ps(acc[0]); - *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); - *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); - *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); - *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); - *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); - *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); - *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - *SumPtr += *BiasPtr; - *(SumPtr + 1) += *(BiasPtr + 1); - *(SumPtr + 2) += *(BiasPtr + 2); - *(SumPtr + 3) += *(BiasPtr + 3); - *(SumPtr + ldc) += *BiasPtr; - *(SumPtr + ldc + 1) += *(BiasPtr + 1); - *(SumPtr + ldc + 2) += *(BiasPtr + 2); - *(SumPtr + ldc + 3) += *(BiasPtr + 3); - } -#else - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); -#endif - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE -Q4Int8GemmR2xC1BlkLen128Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 128; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); - - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); - - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_16(acc0); - *(SumPtr + ldc) = hsum_float_16(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen128Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 128; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr +=NCols4; - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen128Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 128; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(); - for (size_t k = 0; k < BlockCountK; ++k) { - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - - accumulate_blklen128_r1c1blk1_avx512( - av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 - ); - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_16(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE size_t -MlasQ4Int8GemmKernelBlkLen128Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen128Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen128Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen128Avx512( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen128Avx512( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h deleted file mode 100644 index bb14babd6c2b1..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ /dev/null @@ -1,812 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - - - -static MLAS_FORCEINLINE void -load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) -{ - // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | - const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - const __m512i low_mask = _mm512_set1_epi8(0x0F); - bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 - bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r1c1blk8_avx512( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc0) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 - { - const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 - const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) - ); // 0123456701234567 - - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1,2~2,3~3 - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 4~4,5~5,6~6,7~7 - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0044115522663377 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r2c1blk4_avx512( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 - { - const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 - const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) - ); // 0123456701234567 - - // TODO: load from memory - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } - { - const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 - const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); - __m512 scale_a1b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) - ); // 0123456701234567 - - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r1c1blk8_avx512vnni( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc0 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 - { - const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 - const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) - ); // 0123456701234567 - - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 - - const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000111122223333 - const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 4444555566667777 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen16_r2c1blk4_avx512vnni( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 - { - const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 - const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) - ); // 0123456701234567 - - // TODO: load from memory - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); - - const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 - const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } - { - const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 - const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); - __m512 scale_a1b_16_ps = _mm512_castsi512_ps( - _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) - ); // 0123456701234567 - - __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); - scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); - - const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000111122223333 - const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 4444555566667777 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen16Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk8 = 8; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4 * NRows2] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - if constexpr (vnni) { - accumulate_blklen16_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4] - ); - accumulate_blklen16_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, - acc[1], acc[NCols4 + 1] - ); - accumulate_blklen16_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, - acc[2], acc[NCols4 + 2] - ); - accumulate_blklen16_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, - acc[3], acc[NCols4 + 3] - ); - } else { - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4] - ); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, - acc[1], acc[NCols4 + 1] - ); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, - acc[2], acc[NCols4 + 2] - ); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, - acc[3], acc[NCols4 + 3] - ); - } - - // increment block pointers - QuantAPtr += BlkLen16 * PerAccuBlk8; - QuantAScalePtr += PerAccuBlk8; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; - QuantBScalePtr += PerAccuBlk8; - } // k_blks_remaining - - __m256 acc2[NCols4 * NRows2] = { - h_add_512(acc[0]), - h_add_512(acc[1]), - h_add_512(acc[2]), - h_add_512(acc[3]), - h_add_512(acc[4]), - h_add_512(acc[5]), - h_add_512(acc[6]), - h_add_512(acc[7]) - }; - - while (k_blks_remaining-- > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); - const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - { - // Col0 - const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); - } - - { - // Col1 - const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, - acc2[1], acc2[NCols4 + 1]); - } - - { - // Col2 - const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, - acc2[2], acc2[NCols4 + 2]); - } - - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen16_r2c1blk1_avx2( - av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, - acc2[3], acc2[NCols4 + 3]); - } - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); - __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE -Q4Int8GemmR2C1BlkLen16Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk8 = 8; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - if constexpr (vnni) { - accumulate_blklen16_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 - ); - } else { - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 - ); - } - - // increment block pointers - QuantAPtr += BlkLen16 * PerAccuBlk8; - QuantAScalePtr += PerAccuBlk8; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; - QuantBScalePtr += PerAccuBlk8; - } - - __m256 acc20 = h_add_512(acc0); - __m256 acc21 = h_add_512(acc1); - while (k_blks_remaining-- > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); - const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc20, acc21); - - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc20); - *(SumPtr + ldc) = hsum_float_8(acc21); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen16Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk8 = 8; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - - if constexpr (vnni) { - accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); - } else { - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); - } - - QuantAPtr += BlkLen16 * PerAccuBlk8; - QuantAScalePtr += PerAccuBlk8; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; - QuantBScalePtr += PerAccuBlk8; - } - - __m256 acc2[NCols4] = { - h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) - }; - - while (k_blks_remaining-- > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); - - const float& scale_a00 = *QuantAScalePtr; - { - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); - } - - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - - } - - __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen16Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk8 = 8; - - const size_t lda = BlockCountK * BlkLen16; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - - if constexpr (vnni) { - accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - } else { - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - } - - QuantAPtr += BlkLen16 * PerAccuBlk8; - QuantAScalePtr += PerAccuBlk8; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; - QuantBScalePtr += PerAccuBlk8; - } - - __m256 acc2 = h_add_512(acc0); - while (k_blks_remaining-- > 0) { - const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantAPtr); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); - - QuantAPtr += BlkLen16; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc2); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE - size_t -MlasQ4Int8GemmKernelBlkLen16Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc - ) -{ - constexpr size_t BlkLen16 = 16; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen16Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2C1BlkLen16Avx512( - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen16Avx512( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen16Avx512( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h deleted file mode 100644 index e9df6b952bd27..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ /dev/null @@ -1,852 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - -static MLAS_FORCEINLINE void -load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) -{ - // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | - const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - const __m512i low_mask = _mm512_set1_epi8(0x0F); - bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 - bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 -} - -static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3}; - -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk4_avx512( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc0) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 - { - const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 - const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - - __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - // __m512i idx = _mm512_loadu_epi8(&index_array[0]); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3 - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk4_avx512( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 - { - const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 - const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - - __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - // __m512i idx = _mm512_loadu_epi8(&index_array[0]); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } - { - const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 - const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); - __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 - - __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - // __m512i idx = _mm512_loadu_epi8(&index_array[0]); - scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 - - const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 - const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 - - const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 - const __m512i one_32_epi16 = generate_ones_32_epi16(); - const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk4_avx512vnni( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc0 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 - { - const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 - const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - - __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - //__m512i idx = _mm512_loadu_epi8(&index_array[0]); - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 - - const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111 - const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk4_avx512vnni( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - //__m512i idx = _mm512_loadu_epi8(&index_array[0]); - - const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 - { - const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 - const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); - __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - - scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 - - const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 - const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); - } - { - const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 - const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); - __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 - - scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 - - const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111 - const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333 - - const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 - const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 - const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 - const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); - } -} - -MLAS_FORCEINLINE void -accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) -{ - __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk1_avx512( - const __m256i& av00_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale00, - __m256& acc0 -) -{ - if constexpr (vnni) { - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); - } else { - accumulate_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); - } -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk1_avx512( - const __m256i& av00_32_epi8, - const __m256i& av10_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale00, - const float& combined_scale10, - __m256& acc0, - __m256& acc1 -) -{ - if constexpr (vnni) { - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - - accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); - accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); - } else { - accumulate_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = PerAccuBlk4 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4 * NRows2] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - if constexpr (vnni) { - accumulate_blklen32_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4] - ); - accumulate_blklen32_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, - acc[1], acc[NCols4 + 1] - ); - accumulate_blklen32_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, - acc[2], acc[NCols4 + 2] - ); - accumulate_blklen32_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, - acc[3], acc[NCols4 + 3] - ); - } else { - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4] - ); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, - acc[1], acc[NCols4 + 1] - ); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, - acc[2], acc[NCols4 + 2] - ); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, - acc[3], acc[NCols4 + 3] - ); - } - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += StrideQuantBData * NCols4; - QuantBScalePtr += PerAccuBlk4 * NCols4; - } // k_blks_remaining - - __m256 acc2[NCols4 * NRows2] = { - h_add_512(acc[0]), - h_add_512(acc[1]), - h_add_512(acc[2]), - h_add_512(acc[3]), - h_add_512(acc[4]), - h_add_512(acc[5]), - h_add_512(acc[6]), - h_add_512(acc[7]) - }; - - while (k_blks_remaining-- > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - { - // Col0 - const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); - } - - { - // Col1 - const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; - accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); - } - - { - // Col2 - const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; - accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); - } - - { - // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; - accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); - } - QuantAPtr += BlkLen32; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; - QuantBScalePtr += NCols4; - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); - __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; - QuantBScaleColPtr += NCols4 * BlockCountK; - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE -Q4Int8GemmR2C1BlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - if constexpr (vnni) { - accumulate_blklen32_r2c1blk4_avx512vnni( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 - ); - } else { - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 - ); - } - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - __m256 acc20 = h_add_512(acc0); - __m256 acc21 = h_add_512(acc1); - while (k_blks_remaining-- > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_a10 = *(QuantAScalePtr + BlockCountK); - - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); - - QuantAPtr += BlkLen32; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc20); - *(SumPtr + ldc) = hsum_float_8(acc21); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - - if constexpr (vnni) { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); - } else { - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); - } - - QuantAPtr += BlkLen32 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; - QuantBScalePtr += PerAccuBlk4 * NCols4; - } - - __m256 acc2[NCols4] = { - h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) - }; - - while (k_blks_remaining-- > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a00 = *QuantAScalePtr; - { - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; - accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; - accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; - accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); - } - - QuantAPtr += BlkLen32; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; - QuantBScalePtr += NCols4; - - } - - __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk4 = 4; - - const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM < NRows2); - assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - - if constexpr (vnni) { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - } - else { - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - } - - QuantAPtr += BlkLen32 * PerAccuBlk4; - QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; - } - - __m256 acc2 = h_add_512(acc0); - while (k_blks_remaining-- > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - - const float& scale_a00 = *QuantAScalePtr; - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); - - QuantAPtr += BlkLen32; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_8(acc2); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE -size_t -MlasQ4Int8GemmKernelBlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen32Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2C1BlkLen32Avx512( - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen32Avx512( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen32Avx512( - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h deleted file mode 100644 index 2a65ac4af0c1d..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ /dev/null @@ -1,840 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - -static MLAS_FORCEINLINE __m256 -h_add_512(__m512 a) -{ - return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); -} - -static MLAS_FORCEINLINE float -hsum_float_16(const __m512 x) -{ - __m256 hi = h_add_512(x); - __m128 hi128 = _mm256_extractf128_ps(hi, 1); - __m128 lo128 = _mm256_castps256_ps128(hi); - hi128 = _mm_add_ps(hi128, lo128); - hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); - hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); - return _mm_cvtss_f32(hi128); -} - -static MLAS_FORCEINLINE __m512i -combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) -{ - __m512i result = _mm512_castsi256_si512(a); - result = _mm512_inserti64x4(result, b, 1); - return result; -} - -static MLAS_FORCEINLINE void -load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) -{ - // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | - const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - const __m512i low_mask = _mm512_set1_epi8(0x0F); - bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 - bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 - - //// Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 - //__m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); - //__m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); - //__m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); - //__m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); - - //// Compose new __m512i variables - //bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); - //bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); -} - -static MLAS_FORCEINLINE __m512i -load_1blk_4b_packed_blklen64(const std::byte* QuantBDataPtr) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16( - _mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - __m512i bv_64_epi8 = combine_two_m256i_to_m512i(bv0_32_epi8, bv1_32_epi8); - return bv_64_epi8; -} - -static MLAS_FORCEINLINE __m512i -horizontal_add_epi32(__m512i a, __m512i b) -{ - __m512i t1 = _mm512_unpacklo_epi32(a, b); - __m512i t2 = _mm512_unpackhi_epi32(a, b); - __m512i sum = _mm512_add_epi32(t1, t2); - return sum; -} - -static MLAS_FORCEINLINE __m512i -generate_ones_32_epi16() -{ - const __m512i zeros = _mm512_setzero_si512(); - return _mm512_srli_epi16(_mm512_ternarylogic_epi64(zeros, zeros, zeros, 1), 15); -} - -static MLAS_FORCEINLINE void -dot_accumulate_2blk( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const float* scale_a, - const __m512i& bv0_64_epi8, - const __m512i& bv1_64_epi8, - const __m512& scale_b_16_ps, - //const __m512i& one_32_epi16, - __m512& acc) -{ - __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); - __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); - - __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); - __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); - __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // sum for blk: 0 0 1 1 0 0 1 1... - __m512i one_32_epi16 = generate_ones_32_epi16(); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // sum for blk: 0 1 0 1... - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - - __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); - - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); -} - -static MLAS_FORCEINLINE void -dot_accumulate_2blkvnni( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const float* scale_a, - const __m512i& bv0_64_epi8, - const __m512i& bv1_64_epi8, - const __m512& scale_b_16_ps, - // const __m512i& one_32_epi16, - __m512& acc -) -{ - __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); - __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); - - __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); - __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); - __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - - __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); - - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r2c1blk2_avx512( - const __m512i& av00_64_epi8, - const __m512i& av01_64_epi8, - const __m512i& av10_64_epi8, - const __m512i& av11_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); - - if constexpr (vnni) { - dot_accumulate_2blkvnni( - av00_64_epi8, av01_64_epi8, scale_a0, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc0 - ); - - dot_accumulate_2blkvnni( - av10_64_epi8, av11_64_epi8, scale_a1, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc1 - ); - } else { - dot_accumulate_2blk( - av00_64_epi8, av01_64_epi8, scale_a0, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc0 - ); - - dot_accumulate_2blk( - av10_64_epi8, av11_64_epi8, scale_a1, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc1 - ); - } -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r1c1blk2_avx512( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc -) -{ - __m512i bv0_64_epi8, bv1_64_epi8; - load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - - const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); - - if constexpr (vnni) { - dot_accumulate_2blkvnni( - av0_64_epi8, av1_64_epi8, scale_a, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc - ); - } else { - dot_accumulate_2blk( - av0_64_epi8, av1_64_epi8, scale_a, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc - ); - } -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r2c1blk1_avx512( - const __m512i& av0_64_epi8, - const __m512i& av1_64_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a0, - const float* scale_a1, - const float* scale_b, - __m512& acc0, - __m512& acc1 -) -{ - __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); - - const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); - const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); - - if constexpr (vnni) { - { - __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); - __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - - __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); - __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); - - acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); - } - - { - __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); - __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - - __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); - __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); - - acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); - } - } else { - const __m512i zeros = _mm512_setzero_si512(); - // const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); - // const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); - - const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); - { - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); - - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - - __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); - __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); - - acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); - } - - { - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - - __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); - __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); - - acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); - } - } -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen64_r1c1blk1_avx512( - const __m512i& av_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m512& acc -) -{ - __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); - - const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); - const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); - - if constexpr (vnni) { - __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); - __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - - __m128 scale_a_ps = _mm_broadcast_ss(scale_a); - __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); - - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); - } else { - const __m512i one_32_epi16 = _mm512_set1_epi16(1); - - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); - - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - - __m128 scale_a_ps = _mm_broadcast_ss(scale_a); - __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); - - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR2xC4BlkLen64Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen64 = 64; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); - - // process 2 blks of 128 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4 * NRows2] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 128 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += StrideQuantBData * NCols4; - QuantBScalePtr += PerAccuBlk2 * NCols4; - } // k_blks_remaining - - while (k_blks_remaining-- > 0) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); - - QuantAPtr += BlkLen64; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes * NCols4; - QuantBScalePtr += NCols4; - } - -#if 1 - *SumPtr = _mm512_reduce_add_ps(acc[0]); - *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); - *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); - *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); - *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); - *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); - *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); - *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - *SumPtr += *BiasPtr; - *(SumPtr + 1) += *(BiasPtr + 1); - *(SumPtr + 2) += *(BiasPtr + 2); - *(SumPtr + 3) += *(BiasPtr + 3); - *(SumPtr + ldc) += *BiasPtr; - *(SumPtr + ldc + 1) += *(BiasPtr + 1); - *(SumPtr + ldc + 2) += *(BiasPtr + 2); - *(SumPtr + ldc + 3) += *(BiasPtr + 3); - } -#else - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); -#endif - // move to next NCols columns - QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -void MLAS_FORCEINLINE -Q4Int8GemmR2xC1BlkLen64Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkLen64 = 64; - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - - // process 2 blks of 128 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - //assert(CountM % NRows2 == 0); - //assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - float* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 128 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - } - - while (k_blks_remaining-- > 0) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - QuantAPtr += BlkLen64; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_16(acc0); - *(SumPtr + ldc) = hsum_float_16(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - *(SumPtr + ldc) += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen64Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkLen64 = 64; - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - - // process 2 blks of 128 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen; - //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; - //const size_t StrideQuantBScale = BlockCountK; - - //assert(CountM < NRows2); - //assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 128 4b weights a time - for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; - QuantBScalePtr += PerAccuBlk2 * NCols4; - } - - while (k_blks_remaining-- > 0) { - const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); - - QuantAPtr += BlkLen64; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes * NCols4; - QuantBScalePtr += NCols4; - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; - QuantBScaleColPtr += NCols4 * BlockCountK; - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC1BlkLen64Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkLen64 = 64; - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - - // process 2 blks of 128 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - //assert(CountM < NRows2); - //assert(CountN < NCols4); - - for (size_t m = 0; m < CountM; m++) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - - __m512 acc0 = _mm512_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - // process 2 blks of 128 4b weights a time - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - - accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - } - - while (k_blks_remaining-- > 0) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - - accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - - QuantAPtr += BlkLen64; - QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; - } - - *SumPtr = hsum_float_16(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - } -} - -template -MLAS_FORCEINLINE size_t -MlasQ4Int8GemmKernelBlkLen64Avx512( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - - const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); - const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - size_t remainingRows = CountM % NRows2; - size_t multipleRows = CountM - remainingRows; - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleRows > 0 && multipleCols > 0) { - if (NRows2 == 2) - Q4Int8GemmR2xC4BlkLen64Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - else - Q4Int8GemmR1xC4BlkLen64Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - multipleRows, - multipleCols, - BlockCountK, - Bias, - ldc - ); - } - if (remainingCols > 0 && multipleRows > 0) { - if (NRows2 == 2) - Q4Int8GemmR2xC1BlkLen64Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - else - Q4Int8GemmR1xC1BlkLen64Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc - ); - } - - if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen64Avx512( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData, - QuantBScale, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - ldc); - } - if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen64Avx512( - BlkLen, - QuantA + multipleRows * lda, - QuantAScale + multipleRows * lda_scale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - ldc); - } - - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp deleted file mode 100644 index 6a5c01162c51b..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ /dev/null @@ -1,357 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_avx512.cpp.h - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for x64 avx512vnni. - ---*/ - -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_kernel_avx_common_fp32.h" -#include "sqnbitgemm_kernel_avx_common_int8.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" -#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" - -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (BlkLen == 16) { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen16Avx512f( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen16Avx512f( - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } else if (BlkLen == 32) { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } else /*if (BlkLen >= 64)*/ { - if (QuantBZeroPoint != nullptr) { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } else { - MlasQ4GemmKernelBlkLen32PlusAvx512f( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0, - 0 - ); - } - } -} - -MLAS_FORCEINLINE -void -SQ4BitGemmM1Kernel_CompInt8_avx512vnni( - size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - assert(false); - } else { - constexpr bool HasZeroPoint = false; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } - } -} - -MLAS_FORCEINLINE -size_t -SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* /*QuantBZeroPoint*/, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t ldc, - const float* ABlockSum, - const float* QuantBBlkSum -) -{ - if (BlkLen == 16) { - MlasQ4Int8GemmKernelBlkLen16Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx512( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen == 64) { - MlasQ4Int8GemmKernelBlkLen64Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else { - MlasQ4Int8GemmKernelBlkLen128Avx512( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } - - float* c_blk = C; - const float* b_blk_sum = QuantBBlkSum; - - size_t RowsRemaining = CountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += BlockCountK * RowsHandled; - RowsRemaining -= RowsHandled; - } - return CountM; -} - -void MLASCALL -QuantizeARow_CompInt8_avx512( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA, - float* QuantAScale, - float* AScaledBlkSum // scale_k * Sum_blklen(a_i) -); - -static void -SQ4BitGemmPackQuantBDataAndBlkSum512vnni( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - const float* QuantBScaleBegin, - bool has_zp_input, - const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, - MLAS_THREADPOOL* ThreadPool -) -{ - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - - size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { - SubBlkLen = 128; - } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); -} - -// -// Kernel dispatch structure definition. -// -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; - - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - - d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; - d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; - - return d; -}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h deleted file mode 100644 index 177f5518bb891..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ /dev/null @@ -1,679 +0,0 @@ -#pragma once -#include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" - -// -// Quantized B data packing function implementation. -// - -static size_t -SQ4BitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - constexpr size_t BlkBitWidth = 4; - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { - size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t ScaleSize = N * BlockCountK * sizeof(float); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - - // _mm256_load_si256 requires alignment on a 32-byte boundary - constexpr size_t PackedQuantBDataAlignment = 32; - PackedQuantBDataSize += PackedQuantBDataAlignment - 1; - constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); - BlkSumSize += BlkSumAlignment - 1; - - return PackedQuantBDataSize + ScaleSize + BlkSumSize; - } else { - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; - } -} - -static void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - - const size_t SubBlkDataSize = SubBlkLen / 2; - const size_t SubBlkBytePairCount = SubBlkLen / 4; - - // - // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - - // - // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | - // => - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // - - // - // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | - // => - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { - for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += SubBlkDataSize; - PackedQuantBData += SubBlkDataSize; - } - } - ); -} - -static size_t -GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) -{ - size_t T = n / 4, t = n % 4; - bool te = T == N / 4; - size_t scale_dst_offset = T * 4 * SubOrBlkCountK; - if (te) { - scale_dst_offset += t * SubOrBlkCountK + k_sub_or_blk; - } else { - scale_dst_offset += k_sub_or_blk * 4 + t; - } - return scale_dst_offset; -} - -static size_t -GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) -{ - size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; - bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; - size_t scale_dst_offset = T * 4 * BlockCountK; - if (te) { - scale_dst_offset += t * BlockCountK + k_blk; - } else { - scale_dst_offset += k_subblk * blks_per_sub * 4; - if (be) { - scale_dst_offset += b * 4 + t; - } else { - scale_dst_offset += t * blks_per_sub + b; - } - } - return scale_dst_offset; -} - -static void -PackQuantB( - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool, - const size_t N, - const size_t BlockCountK, - const size_t BlkLen, - const size_t SubBlkLen) -{ - constexpr size_t BlkBitWidth = 4; - const size_t BlkBytePairCount = BlkLen / 4; - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - - const size_t SubBlkDataSize = SubBlkLen / 2; - const size_t SubBlkBytePairCount = SubBlkLen / 4; - const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); - const size_t Iterations = N * SubBlkCountK; // one iteration per sub block - - // for avx2 - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // for the remaining blk, it shall be: - // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | - - // for avx512 - // dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | - // for the remaining blk, it shall be: - // dst blklen64: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / SubBlkCountK; - const size_t k_subblk = tid % SubBlkCountK; - - const size_t src_data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + src_data_offset; - - size_t PackBytePairCount = SubBlkBytePairCount; - size_t PackDataSize = SubBlkDataSize; - - auto pack_subblk = []( - const std::byte* QuantBData, std::byte* PackedQuantBData, - size_t pack_byte_pair_count, size_t pack_data_size) { - for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } }; - - if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && - SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { - // this is the last subblk of the column. check if it extends out of the - // BlockCountK. If it does, we shall pack per blocks so that can compute - // on each block instead of each subblk. - PackBytePairCount = BlkBytePairCount; - PackDataSize = BlkDataSize; - const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; - for (size_t k = 0; k < k_blks_remaining; k++) { - const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; - if (BlkLen == 16) { - // not to do the compute order layout yet - std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; - pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); - } else if (BlkLen >= SubBlkLen) { - // shall not reach here with avx2 - assert(SubBlkLen == 128); - } else { - int blks_per_sub = (int)(SubBlkLen / BlkLen); - const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); - std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; - pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); - } - } - } else { - if (BlkLen == 16) { - // not to do the compute order layout yet - std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } else if (BlkLen >= SubBlkLen) { - const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); - std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } else { - int blks_per_sub = (int)(SubBlkLen / BlkLen); - const size_t k_blk = k_subblk * blks_per_sub; - const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); - std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } - } - } - ); -} - -//#include - -static void -ComputePackBlkSum( - size_t BlkLen, - size_t SubBlkLen, - size_t N, - float* QuantBScaleBegin, - const std::byte* QuantBZPBegin, - float* BlockSumBegin, - MLAS_THREADPOOL* ThreadPool, - const size_t BlockCountK) -{ - std::vector QuantBScaleBeginCopy(N * BlockCountK); - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); - MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t src_blk_offset = n * BlockCountK + k_blk; - const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; - uint8_t zp = 8; - if (QuantBZPBegin) { - size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); - size_t src_zp_offset = ZPCountK * n + k_blk / 2; - bool low_zp = k_blk % 2 == 0; - const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; - const std::byte low_mask{0X0F}; - zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); - } - - // BlockSum is a width 16 row major matrix - const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; - *(BlockSumBegin + dst_offset) = -QuantBScale * zp; - if (BlkLen == 16) { // TODO - - } else if (BlkLen >= SubBlkLen) { - const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); - *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; - } else { - int blks_per_sub = (int)(SubBlkLen / BlkLen); - size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); - *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; - } - } - ); -} - -static void -PackQuantBDataAndBlkSum( - size_t N, - size_t BlockCountK, - size_t BlkLen, - size_t SubBlkLen, - const std::byte* QuantBDataBegin, - const float* QuantBScaleBegin, - bool has_zp_input, - const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, - MLAS_THREADPOOL* ThreadPool -) -{ - if (QuantBDataBegin) { - PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); - } - - if (QuantBScaleBegin) { - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); - } - - if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { - ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); - } -} - -// -// Workspace size calculation function implementation. -// - -static size_t -SQ4BitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(N); - - switch(ComputeType) { - case CompInt8: { - // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // QuantData + Scale + BlkSum - const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); - return PerGemmWorkspaceSize; - } - default: { - return 0; - } - } -} - -static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(BlkLen); - - switch (ComputeType) { - case CompInt8: { - return Q8BlkAlignment(); - } - default: { - return 1; - } - } -} - -void -Q4BitBlkDequantBForSgemm_CompFp32_avx2( - const size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - const size_t CountN, - const size_t CountK, - const size_t BlockStrideQuantB -); - -size_t -SQ4BitGemmKernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float* Bias -); - -// -// General helpers. -// - -namespace -{ - -template -MLAS_FORCEINLINE void -UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) -{ - (f(Indices), ...); -} - -template -MLAS_FORCEINLINE void -UnrolledLoop(IterationFn&& f) -{ - UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); -} - -// this function is used to dot product 2 pairs of 32 epi8s. it is used with Int8 precision -// and blklen >= 64. In this case, 64 of 4b weights are filled with one load. -static MLAS_FORCEINLINE __m256 -dot_quad_avx512vnni( - const __m256i bv0_32_epi8, const __m256i bv1_32_epi8, const __m256i av0_32_epi8, const __m256i av1_32_epi8 -) -{ - const __m256i zero = _mm256_setzero_si256(); - __m256i sum_8_epi32 = _mm256_dpbusd_epi32(zero, _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - sum_8_epi32 = _mm256_dpbusd_epi32(sum_8_epi32, _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - return _mm256_cvtepi32_ps(sum_8_epi32); -} - -static MLAS_FORCEINLINE __m256 -dot_quad_avx2( - const __m256i b0, const __m256i b1, const __m256i a0, const __m256i a1 -) -{ - // Perform multiplication and create 16-bit values - const __m256i ones = _mm256_set1_epi16(1); - __m256i sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b0, b0), _mm256_sign_epi8(a0, b0)); - __m256i summed_pair_epi32 = _mm256_madd_epi16(ones, sum_epi16); - - sum_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(b1, b1), _mm256_sign_epi8(a1, b1)); - summed_pair_epi32 = _mm256_add_epi32(_mm256_madd_epi16(ones, sum_epi16), summed_pair_epi32); - return _mm256_cvtepi32_ps(summed_pair_epi32); -} - -// TODO: refactor load_and_mul_sum_s8_quads_with_zp_avx512vnni, load_and_mul_sum_s8_quads_with_zp_avx2 -// and accumulate_mul_sum_avx512vnni, accumulate_mul_sum_avx2 -static MLAS_FORCEINLINE void -load_and_mul_sum_s8_quads_with_zp_avx512vnni( - const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const int8_t zp, const __m256 scale0, __m256& acc0 -) -{ - // load B - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - - // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i - // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. - // const __m128i low_mask = _mm_set1_epi8(15); - const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... - const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... - __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); - - //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - // const __m256i low_mask = _mm256_set1_epi8(15); - // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); - - const __m256i bzp0 = _mm256_set1_epi8(zp); - bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); - // quantized dot product - __m256i dot_0_epi32 = _mm256_dpbusd_epi32( - zero, _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) - ); - const __m256 sum_ps = _mm256_cvtepi32_ps(dot_0_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); -} - -static MLAS_FORCEINLINE void -load_and_mul_sum_s8_quads_with_zp_avx2( - const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i, const int8_t zp, const __m256 scale0, __m256& acc0 -) -{ - // load B - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - - // supprisingly this code that works with __m128i is 2-3% faster than the blobk below with __m256i - // to unpack bv_packed0. Also passing in low_mask is faster than creating it here by 2%. - // const __m128i low_mask = _mm_set1_epi8(15); - const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... - const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... - __m256i bv_0_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); - - //__m256i bv_0_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - // const __m256i low_mask = _mm256_set1_epi8(15); - // bv_0_epi8 = _mm256_and_si256(low_mask, bv_0_epi8); - - const __m256i bzp0 = _mm256_set1_epi8(zp); - bv_0_epi8 = _mm256_sub_epi8(bv_0_epi8, bzp0); - // quantized dot product - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv_0_epi8, bv_0_epi8), _mm256_sign_epi8(av_0_epi8, bv_0_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); -} - -template -void MLAS_FORCEINLINE -get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) -{ - if constexpr (HasZeroPoint) { - zp0 = std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}); - zp1 = std::to_integer((*QuantBZeroPointPtr) >> 4); - } else { - zp0 = 8; - zp1 = 8; - (void)QuantBZeroPointPtr; - } -} - -template -int8_t MLAS_FORCEINLINE -get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) -{ - if constexpr (!HasZeroPoint) { - // Suppress unused variable warnings - (void)QuantBZeroPointPtr; - } - - if constexpr (HasZeroPoint) { - return is_lower_half_byte_zp ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : std::to_integer((*QuantBZeroPointPtr) >> 4); - } else { - return 8; - } -} - -// this function load and unpack 32 4b weights (packed for BlkLen32) and dot product it with 32 -// epi8 input. dot products are accumulated into acc0. -// This function is called for Int8 precision with BlkLen = 32. -template -using AccumulateFunctionType = void (*)( - const __m256i, const __m128i*, const __m128i, const __m256i, const std::byte*, bool, const float, __m256& -); - -template -static MLAS_FORCEINLINE void -accumulate_mul_sum_avx512vnni( - const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 -) -{ - const __m256 scale0 = _mm256_set1_ps(combined_scale); - const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); - load_and_mul_sum_s8_quads_with_zp_avx512vnni( - av_0_epi8, reinterpret_cast(QuantBDataPtr), - low_mask, zero, - zp, scale0, acc0 - ); -} - -template -static MLAS_FORCEINLINE void -accumulate_mul_sum_avx2( - const __m256i av_0_epi8, const __m128i* QuantBDataPtr, const __m128i low_mask, const __m256i zero, const std::byte* QuantBZeroPointPtr, bool is_lower_half_byte_zp, const float combined_scale, __m256& acc0 -) -{ - const __m256 scale0 = _mm256_set1_ps(combined_scale); - const int8_t zp = get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr); - load_and_mul_sum_s8_quads_with_zp_avx2( - av_0_epi8, reinterpret_cast(QuantBDataPtr), - low_mask, zero, - zp, scale0, acc0 - ); -} - -/** - * @brief Horizontally sum 4 vectors and store - * the results in the returned vector - */ -static MLAS_FORCEINLINE __m128 -FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) -{ - __m256 acc_lo01 = _mm256_unpacklo_ps(acc0, acc1); - __m256 acc_hi01 = _mm256_unpackhi_ps(acc0, acc1); - __m256 acc_lo23 = _mm256_unpacklo_ps(acc2, acc3); - __m256 acc_hi23 = _mm256_unpackhi_ps(acc2, acc3); - - __m256 acc_lo0123 = _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) - ); - __m256 acc_hi0123 = _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(acc_lo01), _mm256_castps_pd(acc_lo23)) - ); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) - ); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(acc_hi01), _mm256_castps_pd(acc_hi23)) - ); - acc_lo0123 = _mm256_add_ps(acc_lo0123, acc_hi0123); - - __m128 acc_y = - _mm_add_ps(_mm256_extractf128_ps(acc_lo0123, 0), _mm256_extractf128_ps(acc_lo0123, 1)); - return acc_y; -} - -static MLAS_FORCEINLINE float -hsum_float_8(const __m256 x) -{ - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -/** - * @brief Horizontally sum 4 vectors and store - * the results in the returned vector - */ -static MLAS_FORCEINLINE __m128 -FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, const __m512& acc3) -{ - __m512 acc_lo01 = _mm512_unpacklo_ps(acc0, acc1); - __m512 acc_hi01 = _mm512_unpackhi_ps(acc0, acc1); - __m512 acc_lo23 = _mm512_unpacklo_ps(acc2, acc3); - __m512 acc_hi23 = _mm512_unpackhi_ps(acc2, acc3); - - __m512 acc_lo0123 = _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) - ); - __m512 acc_hi0123 = _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(acc_lo01), _mm512_castps_pd(acc_lo23)) - ); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) - ); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - acc_hi0123 = _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(acc_hi01), _mm512_castps_pd(acc_hi23)) - ); - acc_lo0123 = _mm512_add_ps(acc_lo0123, acc_hi0123); - - __m256 acc_y = - _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); - return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); -} - -static MLAS_FORCEINLINE __m128i -convert_2_ps_to_epi8(__m256 v0, __m256 v1) -{ - __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); - __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); - - __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); - __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); - - return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); -} - -// horizontally add 8 int32_t -static MLAS_FORCEINLINE int -hsum_8_epi32(const __m256i a_8_epi32) -{ - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -} // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h deleted file mode 100644 index 5cd380e591098..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ /dev/null @@ -1,639 +0,0 @@ -#pragma once -#include "sqnbitgemm.h" - -template -MLAS_FORCEINLINE - size_t - MlasQ4GemmKernelBlkLen16Avx512f( - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc - ) -{ - // We process 32 quantized values in a batch. - // assert(BlkLen % 32 == 0) - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols = 4; - constexpr size_t BlkLen16 = 16; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const __m128i lowMask = _mm_set1_epi8(0xF); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - for (size_t m = 0; m < CountM; m++) { - //*// - ////const float* BiasPtr = Bias; - - // for each row of A, reset B pointers - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - ////float* SumPtr = CRowPtr; - //*// - - auto* sum_ptr = C; - const auto* bias_ptr = Bias; - - int64_t nblk = (int64_t)(CountN)-4; - while (nblk >= 0) { - __m512 acc_lo0 = _mm512_setzero_ps(); - __m512 acc_lo1 = _mm512_setzero_ps(); - __m512 acc_lo2 = _mm512_setzero_ps(); - __m512 acc_lo3 = _mm512_setzero_ps(); - - //*// - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - //*// - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx = 0; - } - - for (size_t k = 0; k < CountK; k += BlkLen16) { - size_t kklen = std::min(CountK - k, BlkLen16); - - const float scale_v0 = *(s); - const float scale_v1 = *(s + StrideQuantBScale * 1); - const float scale_v2 = *(s + StrideQuantBScale * 2); - const float scale_v3 = *(s + StrideQuantBScale * 3); - - const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); - const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); - const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); - const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); - - // Load A row vector of 16 floats - uint32_t mask = 0xffff >> (BlkLen16 - kklen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); - - // Load B col vectors of 16 of 4b - // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - const __m128i bvi4_0 = _mm_loadl_epi64(b0ptr++); - const __m128i bvi4_1 = _mm_loadl_epi64(b1ptr++); - const __m128i bvi4_2 = _mm_loadl_epi64(b2ptr++); - const __m128i bvi4_3 = _mm_loadl_epi64(b3ptr++); - - // expand 4b into byte array - __m128i lower = _mm_and_si128(bvi4_0, lowMask); - __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); - __m128i bytes0 = _mm_add_epi8(upper, lower); - - lower = _mm_and_si128(bvi4_1, lowMask); - upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_1, 4), lowMask), 8); - __m128i bytes1 = _mm_add_epi8(upper, lower); - - lower = _mm_and_si128(bvi4_2, lowMask); - upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_2, 4), lowMask), 8); - __m128i bytes2 = _mm_add_epi8(upper, lower); - - lower = _mm_and_si128(bvi4_3, lowMask); - upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_3, 4), lowMask), 8); - __m128i bytes3 = _mm_add_epi8(upper, lower); - - // Subtract zero-point from the integers - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - - // TODO: void condition on is_lower - std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - - bytes0 = _mm_sub_epi8(bytes0, _mm_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes1 = _mm_sub_epi8(bytes1, _mm_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes2 = _mm_sub_epi8(bytes2, _mm_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes3 = _mm_sub_epi8(bytes3, _mm_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m128i eight = _mm_set1_epi8(8); - bytes0 = _mm_sub_epi8(bytes0, eight); - bytes1 = _mm_sub_epi8(bytes1, eight); - bytes2 = _mm_sub_epi8(bytes2, eight); - bytes3 = _mm_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - const __m256i vx16_0 = _mm256_cvtepi8_epi16(bytes0); - const __m256i vx16_1 = _mm256_cvtepi8_epi16(bytes1); - const __m256i vx16_2 = _mm256_cvtepi8_epi16(bytes2); - const __m256i vx16_3 = _mm256_cvtepi8_epi16(bytes3); - - // Convert to 32-bit int -> float 32 - __m512 bvf_0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_0)); - __m512 bvf_1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_1)); - __m512 bvf_2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_2)); - __m512 bvf_3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_3)); - - __m512 scale_ps = _mm512_set1_ps(scale_v0); - bvf_0 = _mm512_mul_ps(bvf_0, scale_ps); - scale_ps = _mm512_set1_ps(scale_v1); - bvf_1 = _mm512_mul_ps(bvf_1, scale_ps); - scale_ps = _mm512_set1_ps(scale_v2); - bvf_2 = _mm512_mul_ps(bvf_2, scale_ps); - scale_ps = _mm512_set1_ps(scale_v3); - bvf_3 = _mm512_mul_ps(bvf_3, scale_ps); - - acc_lo0 = _mm512_fmadd_ps(bvf_0, av_lo, acc_lo0); - acc_lo1 = _mm512_fmadd_ps(bvf_1, av_lo, acc_lo1); - acc_lo2 = _mm512_fmadd_ps(bvf_2, av_lo, acc_lo2); - acc_lo3 = _mm512_fmadd_ps(bvf_3, av_lo, acc_lo3); - - //*// - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - //*// - - } // k - - __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); - if (Bias != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - - // move to next 4 columns - sum_ptr += 4; - bias_ptr += 4; - nblk -= 4; - - //*// - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - ////BiasPtr += BiasPtr != nullptr ? NCols : 0; - ////SumPtr += NCols; - - ////nblk -= NCols; - //*// - } - - // left over columns less than 4 ? - nblk += 4; - if (nblk > 0) { - __m512 acc_lo[4]{}; - - //*// - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - //*// - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx = 0; - } - - for (size_t k = 0; k < CountK; k += BlkLen16) { - size_t klen = std::min(CountK - k, BlkLen16); - - float scale_v[4]; - const __m128i* b_ptr[4]; - for (int64_t nn = 0; nn < nblk; nn++) { - //*// - scale_v[nn] = *(s + StrideQuantBScale * nn); - b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); - //*// - } - - uint32_t mask = 0xffff >> (BlkLen16 - klen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k); - - for (int64_t nn = 0; nn < nblk; nn++) { - // Load B col vectors of 16 of 4b - // SubBlkLen = 16: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - const __m128i bvi4_0 = _mm_loadl_epi64(b_ptr[nn]++); - - // expand 4b into byte array - __m128i lower = _mm_and_si128(bvi4_0, lowMask); - __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi4_0, 4), lowMask), 8); - __m128i bytes = _mm_add_epi8(upper, lower); - - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - - // TODO: void condition on is_lower - std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes = _mm_sub_epi8(bytes, _mm_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m128i eight = _mm_set1_epi8(8); - bytes = _mm_sub_epi8(bytes, eight); - } - - // Convert to 16-bit int - const __m256i vx16 = _mm256_cvtepi8_epi16(bytes); - - // Convert to 32-bit int -> float 32 - __m512 bvf = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16)); - __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); - bvf = _mm512_mul_ps(bvf, scale_16_ps); - - acc_lo[nn] = _mm512_fmadd_ps(bvf, av_lo, acc_lo[nn]); - } - - //*// - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - //*// - } // k - - for (int64_t nn = 0; nn < nblk; nn++) { - sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); - sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; - } - } - - // Prepare pointers for the next row - C += ldc; - A += lda; - } - return CountM; -} - -template -MLAS_FORCEINLINE - size_t - MlasQ4GemmKernelBlkLen32PlusAvx512f( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc - ) -{ - // We process 32 quantized values in a batch. - // assert(BlkLen % 32 == 0) - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols = 4; - constexpr size_t MLAS_QUANT4_BLK_UNIT32 = 32; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const __m256i lowMask = _mm256_set1_epi8(0xF); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - - for (size_t m = 0; m < CountM; m++) { - //*// - ////const float* BiasPtr = Bias; - - // for each row of A, reset B pointers - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - ////float* SumPtr = CRowPtr; - //*// - - auto* sum_ptr = C; - const auto* bias_ptr = Bias; - - int64_t nblk = (int64_t)(CountN)-4; - while (nblk >= 0) { - __m512 acc_lo0 = _mm512_setzero_ps(); - __m512 acc_lo1 = _mm512_setzero_ps(); - __m512 acc_lo2 = _mm512_setzero_ps(); - __m512 acc_lo3 = _mm512_setzero_ps(); - - //*// - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - //*// - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx = 0; - } - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - const float scale_v0 = *(s); - const float scale_v1 = *(s + StrideQuantBScale * 1); - const float scale_v2 = *(s + StrideQuantBScale * 2); - const float scale_v3 = *(s + StrideQuantBScale * 3); - - const __m128i* b0ptr = (const __m128i*)(b_blk_data_ptr); - const __m128i* b1ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 1); - const __m128i* b2ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 2); - const __m128i* b3ptr = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * 3); - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); - - // Load A row vectors - uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - mask = mask >> 16; - __m512 av_hi = mask == 0 ? _mm512_setzero_ps() - : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); - - // Load B col vectors - __m256i bytes0, bytes1, bytes2, bytes3; - if constexpr (IsBlkLen64Layout) { - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 - // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd - // so that all v0-63 of the pack are processed. - const __m256i bvi4_0 = _mm256_loadu_si256((__m256i const*)(b0ptr)); - const __m256i bvi4_1 = _mm256_loadu_si256((__m256i const*)(b1ptr)); - const __m256i bvi4_2 = _mm256_loadu_si256((__m256i const*)(b2ptr)); - const __m256i bvi4_3 = _mm256_loadu_si256((__m256i const*)(b3ptr)); - const int count_half_4 = - 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); - bytes0 = _mm256_and_si256(_mm256_srli_epi16(bvi4_0, count_half_4), lowMask); - bytes1 = _mm256_and_si256(_mm256_srli_epi16(bvi4_1, count_half_4), lowMask); - bytes2 = _mm256_and_si256(_mm256_srli_epi16(bvi4_2, count_half_4), lowMask); - bytes3 = _mm256_and_si256(_mm256_srli_epi16(bvi4_3, count_half_4), lowMask); - b0ptr += count_half_4 / 2; - b1ptr += count_half_4 / 2; - b2ptr += count_half_4 / 2; - b3ptr += count_half_4 / 2; - } else { - const __m128i bvi4_0 = _mm_loadu_si128(b0ptr++); - const __m128i bvi4_1 = _mm_loadu_si128(b1ptr++); - const __m128i bvi4_2 = _mm_loadu_si128(b2ptr++); - const __m128i bvi4_3 = _mm_loadu_si128(b3ptr++); - - // expand 4b into byte array - bytes0 = _mm256_set_m128i(_mm_srli_epi16(bvi4_0, 4), bvi4_0); - bytes1 = _mm256_set_m128i(_mm_srli_epi16(bvi4_1, 4), bvi4_1); - bytes2 = _mm256_set_m128i(_mm_srli_epi16(bvi4_2, 4), bvi4_2); - bytes3 = _mm256_set_m128i(_mm_srli_epi16(bvi4_3, 4), bvi4_3); - bytes0 = _mm256_and_si256(lowMask, bytes0); - bytes1 = _mm256_and_si256(lowMask, bytes1); - bytes2 = _mm256_and_si256(lowMask, bytes2); - bytes3 = _mm256_and_si256(lowMask, bytes3); - } - - // Subtract zero-point from the integers - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - - // TODO: void condition on is_lower - std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - - bytes0 = _mm256_sub_epi8(bytes0, _mm256_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes1 = _mm256_sub_epi8(bytes1, _mm256_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes2 = _mm256_sub_epi8(bytes2, _mm256_set1_epi8(zp)); - - zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes3 = _mm256_sub_epi8(bytes3, _mm256_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes0 = _mm256_sub_epi8(bytes0, eight); - bytes1 = _mm256_sub_epi8(bytes1, eight); - bytes2 = _mm256_sub_epi8(bytes2, eight); - bytes3 = _mm256_sub_epi8(bytes3, eight); - } - - // Convert to 16-bit int - const __m256i vx16_lo0 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 0)); - const __m256i vx16_hi0 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes0, 1)); - const __m256i vx16_lo1 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 0)); - const __m256i vx16_hi1 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes1, 1)); - const __m256i vx16_lo2 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 0)); - const __m256i vx16_hi2 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes2, 1)); - const __m256i vx16_lo3 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 0)); - const __m256i vx16_hi3 = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes3, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo0)); - __m512 bvf_hi0 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi0)); - __m512 bvf_lo1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo1)); - __m512 bvf_hi1 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi1)); - __m512 bvf_lo2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo2)); - __m512 bvf_hi2 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi2)); - __m512 bvf_lo3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo3)); - __m512 bvf_hi3 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi3)); - - __m512 scale_ps = _mm512_set1_ps(scale_v0); - bvf_lo0 = _mm512_mul_ps(bvf_lo0, scale_ps); - bvf_hi0 = _mm512_mul_ps(bvf_hi0, scale_ps); - scale_ps = _mm512_set1_ps(scale_v1); - bvf_lo1 = _mm512_mul_ps(bvf_lo1, scale_ps); - bvf_hi1 = _mm512_mul_ps(bvf_hi1, scale_ps); - scale_ps = _mm512_set1_ps(scale_v2); - bvf_lo2 = _mm512_mul_ps(bvf_lo2, scale_ps); - bvf_hi2 = _mm512_mul_ps(bvf_hi2, scale_ps); - scale_ps = _mm512_set1_ps(scale_v3); - bvf_lo3 = _mm512_mul_ps(bvf_lo3, scale_ps); - bvf_hi3 = _mm512_mul_ps(bvf_hi3, scale_ps); - - acc_lo0 = _mm512_fmadd_ps(bvf_lo0, av_lo, acc_lo0); - acc_lo0 = _mm512_fmadd_ps(bvf_hi0, av_hi, acc_lo0); - acc_lo1 = _mm512_fmadd_ps(bvf_lo1, av_lo, acc_lo1); - acc_lo1 = _mm512_fmadd_ps(bvf_hi1, av_hi, acc_lo1); - acc_lo2 = _mm512_fmadd_ps(bvf_lo2, av_lo, acc_lo2); - acc_lo2 = _mm512_fmadd_ps(bvf_hi2, av_hi, acc_lo2); - acc_lo3 = _mm512_fmadd_ps(bvf_lo3, av_lo, acc_lo3); - acc_lo3 = _mm512_fmadd_ps(bvf_hi3, av_hi, acc_lo3); - } // kk - - //*// - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - //*// - - } // k - - __m128 acc_x = FoldAccumulators(acc_lo0, acc_lo1, acc_lo2, acc_lo3); - if (Bias != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias_ptr)); - } - _mm_storeu_ps(sum_ptr, acc_x); - - // move to next 4 columns - sum_ptr += 4; - bias_ptr += 4; - nblk -= 4; - - //*// - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - ////BiasPtr += BiasPtr != nullptr ? NCols : 0; - ////SumPtr += NCols; - - ////nblk -= NCols; - //*// - } - - // left over columns less than 4 ? - nblk += 4; - if (nblk > 0) { - __m512 acc_lo[4]{}; - - //*// - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - //*// - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx = 0; - } - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - float scale_v[4]; - const __m128i* b_ptr[4]; - for (int64_t nn = 0; nn < nblk; nn++) { - //*// - scale_v[nn] = *(s + StrideQuantBScale * nn); - b_ptr[nn] = (const __m128i*)(b_blk_data_ptr + StrideQuantBData * nn); - //*// - } - - for (size_t kk = 0; kk < ck; kk += MLAS_QUANT4_BLK_UNIT32) { - size_t kklen = std::min((size_t)MLAS_QUANT4_BLK_UNIT32, ck - kk); - - uint32_t mask = 0xffffffff >> (MLAS_QUANT4_BLK_UNIT32 - kklen); - __m512 av_lo = _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk); - - mask = mask >> 16; - __m512 av_hi = mask == 0 - ? _mm512_setzero_ps() - : _mm512_maskz_loadu_ps(__mmask16(mask), A + k + kk + 16); - - for (int64_t nn = 0; nn < nblk; nn++) { - __m256i bytes; - if constexpr (IsBlkLen64Layout) { - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // load 64 weights at once, parse to get v0 - v31 if subblk is even, otherwise get v32 - v63 - // increment b_data_ptr by 2 * MLAS_QUANT4_BLK_UNIT32 if subblk is odd - // so that all v0-63 of the pack are processed. - const __m256i bvi4 = _mm256_loadu_si256((__m256i const*)(b_ptr[nn])); - const int count_half_4 = - 4 * ((kk % (2 * MLAS_QUANT4_BLK_UNIT32)) / MLAS_QUANT4_BLK_UNIT32); - bytes = _mm256_and_si256(_mm256_srli_epi16(bvi4, count_half_4), lowMask); - b_ptr[nn] += count_half_4 / 2; - } else { - const __m128i bvi4 = _mm_loadu_si128(b_ptr[nn]++); - bytes = _mm256_set_m128i(_mm_srli_epi16(bvi4, 4), bvi4); - bytes = _mm256_and_si256(lowMask, bytes); - } - if constexpr (HasZeroPoint) { - // Subtract zero-point from the integers - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - - // TODO: void condition on is_lower - std::byte zp_packed = QuantBZeroPointColPtr[nn * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - bytes = _mm256_sub_epi8(bytes, _mm256_set1_epi8(zp)); - } else { - // Subtract 8 from the integers - const __m256i eight = _mm256_set1_epi8(8); - bytes = _mm256_sub_epi8(bytes, eight); - } - - // Convert to 16-bit int - const __m256i vx16_lo = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 0)); - const __m256i vx16_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(bytes, 1)); - - // Convert to 32-bit int -> float 32 - __m512 bvf_lo = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_lo)); - __m512 bvf_hi = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(vx16_hi)); - __m512 scale_16_ps = _mm512_set1_ps(scale_v[nn]); - bvf_lo = _mm512_mul_ps(bvf_lo, scale_16_ps); - bvf_hi = _mm512_mul_ps(bvf_hi, scale_16_ps); - - acc_lo[nn] = _mm512_fmadd_ps(bvf_lo, av_lo, acc_lo[nn]); - acc_lo[nn] = _mm512_fmadd_ps(bvf_hi, av_hi, acc_lo[nn]); - } - } // kk - - //*// - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - s++; - - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - //*// - } // k - - for (int64_t nn = 0; nn < nblk; nn++) { - sum_ptr[nn] = _mm512_reduce_add_ps(acc_lo[nn]); - sum_ptr[nn] += Bias == nullptr ? 0.0f : bias_ptr[nn]; - } - } - - // Prepare pointers for the next row - C += ldc; - A += lda; - } - return CountM; -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h deleted file mode 100644 index 895ce6cd091c2..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ /dev/null @@ -1,736 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" -#include "sqnbitgemm_q8_block.h" - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - if constexpr (!HasZeroPoint) { - // Suppress unused variable warnings - (void)QuantBZeroPointColPtr; - (void)StrideQuantBZeroPoint; - } - - assert(BlkLen == 16); - constexpr size_t SubBlkLen = 16; - const __m128i low_mask = _mm_set1_epi8(0xF); - - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); - - __m256 acc[NCols]; - UnrolledLoop([&](size_t i) { - acc[i] = _mm256_setzero_ps(); - }); - - const std::byte* ablob = QuantARowPtr; - const auto* b = QuantBDataColPtr; - const float* s = QuantBScaleColPtr; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint == true - - for (size_t k = 0; k < CountK; k += BlkLen) { - const float a_scale = Q8BlkScale(ablob); - ablob += sizeof(float); - - float scale_v[NCols]; - UnrolledLoop([&](size_t i) { - scale_v[i] = (*(s + StrideQuantBScale * i)) * a_scale; - }); - - std::byte* bptr[NCols]; - UnrolledLoop([&](size_t i) { - bptr[i] = (std::byte*)(b + StrideQuantBData * i); - }); - - [[maybe_unused]] uint8_t offset[NCols]; - // not ready for "Manual conversion to float" in neon yet. following neon to unpack to uint8_t. - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = std::to_integer(zp); - }); - } - - // Load A row vector - const __m128i av_epi8 = _mm_lddqu_si128((const __m128i*)ablob); - __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); - ablob += BlkLen; - - // Load 4 B column vectors (quantized to int4 blobs) - __m128i bvi[NCols]; - UnrolledLoop([&](size_t i) { - bvi[i] = _mm_loadl_epi64((__m128i const*)bptr[i]); - bptr[i] += SubBlkStep; - }); - - // expand 4b into byte array - __m256i bv_epi16[NCols]; - UnrolledLoop([&](size_t i) { - const __m128i lower = _mm_and_si128(bvi[i], low_mask); - const __m128i upper = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bvi[i], 4), low_mask), 8); - bv_epi16[i] = _mm256_cvtepi8_epi16(_mm_add_epi8(upper, lower)); - }); - - // Subtract zero-point from the integers - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], _mm256_set1_epi16(offset[i])); - }); - } else { - const __m256i eight = _mm256_set1_epi16(8); - UnrolledLoop([&](size_t i) { - bv_epi16[i] = _mm256_sub_epi16(bv_epi16[i], eight); - }); - } - - UnrolledLoop([&](size_t i) { - __m256i prod_8_epi32 = _mm256_madd_epi16(bv_epi16[i], av_epi16); - - const __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); - acc[i] = _mm256_fmadd_ps(_mm256_set1_ps(scale_v[i]), prod_8_ps, acc[i]); - }); - - b += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - s++; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } - - if constexpr (NCols == 4) { - __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); - } else { - UnrolledLoop([&](size_t i) { - __m128 vlow = _mm256_castps256_ps128(acc[i]); - __m128 vhigh = _mm256_extractf128_ps(acc[i], 1); // Extract high 128 bit - - // Add the two 128-bit vectors together - __m128 vsum = _mm_add_ps(vlow, vhigh); - // Horizontally add the elements of the resulting 128-bit vector - vsum = _mm_hadd_ps(vsum, vsum); - vsum = _mm_hadd_ps(vsum, vsum); - - _mm_store_ss(&SumPtr[i], vsum); - SumPtr[i] += BiasPtr == nullptr ? 0.0f : BiasPtr[i]; - }); - } -} - -template -void -SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t NCols4 = 4; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t BlkLen16 = 16; - - const std::byte* QuantARowPtr = QuantA; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols4; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16( - BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, - SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - - nblk -= NCols4; - } - - // left over columns less than `NCols`? - nblk += NCols4; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen16<1, HasZeroPoint>( - BlkLen16, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, - SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template accumulator> -void -SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - // port from neon implementation - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const __m256i zero = _mm256_setzero_si256(); - const __m128i low_mask = _mm_set1_epi8(0xF); - const size_t NCols = 4; - int64_t nblk = (int64_t)(CountN)-4; - while (nblk >= 0) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 - acc0 = _mm256_setzero_ps(), - acc1 = _mm256_setzero_ps(), - acc2 = _mm256_setzero_ps(), - acc3 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - - const float& scale_a0 = *QuantAScalePtr; - const float& scale_a1 = *(QuantAScalePtr + 1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); - - // increment block pointers - QuantAPtr += BlkLen * 2; - QuantAScalePtr += 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a0 = *QuantAScalePtr; - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - - // Col1 - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); - - // Col2 - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); - - // Col3 - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); - } - - __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); - - // move to next NCols columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - nblk -= NCols; - } - - nblk += NCols; - for (int64_t n = 0; n < nblk; n++) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - - const float& scale_a0 = *QuantAScalePtr; - const float& scale_a1 = *(QuantAScalePtr + 1); - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); - - // increment block pointers - QuantAPtr += BlkLen * 2; - QuantAScalePtr += 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a0 = *QuantAScalePtr; - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -using DotQuadFunctionType = __m256 (*)( - const __m256i, const __m256i, const __m256i, const __m256i -); - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - // TODO: make it work with all BlkLens - assert(BlkLen >= 64); - constexpr size_t SubBlkLen64 = 64; - // const __m256i zero = _mm256_setzero_si256(); - const __m256i low_mask = _mm256_set1_epi8(0xF); - - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); - - __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(), acc2 = _mm256_setzero_ps(), acc3 = _mm256_setzero_ps(); - - const std::byte* ablob = QuantARowPtr; - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* blk_scale_ptr = QuantBScaleColPtr; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint == true - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - const float a_scale = Q8BlkScale(ablob); - ablob += sizeof(float); - - float - scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale, - scale_v1 = (*(blk_scale_ptr + StrideQuantBScale * 1)) * a_scale, - scale_v2 = (*(blk_scale_ptr + StrideQuantBScale * 2)) * a_scale, - scale_v3 = (*(blk_scale_ptr + StrideQuantBScale * 3)) * a_scale; - - const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); - const std::byte* bptr1 = (b_blk_data_ptr + StrideQuantBData * 1); - const std::byte* bptr2 = (b_blk_data_ptr + StrideQuantBData * 2); - const std::byte* bptr3 = (b_blk_data_ptr + StrideQuantBData * 3); - - uint8_t zp0, zp1, zp2, zp3; - if constexpr (HasZeroPoint) { - // TODO: this block causes near 30% of the computation. - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - zp_packed = QuantBZeroPointColPtr[1 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp1 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - zp_packed = QuantBZeroPointColPtr[2 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp2 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - zp_packed = QuantBZeroPointColPtr[3 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp3 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - } else { - zp0 = 8; - zp1 = 8; - zp2 = 8; - zp3 = 8; - } - - for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { - // Load A row vector - const __m256i av0_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); - ablob += 32; - const __m256i av1_32_epi8 = _mm256_loadu_si256((const __m256i*)ablob); - ablob += 32; - - // Load B column vectors (quantized to int4 blobs) - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); - bptr0 += SubBlkStep; - __m256i bv0_32_epi8 = _mm256_and_si256(bv, low_mask); - __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); - __m256i zp_epi8 = _mm256_set1_epi8(zp0); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); - __m256 sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); - acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); - - bv = _mm256_loadu_si256((__m256i const*)bptr1); - bptr1 += SubBlkStep; - bv0_32_epi8 = _mm256_and_si256(bv, low_mask); - bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); - zp_epi8 = _mm256_set1_epi8(zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); - sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); - acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v1), sum_ps, acc1); - - bv = _mm256_loadu_si256((__m256i const*)bptr2); - bptr2 += SubBlkStep; - bv0_32_epi8 = _mm256_and_si256(bv, low_mask); - bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); - zp_epi8 = _mm256_set1_epi8(zp2); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); - sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); - acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v2), sum_ps, acc2); - - bv = _mm256_loadu_si256((__m256i const*)bptr3); - bptr3 += SubBlkStep; - bv0_32_epi8 = _mm256_and_si256(bv, low_mask); - bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); - zp_epi8 = _mm256_set1_epi8(zp3); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, zp_epi8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, zp_epi8); - sum_ps = dot_quad(bv0_32_epi8, bv1_32_epi8, av0_32_epi8, av1_32_epi8); - acc3 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v3), sum_ps, acc3); - } // kk - - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - blk_scale_ptr++; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } // k - - __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); -} - -// TODO: is this able to be inlined if DotQuadFunctionType is a function pointer? -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - // TODO: make it work with all BlkLens - assert(BlkLen >= 64); - constexpr size_t SubBlkLen64 = 64; - // const __m256i zero = _mm256_setzero_si256(); - const __m256i low_mask = _mm256_set1_epi8(0xF); - - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkStep = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen64); - - __m256 acc0 = _mm256_setzero_ps(); - - const std::byte* ablob = QuantARowPtr; - const std::byte* b_blk_data_ptr = QuantBDataColPtr; - const float* blk_scale_ptr = QuantBScaleColPtr; - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint == true - - for (size_t k = 0; k < CountK; k += BlkLen) { - size_t ck = std::min(CountK - k, BlkLen); - - const float a_scale = Q8BlkScale(ablob); - ablob += sizeof(float); - - float scale_v0 = (*(blk_scale_ptr + StrideQuantBScale * 0)) * a_scale; - - const std::byte* bptr0 = (b_blk_data_ptr + StrideQuantBData * 0); - - uint8_t zp0; - if constexpr (HasZeroPoint) { - // TODO: this block causes near 30% of the computation. - // The solution proposed by @yufenglee is to factor out scaleB * zp - // while packing A. Will do this in next PR. - bool is_lower = (QuantBZeroPointIdx & 1) == 0; - std::byte zp_packed = QuantBZeroPointColPtr[0 * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - zp0 = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); - } else { - zp0 = 8; - } - - for (size_t kk = 0; kk < ck; kk += SubBlkLen64) { - // Load A row vector - const __m256i a_byte_lo = _mm256_loadu_si256((const __m256i*)ablob); - ablob += 32; - const __m256i a_byte_hi = _mm256_loadu_si256((const __m256i*)ablob); - ablob += 32; - - // Load B column vectors (quantized to int4 blobs) - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - __m256i bv = _mm256_loadu_si256((__m256i const*)bptr0); - bptr0 += SubBlkStep; - __m256i bv_lo_epi8 = _mm256_and_si256(bv, low_mask); - __m256i bv_hi_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv, 4), low_mask); - __m256i zp_epi8 = _mm256_set1_epi8(zp0); - bv_lo_epi8 = _mm256_sub_epi8(bv_lo_epi8, zp_epi8); - bv_hi_epi8 = _mm256_sub_epi8(bv_hi_epi8, zp_epi8); - __m256 sum_ps = dot_quad(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); - //__m256 sum_ps = mul_sum_s8_quads_float_avx2(bv_lo_epi8, bv_hi_epi8, a_byte_lo, a_byte_hi); - acc0 = _mm256_fmadd_ps(_mm256_set1_ps(scale_v0), sum_ps, acc0); - } // kk - - b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - blk_scale_ptr++; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } // k - - *SumPtr = hsum_float_8(acc0); - *SumPtr += BiasPtr == nullptr ? 0.0f : *BiasPtr; -} - -template -void -SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - - const std::byte* QuantARowPtr = QuantA; - float* CRowPtr = C; - - const size_t BlockCountK = BlockStrideQuantB; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - const size_t NCols = 4; - int64_t nblk = static_cast(CountN) - NCols; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols4( - BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, - SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - - nblk -= NCols; - } - - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8_SubBlkLen64_NCols1( - BlkLen, - QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp deleted file mode 100644 index 3f32cc6c5312d..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ /dev/null @@ -1,194 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_neon.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for ARM NEON. - ---*/ - -#include - -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" -#include "sqnbitgemm_q8_block.h" - -namespace sqnbitgemm_neon -{ - -namespace -{ - -// -// Quantized B data packing function implementation. -// - -size_t -SQ4BitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - - constexpr size_t BlkBitWidth = 4; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; -} - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - const size_t SubBlkLen = (ComputeType == CompInt8) - ? ((BlkLen == 16) ? 16 : 32) - : 16; - - const size_t SubBlkDataSize = SubBlkLen / 2; - const size_t SubBlkBytePairCount = SubBlkLen / 4; - - // - // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - - // - // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | - // => - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { - for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += SubBlkDataSize; - PackedQuantBData += SubBlkDataSize; - } - } - ); -} - -// -// Workspace size calculation function implementation. -// - -size_t -SQ4BitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(N); - - switch (ComputeType) { - case CompInt8: { - // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - return PerGemmWorkspaceSize; - } - default: { - return 0; - } - } -} - -size_t -SQ4BitGemmPerGemmWorkspaceAlignment( - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(BlkLen); - - switch (ComputeType) { - case CompInt8: { - return Q8BlkAlignment(); - } - default: { - return 1; - } - } -} - -} // namespace - -} // namespace sqnbitgemm_neon - -// -// Kernel dispatch structure definition. -// - -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; - d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; - - return d; -}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h deleted file mode 100644 index ef9345d7ac484..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h +++ /dev/null @@ -1,144 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_neon.h - -Abstract: - - This module includes function declarations and common helper functions for - SQNBitGemm ARM NEON kernels. - ---*/ - -#pragma once - -#include - -#include -#include -#include - -#include "mlasi.h" - -namespace sqnbitgemm_neon -{ - -// -// Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. -// These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate -// files. -// - -// CompFp32 declarations - -void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias -); - -void -Q4BitBlkDequantBForSgemm_CompFp32( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockCountK -); - -// CompInt8 declarations - -void -QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -); - -size_t -SQ4BitGemmKernel_CompInt8( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - size_t ldc, - const float* Bias -); - -// -// General helpers. -// - -template -MLAS_FORCEINLINE void -UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) -{ - (f(Indices), ...); -} - -template -MLAS_FORCEINLINE void -UnrolledLoop(IterationFn&& f) -{ - UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); -} - -template -MLAS_FORCEINLINE void -LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) -{ - static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); - - assert(count <= Capacity); - - size_t vi = 0; // vector index - - // handle 4 values at a time - while (count > 3) { - dst[vi] = vld1q_f32(src); - - vi += 1; - src += 4; - count -= 4; - } - - // handle remaining values - if (count > 0) { - dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); - - if (count > 1) { - dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); - - if (count > 2) { - dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); - } - } - } -} - -} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp deleted file mode 100644 index 12ddc42506e98..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ /dev/null @@ -1,647 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_neon_fp32.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for ARM NEON specific to - input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. - ---*/ - -#include - -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" - -namespace sqnbitgemm_neon -{ - -namespace -{ - -// -// CompFp32 kernel implementation. -// - -MLAS_FORCEINLINE void -Transpose4x4(float32x4_t& a0, float32x4_t& a1, float32x4_t& a2, float32x4_t& a3) -{ - // aN: aN_0 aN_1 aN_2 aN_3 - - float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 - float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 - float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 - float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 - - // a0_0 a1_0 a2_0 a3_0 - a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_1 a1_1 a2_1 a3_1 - a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); - // a0_2 a1_2 a3_2 a3_2 - a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); - // a0_3 a1_3 a2_3 a3_3 - a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); -} - -MLAS_FORCEINLINE float32x4_t -FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) -{ - Transpose4x4(a0, a1, a2, a3); - return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); -} - -namespace fp32_conversion -{ - -// Manual conversion to float takes place in two steps: -// 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. -// This target float range is convenient because the 4-bit source values can be placed directly into the -// target float bits. -// 2. Subtract the conversion offset of 16 from the float result. - -// The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. -constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; -// sign|exponent|partial mantissa -// +|131: 2^4|~~~~ <- 4 bits go here - -const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); - -constexpr float offset = 16.0f; - -} // namespace fp32_conversion - -template -MLAS_FORCEINLINE void -ComputeDotProducts_BlkBitWidth4_CompFp32( - size_t BlkLen, - const float* ARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - float* SumPtr, - size_t CountK, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const float* BiasPtr -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; - - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - - assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); - - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - float32x4_t acc[NCols]{}; - - const std::byte* QuantBData = QuantBDataColPtr; - const float* QuantBScale = QuantBScaleColPtr; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - // only used if HasZeroPoint is true - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - float scale[NCols]; - UnrolledLoop( - [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } - ); - - [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset. - // only used if HasZeroPoint is true - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const std::byte zp_packed = - QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; - const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[i] = fp32_conversion::offset + std::to_integer(zp); - }); - } - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { - // load A row vector elements - - // load `SubBlkLen` elements from A, padded with 0's if there aren't enough - const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); - float32x4_t av[4]{}; - LoadFloatData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(scale[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // c[m,n] += a[m,k] * b[k,n] - UnrolledLoop<4>([&](size_t j) { - UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); - }); - } - - // increment pointers to next block - QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - QuantBScale += 1; - if constexpr (HasZeroPoint) { - QuantBZeroPointIdx += 1; - } - } - - if constexpr (NCols == 4) { - float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - - if (BiasPtr != nullptr) { - sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); - } - - vst1q_f32(SumPtr, sum); - } else { - for (size_t i = 0; i < NCols; ++i) { - SumPtr[i] = vaddvq_f32(acc[i]); - if (BiasPtr != nullptr) { - SumPtr[i] += BiasPtr[i]; - } - } - } -} - -template -void -SQ4BitGemmM1Kernel_CompFp32_Impl( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; - - const float* ARowPtr = A; - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - - int64_t nblk = static_cast(CountN) - NCols; - - while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next `NCols` columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - - nblk -= NCols; - } - - // left over columns less than `NCols`? - nblk += NCols; - for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( - BlkLen, - ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - BiasPtr - ); - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -} // namespace - -void -SQ4BitGemmM1Kernel_CompFp32( - size_t BlkLen, - const float* A, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockCountK, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockCountK, - Bias - ); - } else { - constexpr bool HasZeroPoint = false; - SQ4BitGemmM1Kernel_CompFp32_Impl( - BlkLen, - A, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockCountK, - Bias - ); - } -} - -namespace -{ - -// Block dequantize a 16 x NCols section of B from column major source to row major destination. -template -MLAS_FORCEINLINE void -Q4BitBlkDequantB_16xNCols( - const std::byte* QuantBDataPtr, - size_t StrideQuantBData, - const float* QuantBColScalePtr, // pointer to NCols scales of adjacent columns - [[maybe_unused]] const float* QuantBColOffsetPtr, // pointer to NCols offsets of adjacent columns - // only used if HasZeroPoint is true - float* DstColPtr -) -{ - const uint8x8_t LowMask = vdup_n_u8(0x0F); - - // load B column vectors - uint8x8_t bv_packed[NCols]; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBDataPtr) + i * StrideQuantBData - ); - }); - - uint8x8_t bv_u8[NCols][2]; - UnrolledLoop([&](size_t i) { - bv_u8[i][0] = vand_u8(bv_packed[i], LowMask); - bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); - }); - - // shift left 3 and widen to 16 bits - uint16x8_t bv_u16[NCols][2]; - UnrolledLoop([&](size_t i) { - constexpr int shift = 3; - bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); - bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); - }); - - // combine 4 bits with float high half template - UnrolledLoop([&](size_t i) { - bv_u16[i][0] = vorrq_u16(bv_u16[i][0], fp32_conversion::float_high_half_template_v); - bv_u16[i][1] = vorrq_u16(bv_u16[i][1], fp32_conversion::float_high_half_template_v); - }); - - // `SubBlkLen` floats of B - float32x4_t bv[NCols][4]; - - // shift left 16, widen to 32 bits, and reinterpret as float - UnrolledLoop([&](size_t i) { - constexpr int shift = 16; - bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); - bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); - - bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); - bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); - }); - - // subtract float conversion offset and zero point - if constexpr (HasZeroPoint) { - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(QuantBColOffsetPtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } else { - const float32x4_t offset_v = vdupq_n_f32(fp32_conversion::offset + 8.0f); - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); - } - - // multiply by scale - UnrolledLoop([&](size_t i) { - const float32x4_t scale_v = vdupq_n_f32(QuantBColScalePtr[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); - }); - - // write, transposed, 16 x NCols values - if constexpr (NCols == 4) { - UnrolledLoop<4>([&](size_t j) { - Transpose4x4(bv[0][j], bv[1][j], bv[2][j], bv[3][j]); - - vst1q_f32(&DstColPtr[(j * 4 + 0) * 16], bv[0][j]); - vst1q_f32(&DstColPtr[(j * 4 + 1) * 16], bv[1][j]); - vst1q_f32(&DstColPtr[(j * 4 + 2) * 16], bv[2][j]); - vst1q_f32(&DstColPtr[(j * 4 + 3) * 16], bv[3][j]); - }); - } else { - UnrolledLoop([&](size_t i) { - UnrolledLoop<4>([&](size_t j) { - DstColPtr[(j * 4 + 0) * 16 + i] = vgetq_lane_f32(bv[i][j], 0); - DstColPtr[(j * 4 + 1) * 16 + i] = vgetq_lane_f32(bv[i][j], 1); - DstColPtr[(j * 4 + 2) * 16 + i] = vgetq_lane_f32(bv[i][j], 2); - DstColPtr[(j * 4 + 3) * 16 + i] = vgetq_lane_f32(bv[i][j], 3); - }); - }); - } -} - -template -void -Q4BitBlkDequantBForSgemm_CompFp32_Impl( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockCountK -) -{ - constexpr size_t BlkBitWidth = 4; - - float* Dst = FpData; - - const std::byte* QuantBDataCol = QuantBData; - const float* QuantBScaleCol = QuantBScale; - [[maybe_unused]] const std::byte* QuantBZeroPointCol = QuantBZeroPoint; // only used if HasZeroPoint is true - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - [[maybe_unused]] const size_t StrideQuantBZeroPoint = // only used if HasZeroPoint is true - MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - // - // Proceed down 16 column-wide regions of B. Dequantize and write output 16 x 16 elements at a time. - // - - // scales of blocks from 16 adjacent columns - float scale[16]; - // float conversion offsets (including zero point) of blocks from 16 adjacent columns - [[maybe_unused]] float offset[16]; // only used if HasZeroPoint is true - - size_t n_cols_remaining = CountN; - while (n_cols_remaining > 15) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < 16; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - constexpr size_t NCols = 4; - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < 16; nn += NCols) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += NCols; - if constexpr (HasZeroPoint) { - OffsetPtr += NCols; - } - DstColPtr += NCols; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - - n_cols_remaining -= 16; - - QuantBDataCol += 16 * StrideQuantBData; - QuantBScaleCol += 16 * BlockCountK; - if constexpr (HasZeroPoint) { - QuantBZeroPointCol += 16 * StrideQuantBZeroPoint; - } - } - - if (n_cols_remaining > 0) { - for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, ++k_blk_idx) { - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - scale[nn] = QuantBScaleCol[nn * BlockCountK + k_blk_idx]; - - if constexpr (HasZeroPoint) { - const std::byte zp_packed = - QuantBZeroPointCol[nn * StrideQuantBZeroPoint + k_blk_idx / 2]; - const std::byte zp = ((k_blk_idx & 1) == 1) - ? (zp_packed >> 4) - : (zp_packed & std::byte{0x0F}); - offset[nn] = fp32_conversion::offset + std::to_integer(zp); - } - } - - const size_t kklen = std::min(CountK - k, BlkLen); - - for (size_t kk = 0; kk < kklen; kk += 16) { - // zero out the 16x16 block in Dst first to ensure zero padding - const float32x4_t zero_v = vdupq_n_f32(0.0f); - UnrolledLoop<16 * 4>([&](size_t i) { - vst1q_f32(Dst + 4 * i, zero_v); - }); - - const float* ScalePtr = &scale[0]; - const float* OffsetPtr = HasZeroPoint ? &offset[0] : nullptr; - - float* DstColPtr = Dst; - - for (size_t nn = 0; nn < n_cols_remaining; ++nn) { - const std::byte* QuantBDataPtr = QuantBDataCol + nn * StrideQuantBData + (k + kk) * BlkBitWidth / 8; - - Q4BitBlkDequantB_16xNCols<1, HasZeroPoint>( - QuantBDataPtr, - StrideQuantBData, - ScalePtr, - OffsetPtr, - DstColPtr - ); - - ScalePtr += 1; - if constexpr (HasZeroPoint) { - OffsetPtr += 1; - } - DstColPtr += 1; - } - - Dst += 16 * std::min(kklen - kk, size_t{16}); - } - } - } -} - -} // namespace - -void -Q4BitBlkDequantBForSgemm_CompFp32( - size_t BlkLen, - float* FpData, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t CountK, - size_t BlockCountK -) -{ - if (QuantBZeroPoint != nullptr) { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockCountK - ); - } else { - Q4BitBlkDequantBForSgemm_CompFp32_Impl( - BlkLen, - FpData, - QuantBData, - QuantBScale, - QuantBZeroPoint, - CountN, - CountK, - BlockCountK - ); - } -} - -} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp deleted file mode 100644 index 0d62ea37b7e26..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ /dev/null @@ -1,1402 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_neon_int8.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for ARM NEON specific to - input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. - ---*/ - -#include - -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" -#include "sqnbitgemm_q8_block.h" - -namespace sqnbitgemm_neon -{ - -// -// CompInt8 kernel implementation. -// - -namespace -{ - -template -MLAS_FORCEINLINE void -QuantizeBlock( - size_t BlkLen, - const float* A, - size_t ElementCount, - std::byte* QuantA -) -{ - static_assert(SubBlkLen >= 16 && SubBlkLen % 16 == 0); - - assert(BlkLen % SubBlkLen == 0); - - // - // Scan block values first to determine scale. - // - - float amax = 0.0f; // max of absolute values of A block - - size_t k; - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - float32x4_t abs_a[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - abs_a[i] = vabsq_f32(a[i]); - }); - - // find amax of SubBlkLen elements - for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { - for (size_t i = 0; i < interval; ++i) { - abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); - } - } - - // update existing amax - amax = std::max(amax, vmaxvq_f32(abs_a[0])); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - Q8BlkScale(QuantA) = scale; - - // - // Compute quantized block values. - // - - int8_t* QuantAData = Q8BlkData(QuantA); - - for (k = 0; k < ElementCount; k += SubBlkLen) { - const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - - float32x4_t a[SubBlkLen / 4]{}; - LoadFloatData(A + k, SubBlkElementCount, a); - - UnrolledLoop([&](size_t i) { - a[i] = vmulq_n_f32(a[i], scale_reciprocal); - }); - - int32x4_t a_s32[SubBlkLen / 4]; - UnrolledLoop([&](size_t i) { - a_s32[i] = vcvtaq_s32_f32(a[i]); - }); - - UnrolledLoop([&](size_t i) { - QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); - QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); - QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); - QuantAData[k + i * 4 + 3] = static_cast(vgetq_lane_s32(a_s32[i], 3)); - }); - } - - // - // Zero out any remaining sub-block elements. - // - - for (; k < BlkLen; k += SubBlkLen) { - const int8x16_t Zeros = vdupq_n_s8(0); - UnrolledLoop([&](size_t i) { - vst1q_s8(QuantAData + k + i * 16, Zeros); - }); - } -} - -} // namespace - -void -QuantizeARow_CompInt8( - size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA -) -{ - const float* ADataBlkPtr = A; - std::byte* QuantABlkPtr = QuantA; - - for (size_t k = 0; k < CountK; k += BlkLen) { - const size_t k_blk_len = std::min(CountK - k, BlkLen); - - QuantizeBlock<16>(BlkLen, ADataBlkPtr, k_blk_len, QuantABlkPtr); - - ADataBlkPtr += BlkLen; - QuantABlkPtr += Q8BlkSize(BlkLen); - } -} - -namespace -{ - -// -// The ComputeRxC functions compute an R row by C column tile of the output matrix. -// - -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK, - size_t StrideQuantA, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - size_t ldc -) -{ - constexpr size_t BlkLen = 16; - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - const std::byte* QuantABlkRow0 = QuantAPtr; - const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; - const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; - const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; - - const float QuantBScaleCol0 = *QuantBScalePtr; - const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); - - // compute combined scales - const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; - const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; - const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; - const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; - const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; - const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; - const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; - const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; - - // load B zero point - int8_t bzp_col0; - int8_t bzp_col1; - if constexpr (HasZeroPoint) { - const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; - const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); - if ((k_blk_idx & 1) == 0) { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); - } else { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); - } - } else { - bzp_col0 = 8; - bzp_col1 = 8; - } - - const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); - const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); - const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); - const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); - - // TODO handling only 16 elements per accumulator at a time here, probably can do better - { - // load B - const uint8x8_t bv_packed_col0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x8_t bv_packed_col1 = vld1_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); - - const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - - int8x16_t bv_col0 = vreinterpretq_s8_u8( - vcombine_u8( - vand_u8(bv_packed_col0, LowMaskU8x8), - vshr_n_u8(bv_packed_col0, 4) - ) - ); - int8x16_t bv_col1 = vreinterpretq_s8_u8( - vcombine_u8( - vand_u8(bv_packed_col1, LowMaskU8x8), - vshr_n_u8(bv_packed_col1, 4) - ) - ); - - // subtract B zero point - bv_col0 = vsubq_s8(bv_col0, vdupq_n_s8(bzp_col0)); - bv_col1 = vsubq_s8(bv_col1, vdupq_n_s8(bzp_col1)); - - // rows 0 and 1 of A - { - // load A - const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); - const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); - - // quantized dot product - const int32x4_t dot00 = vdotq_s32(int32x4_t{}, av_row0, bv_col0); - const int32x4_t dot01 = vdotq_s32(int32x4_t{}, av_row0, bv_col1); - const int32x4_t dot10 = vdotq_s32(int32x4_t{}, av_row1, bv_col0); - const int32x4_t dot11 = vdotq_s32(int32x4_t{}, av_row1, bv_col1); - - // convert to float - const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); - const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); - const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); - const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); - - // multiply by scale and update accumulator - acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); - acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); - acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); - acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); - } - - // rows 2 and 3 of A - { - // load A - const int8x16_t av_row2 = vld1q_s8(QuantADataPtrRow2 + 0); - const int8x16_t av_row3 = vld1q_s8(QuantADataPtrRow3 + 0); - - // quantized dot product - const int32x4_t dot20 = vdotq_s32(int32x4_t{}, av_row2, bv_col0); - const int32x4_t dot21 = vdotq_s32(int32x4_t{}, av_row2, bv_col1); - const int32x4_t dot30 = vdotq_s32(int32x4_t{}, av_row3, bv_col0); - const int32x4_t dot31 = vdotq_s32(int32x4_t{}, av_row3, bv_col1); - - // convert to float - const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); - const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); - const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); - const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); - - // multiply by scale and update accumulator - acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); - acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); - acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); - acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); - } - } - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBDataPtr += 8; - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); - SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); - SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); - SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); - SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); - SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); - SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); - SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); - - if (BiasPtr != nullptr) { - SumPtr[ldc * 0 + 0] += BiasPtr[0]; - SumPtr[ldc * 0 + 1] += BiasPtr[1]; - SumPtr[ldc * 1 + 0] += BiasPtr[0]; - SumPtr[ldc * 1 + 1] += BiasPtr[1]; - SumPtr[ldc * 2 + 0] += BiasPtr[0]; - SumPtr[ldc * 2 + 1] += BiasPtr[1]; - SumPtr[ldc * 3 + 0] += BiasPtr[0]; - SumPtr[ldc * 3 + 1] += BiasPtr[1]; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK, - size_t StrideQuantA, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - size_t ldc -) -{ - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - const std::byte* QuantABlkRow0 = QuantAPtr; - const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; - const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; - const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; - - const float QuantBScaleCol0 = *QuantBScalePtr; - const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); - - // compute combined scales - const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; - const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; - const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; - const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; - const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; - const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; - const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; - const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; - - // load B zero point - int8_t bzp_col0; - int8_t bzp_col1; - if constexpr (HasZeroPoint) { - const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; - const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint); - if ((k_blk_idx & 1) == 0) { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); - } else { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); - } - } else { - bzp_col0 = 8; - bzp_col1 = 8; - } - - const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); - const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); - const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); - const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { - // load B - const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16)); - int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4)); - int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16)); - int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4)); - - // subtract B zero point - bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0)); - bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0)); - bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); - bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); - - // rows 0 and 1 of A - { - // load A - const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); - const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); - const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); - const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); - - // quantized dot product - const int32x4_t dot00 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); - const int32x4_t dot01 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); - const int32x4_t dot10 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); - const int32x4_t dot11 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); - - // convert to float - const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); - const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); - const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); - const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); - - // multiply by scale and update accumulator - acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); - acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); - acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); - acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); - } - - // rows 2 and 3 of A - { - // load A - const int8x16_t av_row2_0 = vld1q_s8(QuantADataPtrRow2 + 0); - const int8x16_t av_row2_1 = vld1q_s8(QuantADataPtrRow2 + 16); - const int8x16_t av_row3_0 = vld1q_s8(QuantADataPtrRow3 + 0); - const int8x16_t av_row3_1 = vld1q_s8(QuantADataPtrRow3 + 16); - - // quantized dot product - const int32x4_t dot20 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col0_0), av_row2_1, bv_col0_1); - const int32x4_t dot21 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col1_0), av_row2_1, bv_col1_1); - const int32x4_t dot30 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col0_0), av_row3_1, bv_col0_1); - const int32x4_t dot31 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col1_0), av_row3_1, bv_col1_1); - - // convert to float - const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); - const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); - const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); - const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); - - // multiply by scale and update accumulator - acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); - acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); - acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); - acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); - } - - // increment block data pointers to next sub-block - QuantADataPtrRow0 += 32; - QuantADataPtrRow1 += 32; - QuantADataPtrRow2 += 32; - QuantADataPtrRow3 += 32; - QuantBDataPtr += 16; - } - - // increment other block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); - SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); - SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); - SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); - SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); - SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); - SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); - SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); - - if (BiasPtr != nullptr) { - SumPtr[ldc * 0 + 0] += BiasPtr[0]; - SumPtr[ldc * 0 + 1] += BiasPtr[1]; - SumPtr[ldc * 1 + 0] += BiasPtr[0]; - SumPtr[ldc * 1 + 1] += BiasPtr[1]; - SumPtr[ldc * 2 + 0] += BiasPtr[0]; - SumPtr[ldc * 2 + 1] += BiasPtr[1]; - SumPtr[ldc * 3 + 0] += BiasPtr[0]; - SumPtr[ldc * 3 + 1] += BiasPtr[1]; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK -) -{ - constexpr size_t BlkLen = 16; - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1)); - - // load B - const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16); - const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01))); - int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01))); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - bv1 = vsubq_s8(bv1, bzp1); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); - const int32x4_t dot1 = vdotq_s32(int32x4_t{}, av1, bv1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 8 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0)); - - // load B - const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); - - const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); - - const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8); - const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4); - - int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp0); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK -) -{ - constexpr size_t BlkLen = 32; - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - const int8x16_t bzp1 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) >> 4) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1)); - const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - bv_lo1 = vsubq_s8(bv_lo1, bzp1); - bv_hi1 = vsubq_s8(bv_hi1, bzp1); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); - const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo1, bv_lo1), av_hi1, bv_hi1); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale1); - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen) * 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr)); - - // load B zero point - const int8x16_t bzp0 = vdupq_n_s8( - HasZeroPoint ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8 - ); - - // load A - const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0)); - const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - - int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - - // subtract B zero point - bv_lo0 = vsubq_s8(bv_lo0, bzp0); - bv_hi0 = vsubq_s8(bv_hi0, bzp0); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale0); - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } -} - -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK -) -{ - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc0{}, acc1{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - const std::byte* QuantABlk0 = QuantAPtr; - - // compute combined scale - const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]); - - // load B zero point - const int8x16_t bzp = [&]() -> int8x16_t { - if constexpr (HasZeroPoint) { - return vdupq_n_s8( - ((k_blk_idx & 1) == 0) ? std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}) - : std::to_integer((*QuantBZeroPointPtr) >> 4) - ); - } else { - return vdupq_n_s8(8); - } - }(); - - const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) { - // load A - const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0); - const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16); - const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32); - const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48); - - // load B - const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + 16); - - int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16)); - int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4)); - int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16)); - int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4)); - - // subtract B zero point - bv0 = vsubq_s8(bv0, bzp); - bv1 = vsubq_s8(bv1, bzp); - bv2 = vsubq_s8(bv2, bzp); - bv3 = vsubq_s8(bv3, bzp); - - // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av0, bv0), av1, bv1); - const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av2, bv2), av3, bv3); - - // convert to float - const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); - const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1); - - // multiply by scale and update accumulator - acc0 = vfmaq_f32(acc0, dot_f32_0, scale); - acc1 = vfmaq_f32(acc1, dot_f32_1, scale); - - // increment block data pointers to next sub-block - QuantADataPtr += 16 * 4; - QuantBDataPtr += 16 * 2; - } - - // increment block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } -} - -template -MLAS_FORCEINLINE void -AdvanceColPtrs( - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - const std::byte*& QuantBDataColPtr, - const float*& QuantBScaleColPtr, - const std::byte*& QuantBZeroPointColPtr, - const float*& BiasPtr, - float*& SumPtr -) -{ - QuantBDataColPtr += NumCols * StrideQuantBData; - QuantBScaleColPtr += NumCols * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NumCols * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NumCols : 0; - SumPtr += NumCols; -} - -template -MLAS_FORCEINLINE void -AdvanceRowPtrs( - size_t StrideQuantA, - size_t ldc, - const std::byte*& QuantARowPtr, - float*& SumRowPtr -) -{ - QuantARowPtr += NumRows * StrideQuantA; - SumRowPtr += NumRows * ldc; -} - -template -void -SQ4BitGemmKernel_CompInt8_BlkLen16( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 16; - - const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const std::byte* QuantARowPtr = QuantA; - - float* SumRowPtr = C; - - size_t m_remaining = CountM; - while (m_remaining > 3) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 1) { - // Compute 4x2 tiles of output - SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK, - StrideQuantA, - StrideQuantBData, - StrideQuantBScale, - StrideQuantBZeroPoint, - ldc - ); - - // Move to next 2 columns - AdvanceColPtrs<2, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 2; - } - - if (n_remaining > 0) { - // Compute last 4x1 tile of output - for (size_t i = 0; i < 4; ++i) { - SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( - QuantARowPtr + StrideQuantA * i, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc * i, - BlockCountK - ); - } - } - - // Move to next 4 rows - AdvanceRowPtrs<4>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 4; - } - - while (m_remaining > 0) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 0) { - // Compute 1x1 tiles of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - // Move to next column - AdvanceColPtrs<1, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 1; - } - - // Move to next row - AdvanceRowPtrs<1>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 1; - } -} - -template -void -SQ4BitGemmKernel_CompInt8_BlkLen32( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const std::byte* QuantARowPtr = QuantA; - - float* SumRowPtr = C; - - size_t m_remaining = CountM; - while (m_remaining > 3) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 1) { - // Compute 4x2 tiles of output - SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( - BlkLen, - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK, - StrideQuantA, - StrideQuantBData, - StrideQuantBScale, - StrideQuantBZeroPoint, - ldc - ); - - // Move to next 2 columns - AdvanceColPtrs<2, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 2; - } - - if (n_remaining > 0) { - // Compute last 4x1 tile of output - for (size_t i = 0; i < 4; ++i) { - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr + StrideQuantA * i, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc * i, - BlockCountK - ); - } - } - - // Move to next 4 rows - AdvanceRowPtrs<4>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 4; - } - - while (m_remaining > 0) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 0) { - // Compute 1x1 tiles of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - // Move to next column - AdvanceColPtrs<1, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 1; - } - - // Move to next row - AdvanceRowPtrs<1>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 1; - } -} - -template -void -SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - - const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const std::byte* QuantARowPtr = QuantA; - - float* SumRowPtr = C; - - size_t m_remaining = CountM; - while (m_remaining > 3) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 1) { - // Compute 4x2 tiles of output - SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( - BlkLen, - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK, - StrideQuantA, - StrideQuantBData, - StrideQuantBScale, - StrideQuantBZeroPoint, - ldc - ); - - // Move to next 2 columns - AdvanceColPtrs<2, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 2; - } - - if (n_remaining > 0) { - // Compute last 4x1 tile of output - for (size_t i = 0; i < 4; ++i) { - SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( - BlkLen, - QuantARowPtr + StrideQuantA * i, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc * i, - BlockCountK - ); - } - } - - // Move to next 4 rows - AdvanceRowPtrs<4>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 4; - } - - while (m_remaining > 0) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 0) { - // Compute 1x1 tiles of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( - BlkLen, - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - // Move to next column - AdvanceColPtrs<1, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 1; - } - - // Move to next row - AdvanceRowPtrs<1>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 1; - } -} - -template -void -SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - if (BlkLen == 16) { - SQ4BitGemmKernel_CompInt8_BlkLen16( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - BlockCountK, - ldc, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmKernel_CompInt8_BlkLen32( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - BlockCountK, - ldc, - Bias - ); - } else { - SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - BlockCountK, - ldc, - Bias - ); - } -} - -} // namespace - -size_t -SQ4BitGemmKernel_CompInt8( - size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - BlockCountK, - ldc, - Bias - ); - } else { - constexpr bool HasZeroPoint = false; - SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - BlockCountK, - ldc, - Bias - ); - } - - return CountM; -} - -} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h deleted file mode 100644 index 45c3963365e6b..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ /dev/null @@ -1,759 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk1_zp_avx2( - const __m256i& av_32_epi8, - const std::byte* QuantBDataPtr, - const float& combined_scale, - const std::byte* QuantBZeroPointPtr, - __m256& acc, - const __m256i& low_mask -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); - bv_32_epi8 = _mm256_and_si256(low_mask, bv_32_epi8); - - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); - } else { -#endif - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_zp_avx2( - const __m256i& av0_32_epi8, - const __m256i& av1_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - const std::byte* QuantBZeroPointPtr, - __m256& acc0, - const __m256i& low_mask -) -{ - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - { - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - - { - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - } else { -#endif - { - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - - { - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_zp_is_8_avx2( - const __m256i& av0_32_epi8, - const __m256i& av1_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m256& acc0, - const __m256i& low_mask, - const __m256i& bzp8 -) -{ - // accumulate_blklen32_r1c1blk2_zp_is_8_avx2 is much faster than - // accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2: - // BlkBitWidth:4/BlkLen:32/M:1/N:2560/K:2560/Threads:8/Symmetric:1/HasBias:0/ComputeType:4 - // 36591 vs 40270 ns (the main is 51836 ns). both are not as good as main with genai. - // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 - - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - } else { -#endif - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps( - _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) - ); - - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( - const __m256i& av0_32_epi8, - const __m256i& av1_32_epi8, - const __m256& scale_a0_8_ps, - const __m256& scale_a1_8_ps, - const std::byte* QuantBDataPtr, - const float* scale_b, - __m256& acc0, - const __m256i& low_mask, - const __m256i& bzp8 -) -{ - // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - -#if !defined(__GNUC__) || (__GNUC__ > 10) - if constexpr (vnni) { - { - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - { - __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - } else { -#endif - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } -#if !defined(__GNUC__) || (__GNUC__ > 10) - } -#endif -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmM1C4BlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; - - assert(CountN % NCols4 == 0); - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C; - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); - const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale2 = 2; - const size_t StrideQuantBScale1 = 1; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); - //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); - - //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc[0], low_mask, bzp8); - //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale, acc[1], low_mask, bzp8); - //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); - //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); - if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); - - } else { - const __m256i bzp8 = _mm256_set1_epi8(8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); - } - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; - QuantBScalePtr += PerAccuBlk2 * NCols4; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a00 = *QuantAScalePtr; - { - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); - } - { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); - } - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBDataCol; - QuantBScaleColPtr += NCols4 * BlockCountK; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmM1C1BlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk2 = 2; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountN < NCols4); - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C; - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); - //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); - - if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); - } else { - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); - } - - // increment block pointers - QuantAPtr += BlkLen32 * PerAccuBlk2; - QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } - } - - // TODO: use a loop in case PerAccuBlk2 is not 2. - if (k_blks_remaining > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const float& scale_a00 = *QuantAScalePtr; - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -MLAS_FORCEINLINE -void -MlasQ4Int8GemmM1KernelBlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias - ) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleCols > 0) { - Q4Int8GemmM1C4BlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - multipleCols, - BlockCountK, - Bias); - } - - if (remainingCols > 0) { - Q4Int8GemmM1C1BlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleCols, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr); - } -} - -//#define SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout 1 -void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - // port from neon implementation - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout -#else - constexpr bool HasZeroPoint = false; -#endif - - float* CRowPtr = C; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - //const size_t StrideQuantBScale = BlockCountK; - const float* BiasPtr = Bias; - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - float* SumPtr = CRowPtr; - -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - const __m256i low_mask = _mm256_set1_epi8(0x0F); - const __m256i bzp8 = _mm256_set1_epi8(8); - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); - (void)StrideQuantBZeroPoint; -#else - const __m256i zero = _mm256_setzero_si256(); - const __m128i low_mask = _mm_set1_epi8(0xF); -#endif - const size_t NCols = 4; - constexpr size_t StrideQuantBScale2 = 2; - constexpr size_t StrideQuantBScale1 = 1; - - int64_t nblk = (int64_t)(CountN)-4; - while (nblk >= 0) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - (void)QuantBZeroPointPtr; -#endif - __m256 - acc0 = _mm256_setzero_ps(), - acc1 = _mm256_setzero_ps(), - acc2 = _mm256_setzero_ps(), - acc3 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); - const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen))); -#else - const float& scale_a0 = QuantAScalePtr[0]; - const float& scale_a1 = QuantAScalePtr[1]; -#endif - - // Col0 -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); -#else - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); -#endif - - // Col1 -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale2, acc1, low_mask, bzp8); -#else - const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale2)[0]; - const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale2)[1]; - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); - accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); -#endif - - // Col2 -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale2, acc2, low_mask, bzp8); -#else - const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale2)[0]; - const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale2)[1]; - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); - accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); -#endif - // Col3 -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale2, acc3, low_mask, bzp8); -#else - const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale2)[0]; - const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale2)[1]; - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); - accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); -#endif - // increment block pointers - QuantAPtr += BlkLen * 2; - QuantAScalePtr += 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2 * NCols; - } - - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a0 = *QuantAScalePtr; - - // Col0 - const float& scale_0 = scale_a0 * QuantBScalePtr[0]; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_0, acc0, low_mask, bzp8); -#else - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_0, acc0); -#endif - - // Col1 - const float& scale_1 = scale_a0 * (QuantBScalePtr + StrideQuantBScale1)[0]; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + StrideQuantBData, scale_1, acc1, low_mask, bzp8); -#else - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_1, acc1); -#endif - - // Col2 - const float& scale_2 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_2, acc2, low_mask, bzp8); -#else - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_2, acc2); -#endif - - // Col3 - const float& scale_3 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_3, acc3, low_mask, bzp8); -#else - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_3, acc3); -#endif - } - - __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); - if (BiasPtr != nullptr) { - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); - } - _mm_storeu_ps(SumPtr, acc_x); - - // move to next NCols columns - - QuantBDataColPtr += NCols * StrideQuantBData; - QuantBScaleColPtr += NCols * BlockCountK; - - BiasPtr += BiasPtr != nullptr ? NCols : 0; - SumPtr += NCols; - nblk -= NCols; - } - - nblk += NCols; - for (int64_t n = 0; n < nblk; n++) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - (void)QuantBZeroPointPtr; -#endif - __m256 acc0 = _mm256_setzero_ps(); - - size_t k_blks_remaining = BlockCountK; - for (; k_blks_remaining > 1; k_blks_remaining -= 2) { - const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; - - // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk0)); - const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk1)); -#else - const float& scale_a0 = QuantAScalePtr[0]; - const float& scale_a1 = QuantAScalePtr[1]; -#endif - - // Col0 -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); -#else - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; - const float& scale_01 = scale_a1 * QuantBScalePtr[1]; - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); - accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); -#endif - // increment block pointers - QuantAPtr += BlkLen * 2; - QuantAScalePtr += 2; - QuantBDataPtr += 16 * 2; - QuantBScalePtr += 2; - } - - if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - - const float& scale_a0 = *QuantAScalePtr; - - // Col0 - const float& scale_00 = scale_a0 * QuantBScalePtr[0]; -#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout - accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_00, acc0, low_mask, bzp8); -#else - accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); -#endif - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += BlockCountK; - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h deleted file mode 100644 index e9c3812bde899..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ /dev/null @@ -1,312 +0,0 @@ -#pragma once -#include -#include -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_avx_common.h" - - -static MLAS_FORCEINLINE void -accumulate_blklen64_r1c1blk1_zp_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - const std::byte* QuantBZeroPointPtr, - const bool is_lower_half_byte_zp, - __m256& acc0, - const __m256i& low_mask -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - - const __m256i bzp8 = _mm256_set1_epi8(get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr)); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -} - -static MLAS_FORCEINLINE void -accumulate_blklen64_r1c1blk1_zp_is_8_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float* scale_a, - const float* scale_b, - __m256& acc0, - const __m256i& low_mask, - const __m256i& bzp8 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmM1C4BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t SubblkLen64 = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen64; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - - assert(CountN % NCols4 == 0); - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C; - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - const size_t StrideQuantBData1 = 1 * SubblkDataSizeInBytes; - const size_t StrideQuantBScale1 = 1; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; - for (size_t k = 0; k < BlockCountK; ++k) { - [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - if constexpr (HasZeroPoint) { - accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc[0], low_mask); - accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, QuantBZeroPointPtr + StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[1], low_mask); - accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[2], low_mask); - accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[3], low_mask); - } else { - const __m256i bzp8 = _mm256_set1_epi8(8); - accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); - accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, acc[1], low_mask, bzp8); - accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, acc[2], low_mask, bzp8); - accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, acc[3], low_mask, bzp8); - } - - // increment block pointers - QuantAPtr += SubblkLen64; - QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr += NCols4; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += k % 2; - } - } - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - if (BiasPtr != nullptr) { - acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); - } - - _mm_storeu_ps(SumPtr, acc_r0); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } -} - -template -MLAS_FORCEINLINE void -Q4Int8GemmM1C1BlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias) -{ - constexpr size_t BlkBitWidth4 = 4; - [[maybe_unused]] constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t SubblkLen = 64; - - const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t PerBlkSubblkCount = BlkLen / SubblkLen; - const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - assert(CountN < NCols4); - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); - - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C; - - for (size_t n = 0; n < CountN; n++) { - const std::byte* QuantAPtr = QuantA; - const float* QuantAScalePtr = QuantAScale; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc0 = _mm256_setzero_ps(); - for (size_t k = 0; k < BlockCountK; ++k) { - [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - if constexpr (HasZeroPoint) { - accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); - } else { - accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); - } - - // increment block pointers - QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; - } - QuantAScalePtr++; - QuantBScalePtr++; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += k % 2; - } - } - - *SumPtr = hsum_float_8(acc0); - if (BiasPtr) { - *SumPtr += *BiasPtr; - } - - // move to next column - QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } -} - -template -MLAS_FORCEINLINE void -MlasQ4Int8GemmKernelBlkLen64Avx2( - const size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias -) -{ - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - size_t remainingCols = CountN % NCols4; - size_t multipleCols = CountN - remainingCols; - - if (multipleCols > 0) { - Q4Int8GemmM1C4BlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - multipleCols, - BlockCountK, - Bias); - } - - if (remainingCols > 0) { - Q4Int8GemmM1C1BlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleCols, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr); - } -} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h b/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h deleted file mode 100644 index 80af2f46790df..0000000000000 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h +++ /dev/null @@ -1,70 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_q8_block.h - -Abstract: - - This module includes helper functions for manipulating blocks of quantized - int8 (Q8) values. - ---*/ - -#pragma once - -#include -#include -#include - -#include "mlasi.h" - -MLAS_FORCEINLINE -const float& -Q8BlkScale(const std::byte* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -MLAS_FORCEINLINE -float& -Q8BlkScale(std::byte* BlkPtr) -{ - return *reinterpret_cast(BlkPtr); -} - -MLAS_FORCEINLINE -const int8_t* -Q8BlkData(const std::byte* BlkPtr) -{ - return reinterpret_cast(BlkPtr + sizeof(float)); -} - -MLAS_FORCEINLINE -int8_t* -Q8BlkData(std::byte* BlkPtr) -{ - return reinterpret_cast(BlkPtr + sizeof(float)); -} - -MLAS_FORCEINLINE -constexpr size_t -Q8BlkSize(size_t BlkLen) -{ - const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(BlkSize % alignof(float) == 0); - return BlkSize; -} - -MLAS_FORCEINLINE -constexpr size_t -Q8BlkAlignment() -{ - return alignof(float); -} diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp deleted file mode 100644 index 9750337237b00..0000000000000 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ /dev/null @@ -1,184 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - tanh.cpp - -Abstract: - - This module implements routines to compute the hyperbolic tangent function. - - This implementation uses the same polynomial coefficients and algorithm as - found in Eigen. Our usage requires building platform specific versions of - the algorithm to target different instruction sets. The implementation below - targets the base instruction set (typically SSE2) while assembly - implementations target newer instruction sets (such as FMA3). - ---*/ - -#include "mlasi.h" - -// -// Bundles the floating point constants for use by kernels written in assembly. -// - -MLAS_INTERNAL_DATA const struct { - float LowerRange; - float UpperRange; - float alpha_13; - float alpha_11; - float alpha_9; - float alpha_7; - float alpha_5; - float alpha_3; - float alpha_1; - float beta_6; - float beta_4; - float beta_2; - float beta_0; -} MlasTanhConstants = { - -9.0f, - 9.0f, - -2.76076847742355e-16f, - 2.00018790482477e-13f, - -8.60467152213735e-11f, - 5.12229709037114e-08f, - 1.48572235717979e-05f, - 6.37261928875436e-04f, - 4.89352455891786e-03f, - 1.19825839466702e-06f, - 1.18534705686654e-04f, - 2.26843463243900e-03f, - 4.89352518554385e-03f, -}; - -void -MLASCALL -MlasTanhKernel( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine implements the generic kernel for the hyperbolic tangent function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ - while (N >= 4) { - - MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input); - - Value = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.LowerRange), Value); - Value = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.UpperRange), Value); - - MLAS_FLOAT32X4 ValueSquared = MlasMultiplyFloat32x4(Value, Value); - - MLAS_FLOAT32X4 p; - p = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_13), - MlasBroadcastFloat32x4(MlasTanhConstants.alpha_11)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_9)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_7)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_5)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_3)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_1)); - p = MlasMultiplyFloat32x4(p, Value); - - MLAS_FLOAT32X4 q; - q = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_6), - MlasBroadcastFloat32x4(MlasTanhConstants.beta_4)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_2)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_0)); - - MlasStoreFloat32x4(Output, MlasDivideFloat32x4(p, q)); - - Input += 4; - Output += 4; - N -= 4; - } - - while (N > 0) { - - float Value = *Input++; - - // This odd two-step process exists to ensure an input value of NaN carries through - // without modification because "std::min" and "std::max" return unreliable results - // when NaNs are involved, and it's clear from the test's reference outputs that - // they want a NaN on output whenever the input is a NaN. - float v_tmp; - v_tmp = (Value < MlasTanhConstants.LowerRange) ? MlasTanhConstants.LowerRange : Value; - Value = (v_tmp > MlasTanhConstants.UpperRange) ? MlasTanhConstants.UpperRange : v_tmp; - - float ValueSquared = Value * Value; - - float p; - p = ValueSquared * MlasTanhConstants.alpha_13 + MlasTanhConstants.alpha_11; - p = p * ValueSquared + MlasTanhConstants.alpha_9; - p = p * ValueSquared + MlasTanhConstants.alpha_7; - p = p * ValueSquared + MlasTanhConstants.alpha_5; - p = p * ValueSquared + MlasTanhConstants.alpha_3; - p = p * ValueSquared + MlasTanhConstants.alpha_1; - p = p * Value; - - float q; - q = ValueSquared * MlasTanhConstants.beta_6 + MlasTanhConstants.beta_4; - q = q * ValueSquared + MlasTanhConstants.beta_2; - q = q * ValueSquared + MlasTanhConstants.beta_0; - - *Output++ = (p / q); - - N -= 1; - } -} - -void -MLASCALL -MlasComputeTanh( - const float* Input, - float* Output, - size_t N - ) -/*++ - -Routine Description: - - This routine computes the hyperbolic tangent function. - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().TanhKernelRoutine(Input, Output, N); -#else - MlasTanhKernel(Input, Output, N); -#endif -} diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp deleted file mode 100644 index dc5daf998d3be..0000000000000 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ /dev/null @@ -1,133 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - threading.cpp - -Abstract: - - This module implements platform specific threading support. - ---*/ - -#include "mlasi.h" - -void -MlasExecuteThreaded( - MLAS_THREADED_ROUTINE* ThreadedRoutine, - void* Context, - ptrdiff_t Iterations, - MLAS_THREADPOOL* ThreadPool - ) -{ - // - // Execute the routine directly if only one iteration is specified. - // - - if (Iterations == 1) { - ThreadedRoutine(Context, 0); - return; - } - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - - // - // Fallback to OpenMP or a serialized implementation. - // - - // - // Execute the routine for the specified number of iterations. - // - for (ptrdiff_t tid = 0; tid < Iterations; tid++) { - ThreadedRoutine(Context, tid); - } -#else - // - // Schedule the threaded iterations using the thread pool object. - // - - MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, [&](ptrdiff_t tid) { - ThreadedRoutine(Context, tid); - }); -#endif -} - - -void -MlasTrySimpleParallel( - MLAS_THREADPOOL * ThreadPool, - const std::ptrdiff_t Iterations, - const std::function& Work) -{ - // - // Execute the routine directly if only one iteration is specified. - // - if (Iterations == 1) { - Work(0); - return; - } - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - - // - // Fallback to OpenMP or a serialized implementation. - // - - // - // Execute the routine for the specified number of iterations. - // - for (ptrdiff_t tid = 0; tid < Iterations; tid++) { - Work(tid); - } -#else - // - // Schedule the threaded iterations using the thread pool object. - // - - MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); -#endif -} - - -void -MlasTryBatchParallel( - MLAS_THREADPOOL * ThreadPool, - const std::ptrdiff_t Iterations, - const std::function& Work) -{ - // - // Execute the routine directly if only one iteration is specified. - // - if (Iterations == 1) { - Work(0); - return; - } - -#if defined(BUILD_MLAS_NO_ONNXRUNTIME) - MLAS_UNREFERENCED_PARAMETER(ThreadPool); - - // - // Fallback to OpenMP or a serialized implementation. - // - - // - // Execute the routine for the specified number of iterations. - // - for (ptrdiff_t tid = 0; tid < Iterations; tid++) { - Work(tid); - } -#else - // - // Schedule the threaded iterations using the thread pool object. - // - - MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); -#endif - -} \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp deleted file mode 100644 index a758a0e59fb4f..0000000000000 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ /dev/null @@ -1,928 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - transpose.cpp - -Abstract: - - This module implements the transpose operation. - ---*/ - -#include "mlasi.h" - -#if defined(MLAS_SSE2_INTRINSICS) - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint32_t* Input, - size_t InputStride, - uint32_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = _mm_loadu_si128((const __m128i*)&Input[InputStride * 0]); - __m128i a1 = _mm_loadu_si128((const __m128i*)&Input[InputStride * 1]); - __m128i a2 = _mm_loadu_si128((const __m128i*)&Input[InputStride * 2]); - __m128i a3 = _mm_loadu_si128((const __m128i*)&Input[InputStride * 3]); - - __m128i b0 = _mm_unpacklo_epi32(a0, a2); - __m128i b1 = _mm_unpackhi_epi32(a0, a2); - __m128i b2 = _mm_unpacklo_epi32(a1, a3); - __m128i b3 = _mm_unpackhi_epi32(a1, a3); - - __m128i c0 = _mm_unpacklo_epi32(b0, b2); - __m128i c1 = _mm_unpackhi_epi32(b0, b2); - __m128i c2 = _mm_unpacklo_epi32(b1, b3); - __m128i c3 = _mm_unpackhi_epi32(b1, b3); - - _mm_storeu_si128((__m128i*)&Output[OutputStride * 0], c0); - _mm_storeu_si128((__m128i*)&Output[OutputStride * 1], c1); - _mm_storeu_si128((__m128i*)&Output[OutputStride * 2], c2); - _mm_storeu_si128((__m128i*)&Output[OutputStride * 3], c3); -} - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint16_t* Input, - size_t InputStride, - uint16_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]); - __m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]); - __m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]); - __m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]); - - __m128i b0 = _mm_unpacklo_epi16(a0, a2); - __m128i b1 = _mm_unpacklo_epi16(a1, a3); - - __m128i c0 = _mm_unpacklo_epi16(b0, b1); - __m128i c1 = _mm_unpackhi_epi16(b0, b1); - - _mm_storel_pi((__m64*)&Output[OutputStride * 0], _mm_castsi128_ps(c0)); - _mm_storeh_pi((__m64*)&Output[OutputStride * 1], _mm_castsi128_ps(c0)); - _mm_storel_pi((__m64*)&Output[OutputStride * 2], _mm_castsi128_ps(c1)); - _mm_storeh_pi((__m64*)&Output[OutputStride * 3], _mm_castsi128_ps(c1)); -} - -MLAS_FORCEINLINE -void -MlasTranspose8x8Block( - const uint8_t* Input, - size_t InputStride, - uint8_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]); - __m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]); - __m128i b0 = _mm_unpacklo_epi8(a0, a1); - - __m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]); - __m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]); - __m128i b1 = _mm_unpacklo_epi8(a2, a3); - - __m128i a4 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 4]); - __m128i a5 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 5]); - __m128i b2 = _mm_unpacklo_epi8(a4, a5); - - __m128i a6 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 6]); - __m128i a7 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 7]); - __m128i b3 = _mm_unpacklo_epi8(a6, a7); - - __m128i c0 = _mm_unpacklo_epi16(b0, b1); - __m128i c1 = _mm_unpackhi_epi16(b0, b1); - __m128i c2 = _mm_unpacklo_epi16(b2, b3); - __m128i c3 = _mm_unpackhi_epi16(b2, b3); - - __m128 d0 = _mm_castsi128_ps(_mm_unpacklo_epi32(c0, c2)); - _mm_storel_pi((__m64*)&Output[OutputStride * 0], d0); - _mm_storeh_pi((__m64*)&Output[OutputStride * 1], d0); - - __m128 d1 = _mm_castsi128_ps(_mm_unpackhi_epi32(c0, c2)); - _mm_storel_pi((__m64*)&Output[OutputStride * 2], d1); - _mm_storeh_pi((__m64*)&Output[OutputStride * 3], d1); - - __m128 d2 = _mm_castsi128_ps(_mm_unpacklo_epi32(c1, c3)); - _mm_storel_pi((__m64*)&Output[OutputStride * 4], d2); - _mm_storeh_pi((__m64*)&Output[OutputStride * 5], d2); - - __m128 d3 = _mm_castsi128_ps(_mm_unpackhi_epi32(c1, c3)); - _mm_storel_pi((__m64*)&Output[OutputStride * 6], d3); - _mm_storeh_pi((__m64*)&Output[OutputStride * 7], d3); -} - -#elif defined(MLAS_NEON_INTRINSICS) - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint32_t* Input, - size_t InputStride, - uint32_t* Output, - size_t OutputStride - ) -{ - uint32x4_t a0 = vld1q_u32(&Input[InputStride * 0]); - uint32x4_t a1 = vld1q_u32(&Input[InputStride * 1]); - uint32x4_t a2 = vld1q_u32(&Input[InputStride * 2]); - uint32x4_t a3 = vld1q_u32(&Input[InputStride * 3]); - - uint32x4x2_t b0 = vzipq_u32(a0, a2); - uint32x4x2_t b1 = vzipq_u32(a1, a3); - - uint32x4x2_t c0 = vzipq_u32(b0.val[0], b1.val[0]); - uint32x4x2_t c1 = vzipq_u32(b0.val[1], b1.val[1]); - - vst1q_u32(&Output[OutputStride * 0], c0.val[0]); - vst1q_u32(&Output[OutputStride * 1], c0.val[1]); - vst1q_u32(&Output[OutputStride * 2], c1.val[0]); - vst1q_u32(&Output[OutputStride * 3], c1.val[1]); -} - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint16_t* Input, - size_t InputStride, - uint16_t* Output, - size_t OutputStride - ) -{ - uint16x4_t a0 = vld1_u16(&Input[InputStride * 0]); - uint16x4_t a1 = vld1_u16(&Input[InputStride * 1]); - uint16x4_t a2 = vld1_u16(&Input[InputStride * 2]); - uint16x4_t a3 = vld1_u16(&Input[InputStride * 3]); - - uint16x4x2_t b0 = vzip_u16(a0, a2); - uint16x4x2_t b1 = vzip_u16(a1, a3); - - uint16x4x2_t c0 = vzip_u16(b0.val[0], b1.val[0]); - uint16x4x2_t c1 = vzip_u16(b0.val[1], b1.val[1]); - - vst1_u16(&Output[OutputStride * 0], c0.val[0]); - vst1_u16(&Output[OutputStride * 1], c0.val[1]); - vst1_u16(&Output[OutputStride * 2], c1.val[0]); - vst1_u16(&Output[OutputStride * 3], c1.val[1]); -} - -MLAS_FORCEINLINE -void -MlasTranspose8x8Block( - const uint8_t* Input, - size_t InputStride, - uint8_t* Output, - size_t OutputStride - ) -{ - uint8x8_t a0 = vld1_u8(&Input[InputStride * 0]); - uint8x8_t a1 = vld1_u8(&Input[InputStride * 1]); - uint8x8x2_t b0 = vzip_u8(a0, a1); - - uint8x8_t a2 = vld1_u8(&Input[InputStride * 2]); - uint8x8_t a3 = vld1_u8(&Input[InputStride * 3]); - uint8x8x2_t b1 = vzip_u8(a2, a3); - - uint8x8_t a4 = vld1_u8(&Input[InputStride * 4]); - uint8x8_t a5 = vld1_u8(&Input[InputStride * 5]); - uint8x8x2_t b2 = vzip_u8(a4, a5); - - uint8x8_t a6 = vld1_u8(&Input[InputStride * 6]); - uint8x8_t a7 = vld1_u8(&Input[InputStride * 7]); - uint8x8x2_t b3 = vzip_u8(a6, a7); - - uint16x4x2_t c0 = vzip_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0])); - uint16x4x2_t c1 = vzip_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1])); - uint16x4x2_t c2 = vzip_u16(vreinterpret_u16_u8(b2.val[0]), vreinterpret_u16_u8(b3.val[0])); - uint16x4x2_t c3 = vzip_u16(vreinterpret_u16_u8(b2.val[1]), vreinterpret_u16_u8(b3.val[1])); - - uint32x2x2_t d0 = vzip_u32(vreinterpret_u32_u16(c0.val[0]), vreinterpret_u32_u16(c2.val[0])); - uint32x2x2_t d1 = vzip_u32(vreinterpret_u32_u16(c0.val[1]), vreinterpret_u32_u16(c2.val[1])); - uint32x2x2_t d2 = vzip_u32(vreinterpret_u32_u16(c1.val[0]), vreinterpret_u32_u16(c3.val[0])); - uint32x2x2_t d3 = vzip_u32(vreinterpret_u32_u16(c1.val[1]), vreinterpret_u32_u16(c3.val[1])); - - vst1_u8(&Output[OutputStride * 0], vreinterpret_u8_u32(d0.val[0])); - vst1_u8(&Output[OutputStride * 1], vreinterpret_u8_u32(d0.val[1])); - vst1_u8(&Output[OutputStride * 2], vreinterpret_u8_u32(d1.val[0])); - vst1_u8(&Output[OutputStride * 3], vreinterpret_u8_u32(d1.val[1])); - vst1_u8(&Output[OutputStride * 4], vreinterpret_u8_u32(d2.val[0])); - vst1_u8(&Output[OutputStride * 5], vreinterpret_u8_u32(d2.val[1])); - vst1_u8(&Output[OutputStride * 6], vreinterpret_u8_u32(d3.val[0])); - vst1_u8(&Output[OutputStride * 7], vreinterpret_u8_u32(d3.val[1])); -} - -#elif defined(MLAS_TARGET_POWER) - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint32_t* Input, - size_t InputStride, - uint32_t* Output, - size_t OutputStride - ) -{ - __vector unsigned int a0 = vec_vsx_ld(0, Input); - __vector unsigned int a1 = vec_vsx_ld(0, &Input[InputStride]); - __vector unsigned int a2 = vec_vsx_ld(0, &Input[InputStride * 2]); - __vector unsigned int a3 = vec_vsx_ld(0, &Input[InputStride * 3]); - - __vector unsigned int b0 = vec_mergeh(a0, a1); - __vector unsigned int b1 = vec_mergeh(a2, a3); - __vector unsigned int b2 = vec_mergel(a0, a1); - __vector unsigned int b3 = vec_mergel(a2, a3); - - __vector unsigned int c0 = vec_xxpermdi(b0, b1, 0); - __vector unsigned int c1 = vec_xxpermdi(b0, b1, 3); - __vector unsigned int c2 = vec_xxpermdi(b2, b3, 0); - __vector unsigned int c3 = vec_xxpermdi(b2, b3, 3); - - vec_vsx_st(c0, 0, Output); - vec_vsx_st(c1, 0, &Output[OutputStride]); - vec_vsx_st(c2, 0, &Output[OutputStride * 2]); - vec_vsx_st(c3, 0, &Output[OutputStride * 3]); -} - -MLAS_FORCEINLINE -void -MlasTranspose16x16Block( - const uint8_t* Input, - size_t InputStride, - uint8_t* Output, - size_t OutputStride - ) -{ - __vector unsigned char a0 = vec_vsx_ld(0, Input); - __vector unsigned char a1 = vec_vsx_ld(0, &Input[InputStride]); - __vector unsigned char a2 = vec_vsx_ld(0, &Input[InputStride * 2]); - __vector unsigned char a3 = vec_vsx_ld(0, &Input[InputStride * 3]); - __vector unsigned char a4 = vec_vsx_ld(0, &Input[InputStride * 4]); - __vector unsigned char a5 = vec_vsx_ld(0, &Input[InputStride * 5]); - __vector unsigned char a6 = vec_vsx_ld(0, &Input[InputStride * 6]); - __vector unsigned char a7 = vec_vsx_ld(0, &Input[InputStride * 7]); - __vector unsigned char a8 = vec_vsx_ld(0, &Input[InputStride * 8]); - __vector unsigned char a9 = vec_vsx_ld(0, &Input[InputStride * 9]); - __vector unsigned char a10 = vec_vsx_ld(0, &Input[InputStride * 10]); - __vector unsigned char a11 = vec_vsx_ld(0, &Input[InputStride * 11]); - __vector unsigned char a12 = vec_vsx_ld(0, &Input[InputStride * 12]); - __vector unsigned char a13 = vec_vsx_ld(0, &Input[InputStride * 13]); - __vector unsigned char a14 = vec_vsx_ld(0, &Input[InputStride * 14]); - __vector unsigned char a15 = vec_vsx_ld(0, &Input[InputStride * 15]); - - __vector unsigned char b0 = vec_mergeh(a0, a1); - __vector unsigned char b1 = vec_mergeh(a2, a3); - __vector unsigned char b2 = vec_mergeh(a4, a5); - __vector unsigned char b3 = vec_mergeh(a6, a7); - __vector unsigned char b4 = vec_mergeh(a8, a9); - __vector unsigned char b5 = vec_mergeh(a10, a11); - __vector unsigned char b6 = vec_mergeh(a12, a13); - __vector unsigned char b7 = vec_mergeh(a14, a15); - __vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); - __vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); - __vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); - __vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); - - __vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - __vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - __vector unsigned char e0 = vec_xxpermdi(d0, d1, 0); - __vector unsigned char e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[0]); - vec_vsx_st(e1, 0, &Output[OutputStride]); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 2]); - vec_vsx_st(e1, 0, &Output[OutputStride * 3]); - - c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); - c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); - c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); - c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 4]); - vec_vsx_st(e1, 0, &Output[OutputStride * 5]); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 6]); - vec_vsx_st(e1, 0, &Output[OutputStride * 7]); - - b0 = vec_mergel(a0, a1); - b1 = vec_mergel(a2, a3); - b2 = vec_mergel(a4, a5); - b3 = vec_mergel(a6, a7); - b4 = vec_mergel(a8, a9); - b5 = vec_mergel(a10, a11); - b6 = vec_mergel(a12, a13); - b7 = vec_mergel(a14, a15); - - c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); - c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); - c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); - c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 8]); - vec_vsx_st(e1, 0, &Output[OutputStride * 9]); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 10]); - vec_vsx_st(e1, 0, &Output[OutputStride * 11]); - - c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); - c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); - c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); - c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 12]); - vec_vsx_st(e1, 0, &Output[OutputStride * 13]); - - d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); - d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); - e0 = vec_xxpermdi(d0, d1, 0); - e1 = vec_xxpermdi(d0, d1, 3); - vec_vsx_st(e0, 0, &Output[OutputStride * 14]); - vec_vsx_st(e1, 0, &Output[OutputStride * 15]); -} - -#elif defined(MLAS_LSX_INTRINSICS) - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint32_t* Input, - size_t InputStride, - uint32_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); - __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); - __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); - __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); - - __m128i b0 = __lsx_vilvl_w(a2, a0); - __m128i b1 = __lsx_vilvh_w(a2, a0); - __m128i b2 = __lsx_vilvl_w(a3, a1); - __m128i b3 = __lsx_vilvh_w(a3, a1); - __m128i c0 = __lsx_vilvl_w(b2, b0); - __m128i c1 = __lsx_vilvh_w(b2, b0); - __m128i c2 = __lsx_vilvl_w(b3, b1); - __m128i c3 = __lsx_vilvh_w(b3, b1); - - __lsx_vst(c0, (__m128i*)&Output[OutputStride * 0], 0); - __lsx_vst(c1, (__m128i*)&Output[OutputStride * 1], 0); - __lsx_vst(c2, (__m128i*)&Output[OutputStride * 2], 0); - __lsx_vst(c3, (__m128i*)&Output[OutputStride * 3], 0); -} - -MLAS_FORCEINLINE -void -MlasTranspose4x4Block( - const uint16_t* Input, - size_t InputStride, - uint16_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); - __lsx_vinsgr2vr_d(a0, 0 , 1); - __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); - __lsx_vinsgr2vr_d(a1, 0 , 1); - __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); - __lsx_vinsgr2vr_d(a2, 0 , 1); - __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); - __lsx_vinsgr2vr_d(a3, 0 , 1); - - __m128i b0 = __lsx_vilvl_h(a2, a0); - __m128i b1 = __lsx_vilvl_h(a3, a1); - __m128i c0 = __lsx_vilvl_h(b1, b0); - __m128i c1 = __lsx_vilvh_h(b1, b0); - - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(c0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(c0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(c1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(c1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); -} - -MLAS_FORCEINLINE -void -MlasTranspose8x8Block( - const uint8_t* Input, - size_t InputStride, - uint8_t* Output, - size_t OutputStride - ) -{ - __m128i a0 = __lsx_vld((const __m128i*)&Input[InputStride * 0], 0); - __lsx_vinsgr2vr_d(a0, 0, 1); - __m128i a1 = __lsx_vld((const __m128i*)&Input[InputStride * 1], 0); - __lsx_vinsgr2vr_d(a1, 0, 1); - __m128i b0 = __lsx_vilvl_b(a1, a0); - - __m128i a2 = __lsx_vld((const __m128i*)&Input[InputStride * 2], 0); - __lsx_vinsgr2vr_d(a2, 0, 1); - __m128i a3 = __lsx_vld((const __m128i*)&Input[InputStride * 3], 0); - __lsx_vinsgr2vr_d(a3, 0, 1); - __m128i b1 = __lsx_vilvl_b(a3, a2); - - __m128i a4 = __lsx_vld((const __m128i*)&Input[InputStride * 4], 0); - __lsx_vinsgr2vr_d(a4, 0, 1); - __m128i a5 = __lsx_vld((const __m128i*)&Input[InputStride * 5], 0); - __lsx_vinsgr2vr_d(a5, 0, 1); - __m128i b2 = __lsx_vilvl_b(a5, a4); - - __m128i a6 = __lsx_vld((const __m128i*)&Input[InputStride * 6], 0); - __lsx_vinsgr2vr_d(a6, 0, 1); - __m128i a7 = __lsx_vld((const __m128i*)&Input[InputStride * 7], 0); - __lsx_vinsgr2vr_d(a7, 0, 1); - __m128i b3 = __lsx_vilvl_b(a7, a6); - __m128i c0 = __lsx_vilvl_h(b1, b0); - __m128i c1 = __lsx_vilvh_h(b1, b0); - __m128i c2 = __lsx_vilvl_h(b3, b2); - __m128i c3 = __lsx_vilvh_h(b3, b2); - - __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); - - __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); - - __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); - - __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); -} - -#endif - -template -MLAS_FORCEINLINE -void -MlasTranspose4xNVector( - const ElementType* Input, - size_t InputStride, - ElementType* Output, - size_t OutputStride - ) -{ - ElementType a0 = Input[InputStride * 0]; - ElementType a1 = Input[InputStride * 1]; - ElementType a2 = Input[InputStride * 2]; - ElementType a3 = Input[InputStride * 3]; - - Output[OutputStride * 0] = a0; - Output[OutputStride * 1] = a1; - Output[OutputStride * 2] = a2; - Output[OutputStride * 3] = a3; -} - -#if defined(MLAS_TARGET_POWER) -template -MLAS_FORCEINLINE -void -MlasTranspose16xNVector( - const ElementType* Input, - size_t InputStride, - ElementType* Output, - size_t OutputStride - ) -{ - MlasTranspose4xNVector(&Input[InputStride * 0], InputStride, &Output[OutputStride * 0], OutputStride); - MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); - MlasTranspose4xNVector(&Input[InputStride * 8], InputStride, &Output[OutputStride * 8], OutputStride); - MlasTranspose4xNVector(&Input[InputStride * 12], InputStride, &Output[OutputStride * 12], OutputStride); -} -#endif - -template -MLAS_FORCEINLINE -void -MlasTranspose8xNVector( - const ElementType* Input, - size_t InputStride, - ElementType* Output, - size_t OutputStride - ) -{ - MlasTranspose4xNVector(&Input[InputStride * 0], InputStride, &Output[OutputStride * 0], OutputStride); - MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); -} - -void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. - - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. - -Return Value: - - None. - ---*/ -{ - size_t n = N; - - // - // Transpose elements from the input matrix to the output matrix 4 columns - // at a time. - // - - while (n >= 4) { - - const uint32_t* s = Input; - uint32_t* d = Output; - size_t m = M; - -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_LSX_INTRINSICS) - - while (m >= 4) { - - MlasTranspose4x4Block(s, N, d, M); - - s += N * 4; - d += 4; - m -= 4; - } - -#endif - - while (m > 0) { - - MlasTranspose4xNVector(s, 1, d, M); - - s += N; - d += 1; - m -= 1; - } - - Input += 4; - Output += M * 4; - n -= 4; - } - - // - // Transpose elements from the input matrix to the output matrix for the - // remaining columns. - // - - while (n > 0) { - - const uint32_t* s = Input; - uint32_t* d = Output; - size_t m = M; - - while (m >= 4) { - - MlasTranspose4xNVector(s, N, d, 1); - - s += N * 4; - d += 4; - m -= 4; - } - - while (m > 0) { - - d[0] = s[0]; - - s += N; - d += 1; - m -= 1; - } - - Input += 1; - Output += M; - n -= 1; - } -} - -void -MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N - ) -{ - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, - N); -} - - -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. - - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. - -Return Value: - - None. - ---*/ -{ - size_t n = N; - - // - // Transpose elements from the input matrix to the output matrix 4 columns - // at a time. - // - - while (n >= 4) { - - const uint16_t* s = Input; - uint16_t* d = Output; - size_t m = M; - -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) - - while (m >= 4) { - - MlasTranspose4x4Block(s, N, d, M); - - s += N * 4; - d += 4; - m -= 4; - } - -#endif - - while (m > 0) { - - MlasTranspose4xNVector(s, 1, d, M); - - s += N; - d += 1; - m -= 1; - } - - Input += 4; - Output += M * 4; - n -= 4; - } - - // - // Transpose elements from the input matrix to the output matrix for the - // remaining columns. - // - - while (n > 0) { - - const uint16_t* s = Input; - uint16_t* d = Output; - size_t m = M; - - while (m >= 4) { - - MlasTranspose4xNVector(s, N, d, 1); - - s += N * 4; - d += 4; - m -= 4; - } - - while (m > 0) { - - d[0] = s[0]; - - s += N; - d += 1; - m -= 1; - } - - Input += 1; - Output += M; - n -= 1; - } -} - - -void -MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. - - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. - -Return Value: - - None. - ---*/ -{ - size_t n = N; - - // - // Transpose elements from the input matrix to the output matrix 8 columns - // at a time. - // -#if defined(MLAS_TARGET_POWER) - while (n >= 16) { - - const uint8_t* s = Input; - uint8_t* d = Output; - size_t m = M; - while (m >= 16) { - - MlasTranspose16x16Block(s, N, d, M); - - s += N * 16; - d += 16; - m -= 16; - } - - while (m > 0) { - - MlasTranspose16xNVector(s, 1, d, M); - - s += N; - d += 1; - m -= 1; - } - - Input += 16; - Output += M * 16; - n -= 16; - } -#endif - while (n >= 8) { - - const uint8_t* s = Input; - uint8_t* d = Output; - size_t m = M; - -#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) - - while (m >= 8) { - - MlasTranspose8x8Block(s, N, d, M); - - s += N * 8; - d += 8; - m -= 8; - } - -#endif - - while (m > 0) { - - MlasTranspose8xNVector(s, 1, d, M); - - s += N; - d += 1; - m -= 1; - } - - Input += 8; - Output += M * 8; - n -= 8; - } - - // - // Transpose elements from the input matrix to the output matrix for the - // remaining columns. - // - - while (n > 0) { - - const uint8_t* s = Input; - uint8_t* d = Output; - size_t m = M; - - while (m >= 8) { - - MlasTranspose8xNVector(s, N, d, 1); - - s += N * 8; - d += 8; - m -= 8; - } - - while (m > 0) { - - d[0] = s[0]; - - s += N; - d += 1; - m -= 1; - } - - Input += 1; - Output += M; - n -= 1; - } -} - -void -MLASCALL -MlasTranspose( - const int8_t* Input, - int8_t* Output, - size_t M, - size_t N) -{ - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, - N); -} diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp deleted file mode 100644 index 43a12b37e4ffa..0000000000000 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemmKernelWasmSimd.cpp +++ /dev/null @@ -1,540 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelWasmSimd.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - ---*/ - -#include "mlasi.h" - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - MLAS_FLOAT32X4 Row0Block0; - MLAS_FLOAT32X4 Row0Block1; - MLAS_FLOAT32X4 Row0Block2; - MLAS_FLOAT32X4 Row0Block3; - - MLAS_FLOAT32X4 Row1Block0; - MLAS_FLOAT32X4 Row1Block1; - MLAS_FLOAT32X4 Row1Block2; - MLAS_FLOAT32X4 Row1Block3; - -#if defined(_WIN32) - - if (!ProcessTwoRows) { - UNREFERENCED_PARAMETER(lda); - UNREFERENCED_PARAMETER(ldc); - } - -#endif - - MLAS_FLOAT32X4 Alpha = MlasBroadcastFloat32x4(alpha); - - do { - - MLAS_FLOAT32X4 BElements0; - MLAS_FLOAT32X4 BElements1; - MLAS_FLOAT32X4 BElements2; - MLAS_FLOAT32X4 BElements3; - - float Row0AElements0; - float Row0AElements1; - float Row1AElements0; - float Row1AElements1; - - // - // Clear the block accumulators. - // - - Row0Block0 = MlasZeroFloat32x4(); - Row0Block1 = MlasZeroFloat32x4(); - Row0Block2 = MlasZeroFloat32x4(); - Row0Block3 = MlasZeroFloat32x4(); - - if (ProcessTwoRows) { - Row1Block0 = MlasZeroFloat32x4(); - Row1Block1 = MlasZeroFloat32x4(); - Row1Block2 = MlasZeroFloat32x4(); - Row1Block3 = MlasZeroFloat32x4(); - } - - // - // Compute the 16x1 or 16x2 output block. - // - - const float* a = A; - size_t k = CountK; - - while (k >= 2) { - - Row0AElements0 = a[0]; - Row0AElements1 = a[1]; - - if (ProcessTwoRows) { - Row1AElements0 = a[lda]; - Row1AElements1 = a[lda + 1]; - } - - BElements0 = MlasLoadFloat32x4(B + 0); - BElements1 = MlasLoadFloat32x4(B + 4); - BElements2 = MlasLoadFloat32x4(B + 8); - BElements3 = MlasLoadFloat32x4(B + 12); - - Row0Block0 = MlasMultiplyAddFloat32x4(BElements0, Row0AElements0, Row0Block0); - Row0Block1 = MlasMultiplyAddFloat32x4(BElements1, Row0AElements0, Row0Block1); - Row0Block2 = MlasMultiplyAddFloat32x4(BElements2, Row0AElements0, Row0Block2); - Row0Block3 = MlasMultiplyAddFloat32x4(BElements3, Row0AElements0, Row0Block3); - - if (ProcessTwoRows) { - Row1Block0 = MlasMultiplyAddFloat32x4(BElements0, Row1AElements0, Row1Block0); - Row1Block1 = MlasMultiplyAddFloat32x4(BElements1, Row1AElements0, Row1Block1); - Row1Block2 = MlasMultiplyAddFloat32x4(BElements2, Row1AElements0, Row1Block2); - Row1Block3 = MlasMultiplyAddFloat32x4(BElements3, Row1AElements0, Row1Block3); - } - - BElements0 = MlasLoadFloat32x4(B + 16); - BElements1 = MlasLoadFloat32x4(B + 20); - BElements2 = MlasLoadFloat32x4(B + 24); - BElements3 = MlasLoadFloat32x4(B + 28); - - Row0Block0 = MlasMultiplyAddFloat32x4(BElements0, Row0AElements1, Row0Block0); - Row0Block1 = MlasMultiplyAddFloat32x4(BElements1, Row0AElements1, Row0Block1); - Row0Block2 = MlasMultiplyAddFloat32x4(BElements2, Row0AElements1, Row0Block2); - Row0Block3 = MlasMultiplyAddFloat32x4(BElements3, Row0AElements1, Row0Block3); - - if (ProcessTwoRows) { - Row1Block0 = MlasMultiplyAddFloat32x4(BElements0, Row1AElements1, Row1Block0); - Row1Block1 = MlasMultiplyAddFloat32x4(BElements1, Row1AElements1, Row1Block1); - Row1Block2 = MlasMultiplyAddFloat32x4(BElements2, Row1AElements1, Row1Block2); - Row1Block3 = MlasMultiplyAddFloat32x4(BElements3, Row1AElements1, Row1Block3); - } - - a += 2; - B += 32; - k -= 2; - } - - if (k > 0) { - - Row0AElements0 = a[0]; - - if (ProcessTwoRows) { - Row1AElements0 = a[lda]; - } - - BElements0 = MlasLoadFloat32x4(B + 0); - BElements1 = MlasLoadFloat32x4(B + 4); - BElements2 = MlasLoadFloat32x4(B + 8); - BElements3 = MlasLoadFloat32x4(B + 12); - - Row0Block0 = MlasMultiplyAddFloat32x4(BElements0, Row0AElements0, Row0Block0); - Row0Block1 = MlasMultiplyAddFloat32x4(BElements1, Row0AElements0, Row0Block1); - Row0Block2 = MlasMultiplyAddFloat32x4(BElements2, Row0AElements0, Row0Block2); - Row0Block3 = MlasMultiplyAddFloat32x4(BElements3, Row0AElements0, Row0Block3); - - if (ProcessTwoRows) { - Row1Block0 = MlasMultiplyAddFloat32x4(BElements0, Row1AElements0, Row1Block0); - Row1Block1 = MlasMultiplyAddFloat32x4(BElements1, Row1AElements0, Row1Block1); - Row1Block2 = MlasMultiplyAddFloat32x4(BElements2, Row1AElements0, Row1Block2); - Row1Block3 = MlasMultiplyAddFloat32x4(BElements3, Row1AElements0, Row1Block3); - } - - B += 16; - } - - // - // Multiply by the alpha value. - // - - Row0Block0 = MlasMultiplyFloat32x4(Row0Block0, Alpha); - Row0Block1 = MlasMultiplyFloat32x4(Row0Block1, Alpha); - Row0Block2 = MlasMultiplyFloat32x4(Row0Block2, Alpha); - Row0Block3 = MlasMultiplyFloat32x4(Row0Block3, Alpha); - - if (ProcessTwoRows) { - Row1Block0 = MlasMultiplyFloat32x4(Row1Block0, Alpha); - Row1Block1 = MlasMultiplyFloat32x4(Row1Block1, Alpha); - Row1Block2 = MlasMultiplyFloat32x4(Row1Block2, Alpha); - Row1Block3 = MlasMultiplyFloat32x4(Row1Block3, Alpha); - } - - if (CountN >= 16) { - - // - // Store the entire output block. - // - - if (!ZeroMode) { - Row0Block0 = MlasAddFloat32x4(Row0Block0, MlasLoadFloat32x4(C)); - Row0Block1 = MlasAddFloat32x4(Row0Block1, MlasLoadFloat32x4(C + 4)); - Row0Block2 = MlasAddFloat32x4(Row0Block2, MlasLoadFloat32x4(C + 8)); - Row0Block3 = MlasAddFloat32x4(Row0Block3, MlasLoadFloat32x4(C + 12)); - } - - MlasStoreFloat32x4(C, Row0Block0); - MlasStoreFloat32x4(C + 4, Row0Block1); - MlasStoreFloat32x4(C + 8, Row0Block2); - MlasStoreFloat32x4(C + 12, Row0Block3); - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = MlasAddFloat32x4(Row1Block0, MlasLoadFloat32x4(C + ldc)); - Row1Block1 = MlasAddFloat32x4(Row1Block1, MlasLoadFloat32x4(C + ldc + 4)); - Row1Block2 = MlasAddFloat32x4(Row1Block2, MlasLoadFloat32x4(C + ldc + 8)); - Row1Block3 = MlasAddFloat32x4(Row1Block3, MlasLoadFloat32x4(C + ldc + 12)); - } - - MlasStoreFloat32x4(C + ldc, Row1Block0); - MlasStoreFloat32x4(C + ldc + 4, Row1Block1); - MlasStoreFloat32x4(C + ldc + 8, Row1Block2); - MlasStoreFloat32x4(C + ldc + 12, Row1Block3); - } - - } else { - - // - // Store the partial output block. - // - - if ((CountN & 8) != 0) { - - if (!ZeroMode) { - Row0Block0 = MlasAddFloat32x4(Row0Block0, MlasLoadFloat32x4(C)); - Row0Block1 = MlasAddFloat32x4(Row0Block1, MlasLoadFloat32x4(C + 4)); - } - - MlasStoreFloat32x4(C, Row0Block0); - MlasStoreFloat32x4(C + 4, Row0Block1); - Row0Block0 = Row0Block2; - Row0Block1 = Row0Block3; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = MlasAddFloat32x4(Row1Block0, MlasLoadFloat32x4(C + ldc)); - Row1Block1 = MlasAddFloat32x4(Row1Block1, MlasLoadFloat32x4(C + ldc + 4)); - } - - MlasStoreFloat32x4(C + ldc, Row1Block0); - MlasStoreFloat32x4(C + ldc + 4, Row1Block1); - Row1Block0 = Row1Block2; - Row1Block1 = Row1Block3; - } - - C += 8; - } - - if ((CountN & 4) != 0) { - - if (!ZeroMode) { - Row0Block0 = MlasAddFloat32x4(Row0Block0, MlasLoadFloat32x4(C)); - } - - MlasStoreFloat32x4(C, Row0Block0); - Row0Block0 = Row0Block1; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block0 = MlasAddFloat32x4(Row1Block0, MlasLoadFloat32x4(C + ldc)); - } - - MlasStoreFloat32x4(C + ldc, Row1Block0); - Row1Block0 = Row1Block1; - } - - C += 4; - } - - float Row0Block00 = MlasExtractLaneFloat32x4<0>(Row0Block0); - float Row0Block01 = MlasExtractLaneFloat32x4<1>(Row0Block0); - float Row1Block00; - float Row1Block01; - - if (ProcessTwoRows) { - Row1Block00 = MlasExtractLaneFloat32x4<0>(Row1Block0); - Row1Block01 = MlasExtractLaneFloat32x4<1>(Row1Block0); - } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Row0Block00 = Row0Block00 + C[0]; - Row0Block01 = Row0Block01 + C[1]; - } - - - C[0] = Row0Block00; - C[1] = Row0Block01; - Row0Block00 = MlasExtractLaneFloat32x4<2>(Row0Block0); - Row0Block01 = MlasExtractLaneFloat32x4<3>(Row0Block0); - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block00 = Row1Block00 + C[ldc]; - Row1Block01 = Row1Block01 + C[ldc + 1]; - } - - C[ldc] = Row1Block00; - C[ldc + 1] = Row1Block01; - Row1Block00 = MlasExtractLaneFloat32x4<2>(Row1Block0); - Row1Block01 = MlasExtractLaneFloat32x4<3>(Row1Block0); - } - - C += 2; - } - - if ((CountN & 1) != 0) { - - if (!ZeroMode) { - Row0Block00 = Row0Block00 + C[0]; - } - - C[0] = Row0Block00; - - if (ProcessTwoRows) { - - if (!ZeroMode) { - Row1Block00 = Row1Block00 + C[ldc]; - } - - C[ldc] = Row1Block00; - } - } - - break; - } - - C += 16; - CountN -= 16; - - } while (CountN > 0); - - return ProcessTwoRows ? 2 : 1; -} - -template -size_t -MlasSgemmKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - size_t RowsHandled; - - if (CountM >= 2) { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } else { - RowsHandled = MlasSgemmKernel(A, B, C, CountK, CountN, lda, ldc, alpha); - } - - return RowsHandled; -} - -size_t -MLASCALL -MlasSgemmKernelZero( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} - -size_t -MLASCALL -MlasSgemmKernelAdd( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountM, - size_t CountN, - size_t lda, - size_t ldc, - float alpha - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - alpha - Supplies the scaler multiplier (see SGEMM definition). - -Return Value: - - Returns the number of rows handled. - ---*/ -{ - return MlasSgemmKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); -} diff --git a/onnxruntime/core/mlas/lib/wasm_simd/SgemvKernelWasmSimd.cpp b/onnxruntime/core/mlas/lib/wasm_simd/SgemvKernelWasmSimd.cpp deleted file mode 100644 index a46efd4093a34..0000000000000 --- a/onnxruntime/core/mlas/lib/wasm_simd/SgemvKernelWasmSimd.cpp +++ /dev/null @@ -1,158 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemvKernelWasmSimd.cpp - -Abstract: - - This module implements the kernels for the single precision matrix/vector - multiply operation (SGEMV). - ---*/ - -#include "mlasi.h" - -void -MLASCALL -MlasGemvFloatKernel( - const float* A, - const float* B, - float* C, - size_t CountK, - size_t CountN, - size_t ldb, - bool ZeroMode - ) -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. This handles the special case of M=1. - - The elements in matrix B are not transposed. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb - Supplies the first dimension of matrix B. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - None. - ---*/ -{ - if (ZeroMode && CountK > 0) { - float* c = C; - const float* b = B; - const MLAS_FLOAT32X4 A0 = MlasBroadcastFloat32x4(A); - auto N = CountN; - for (; N >= 4; N -= 4) { - MlasStoreFloat32x4(c, MlasMultiplyFloat32x4(A0, MlasLoadFloat32x4(b))); - b += 4; - c += 4; - } - for (; N > 0; N--) { - c[0] = A[0] * b[0]; - c++; - b++; - } - CountK--; - B += ldb; - A++; - } - - for (; CountK >= 4; CountK -= 4) { - float* c = C; - const float* b = B; - const float* b2 = b + ldb * 2; - - const MLAS_FLOAT32X4 A0 = MlasBroadcastFloat32x4(A); - const MLAS_FLOAT32X4 A1 = MlasBroadcastFloat32x4(A + 1); - const MLAS_FLOAT32X4 A2 = MlasBroadcastFloat32x4(A + 2); - const MLAS_FLOAT32X4 A3 = MlasBroadcastFloat32x4(A + 3); - - auto N = CountN; - constexpr size_t kWide = 8; - for(; N >= kWide; N -= kWide) { - MLAS_FLOAT32X4 vec_c0 = MlasMultiplyAddFloat32x4(A0, MlasLoadFloat32x4(b), MlasLoadFloat32x4(c)); - MLAS_FLOAT32X4 vec_c1 = MlasMultiplyAddFloat32x4(A0, MlasLoadFloat32x4(b + 4), MlasLoadFloat32x4(c + 4)); - - vec_c0 = MlasMultiplyAddFloat32x4(A1, MlasLoadFloat32x4(b + ldb), vec_c0); - vec_c1 = MlasMultiplyAddFloat32x4(A1, MlasLoadFloat32x4(b + ldb + 4), vec_c1); - - vec_c0 = MlasMultiplyAddFloat32x4(A2, MlasLoadFloat32x4(b2), vec_c0); - vec_c1 = MlasMultiplyAddFloat32x4(A2, MlasLoadFloat32x4(b2 + 4), vec_c1); - - vec_c0 = MlasMultiplyAddFloat32x4(A3, MlasLoadFloat32x4(b2 + ldb), vec_c0); - vec_c1 = MlasMultiplyAddFloat32x4(A3, MlasLoadFloat32x4(b2 + ldb + 4), vec_c1); - - MlasStoreFloat32x4(c, vec_c0); - MlasStoreFloat32x4(c + 4, vec_c1); - - b += kWide; - b2 += kWide; - c += kWide; - } - - for (; N >= 4; N -= 4) { - MLAS_FLOAT32X4 vec_c0 = MlasMultiplyAddFloat32x4(MlasLoadFloat32x4(b), A0, MlasLoadFloat32x4(c)); - vec_c0 = MlasMultiplyAddFloat32x4(MlasLoadFloat32x4(b + ldb), A1, vec_c0); - vec_c0 = MlasMultiplyAddFloat32x4(MlasLoadFloat32x4(b2), A2, vec_c0); - vec_c0 = MlasMultiplyAddFloat32x4(MlasLoadFloat32x4(b2 + ldb), A3, vec_c0); - MlasStoreFloat32x4(c, vec_c0); - b += 4; - b2 += 4; - c += 4; - } - - for (; N > 0; N--) { - c[0] += A[0] * b[0] + A[1] * b[ldb] + A[2] * b2[0] + A[3] * b2[ldb]; - b++; - b2++; - c++; - } - - B += 4 * ldb; - A += 4; - } - - for (; CountK > 0; CountK--) { - float* c = C; - const float* b = B; - const MLAS_FLOAT32X4 A0 = MlasBroadcastFloat32x4(A); - auto N = CountN; - for (; N >= 4; N -= 4) { - MlasStoreFloat32x4(c, MlasMultiplyAddFloat32x4(MlasLoadFloat32x4(b), A0, MlasLoadFloat32x4(c))); - b += 4; - c += 4; - } - for (; N > 0; N--) { - c[0] += A[0] * b[0]; - c++; - b++; - } - B += ldb; - A++; - } -} diff --git a/onnxruntime/core/mlas/lib/x86/SgemmKernelAvx.S b/onnxruntime/core/mlas/lib/x86/SgemmKernelAvx.S deleted file mode 100644 index 7af2a9e118e5e..0000000000000 --- a/onnxruntime/core/mlas/lib/x86/SgemmKernelAvx.S +++ /dev/null @@ -1,435 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelAvx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the SGEMM kernel. -// - - .equ .LSgemmKernelFrame_SavedEdi, 0 - .equ .LSgemmKernelFrame_SavedEsi, 4 - .equ .LSgemmKernelFrame_SavedEbx, 8 - .equ .LSgemmKernelFrame_SavedEbp, 12 - .equ .LSgemmKernelFrame_ReturnAddress, 16 - .equ .LSgemmKernelFrame_MatrixA, 20 - .equ .LSgemmKernelFrame_MatrixB, 24 - .equ .LSgemmKernelFrame_MatrixC, 28 - .equ .LSgemmKernelFrame_CountK, 32 - .equ .LSgemmKernelFrame_CountM, 36 - .equ .LSgemmKernelFrame_CountN, 40 - .equ .LSgemmKernelFrame_lda, 44 - .equ .LSgemmKernelFrame_ldc, 48 - .equ .LSgemmKernelFrame_alpha, 52 - .equ .LSgemmKernelFrame_ZeroMode, 56 - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 16xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - ebx - Supplies the length in bytes of a row from matrix A. - - ecx - Supplies the address into the matrix A data. - - edx - Supplies the address into the matrix B data. - - ymm4-ymm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxBy16 RowCount, VectorOffset, BroadcastOffset - -.if \RowCount\() == 1 - vbroadcastss ymm3,DWORD PTR [ecx+\BroadcastOffset\()] - vmulps ymm1,ymm3,YMMWORD PTR [edx+\VectorOffset\()] - vaddps ymm4,ymm1,ymm4 - vmulps ymm3,ymm3,YMMWORD PTR [edx+\VectorOffset\()+32] - vaddps ymm5,ymm3,ymm5 -.else - vmovaps ymm0,YMMWORD PTR [edx+\VectorOffset\()] - vmovaps ymm1,YMMWORD PTR [edx+\VectorOffset\()+32] - vbroadcastss ymm3,DWORD PTR [ecx+\BroadcastOffset\()] - vmulps ymm2,ymm3,ymm0 - vaddps ymm4,ymm2,ymm4 - vmulps ymm2,ymm3,ymm1 - vaddps ymm5,ymm2,ymm5 - vbroadcastss ymm3,DWORD PTR [ecx+ebx+\BroadcastOffset\()] - vmulps ymm2,ymm3,ymm0 - vaddps ymm6,ymm2,ymm6 - vmulps ymm2,ymm3,ymm1 - vaddps ymm7,ymm2,ymm7 -.endif - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 8xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - ebx - Supplies the length in bytes of a row from matrix A. - - ecx - Supplies the address into the matrix A data. - - edx - Supplies the address into the matrix B data. - - ymm4-ymm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxBy8 RowCount, VectorOffset, BroadcastOffset - -.if \RowCount\() == 1 - vbroadcastss ymm3,DWORD PTR [ecx+\BroadcastOffset\()] - vmulps ymm3,ymm3,YMMWORD PTR [edx+\VectorOffset\()] - vaddps ymm5,ymm3,ymm5 -.else - vmovaps ymm0,YMMWORD PTR [edx+\VectorOffset\()] - vbroadcastss ymm3,DWORD PTR [ecx+\BroadcastOffset\()] - vmulps ymm3,ymm3,ymm0 - vaddps ymm5,ymm3,ymm5 - vbroadcastss ymm3,DWORD PTR [ecx+ebx+\BroadcastOffset\()] - vmulps ymm3,ymm3,ymm0 - vaddps ymm7,ymm3,ymm7 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - ebx - Supplies the number of bytes to the next row of matrix A. - - ecx - Supplies the address into the matrix A data. - - edx - Supplies the address into the matrix B data. - - edi - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - ymm4-ymm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxLoop ComputeBlock, RowCount - - sub edi,4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - \ComputeBlock\() \RowCount\(), 0, 0 - \ComputeBlock\() \RowCount\(), 16*4, 4 - sub edx,-32*4 # advance matrix B by 32 columns - \ComputeBlock\() \RowCount\(), 0, 8 - \ComputeBlock\() \RowCount\(), 16*4, 12 - sub edx,-32*4 # advance matrix B by 32 columns - add ecx,4*4 # advance matrix A by 4 columns - sub edi,4 - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add edi,4 # correct for over-subtract above - jz .LOutputBlock\@ - -.LComputeBlockBy1Loop\@: - \ComputeBlock\() \RowCount\(), 0, 0 - add edx,16*4 # advance matrix B by 16 columns - add ecx,4 # advance matrix A by 1 column - dec edi - jne .LComputeBlockBy1Loop\@ - -.LOutputBlock\@: - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to iterate - over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha - Supplies the scalar multiplier (see SGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmFloatKernelAvx - - push ebp - push ebx - push esi - push edi - mov edx,.LSgemmKernelFrame_MatrixB[esp] - mov esi,.LSgemmKernelFrame_MatrixC[esp] - mov ebp,.LSgemmKernelFrame_CountN[esp] - -// -// Process 2 rows of the matrices. -// - - cmp DWORD PTR .LSgemmKernelFrame_CountM[esp],2 - jb .LProcessCountMLessThan2 - mov BYTE PTR .LSgemmKernelFrame_CountM[esp],2 - mov eax,.LSgemmKernelFrame_ldc[esp] - mov ebx,.LSgemmKernelFrame_lda[esp] - shl eax,2 # convert ldc to bytes - shl ebx,2 # convert lda to bytes - cmp ebp,8 - jbe .LProcessRemainingCountN2 - -.LProcessNextColumnLoop16x2: - mov edi,.LSgemmKernelFrame_CountK[esp] - mov ecx,.LSgemmKernelFrame_MatrixA[esp] - vxorps xmm4,xmm4,xmm4 # clear block accumulators - vxorps xmm5,xmm5,xmm5 - vxorps xmm6,xmm6,xmm6 - vxorps xmm7,xmm7,xmm7 - ComputeBlockAvxLoop ComputeBlockAvxBy16, 2 - vbroadcastss ymm2,DWORD PTR .LSgemmKernelFrame_alpha[esp] - vmulps ymm4,ymm4,ymm2 # multiply by alpha - vmulps ymm5,ymm5,ymm2 - vmulps ymm6,ymm6,ymm2 - vmulps ymm7,ymm7,ymm2 - sub ebp,16 - jb .LOutputMasked16x2Block - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateOutput16x2 - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm5,ymm5,YMMWORD PTR [esi+32] - vaddps ymm6,ymm6,YMMWORD PTR [esi+eax] - vaddps ymm7,ymm7,YMMWORD PTR [esi+eax+32] - -.LSkipAccumulateOutput16x2: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+32],ymm5 - vmovups YMMWORD PTR [esi+eax],ymm6 - vmovups YMMWORD PTR [esi+eax+32],ymm7 - add esi,16*4 # advance matrix C by 16 columns - cmp ebp,8 - ja .LProcessNextColumnLoop16x2 - test ebp,ebp - jz .LExitKernel - -.LProcessRemainingCountN2: - mov edi,.LSgemmKernelFrame_CountK[esp] - mov ecx,.LSgemmKernelFrame_MatrixA[esp] - vxorps xmm5,xmm5,xmm5 # clear block accumulators - vxorps xmm7,xmm7,xmm7 - ComputeBlockAvxLoop ComputeBlockAvxBy8, 2 - vbroadcastss ymm2,DWORD PTR .LSgemmKernelFrame_alpha[esp] - vmulps ymm5,ymm5,ymm2 # multiply by alpha - vmulps ymm7,ymm7,ymm2 - cmp ebp,8 - jb .LOutputMasked8x2Block - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateOutput8x2 - vaddps ymm5,ymm5,YMMWORD PTR [esi] - vaddps ymm7,ymm7,YMMWORD PTR [esi+eax] - -.LSkipAccumulateOutput8x2: - vmovups YMMWORD PTR [esi],ymm5 - vmovups YMMWORD PTR [esi+eax],ymm7 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - movzx eax,BYTE PTR .LSgemmKernelFrame_CountM[esp] - vzeroupper - pop edi - pop esi - pop ebx - pop ebp - ret - -.LOutputMasked16x2Block: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateMasked16x2Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm6,ymm6,YMMWORD PTR [esi+eax] - -.LSkipAccumulateMasked16x2Block: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+eax],ymm6 - add esi,8*4 # advance matrix C by 8 columns - add ebp,8 # correct for over-subtract above - -.LOutputMasked8x2Block: - neg ebp - LoadGlobalOffsetTable bx - mov ebx,DWORD PTR C_UNDERSCORE(MlasMaskMoveTableAvx)@GOT[ebx] - vmovdqu ymm0,YMMWORD PTR [ebx+ebp*4+8*4] - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateMasked8x2Block - vmaskmovps ymm4,ymm0,YMMWORD PTR [esi] - vmaskmovps ymm6,ymm0,YMMWORD PTR [esi+eax] - vaddps ymm5,ymm5,ymm4 - vaddps ymm7,ymm7,ymm6 - -.LSkipAccumulateMasked8x2Block: - vmaskmovps YMMWORD PTR [esi],ymm0,ymm5 - vmaskmovps YMMWORD PTR [esi+eax],ymm0,ymm7 - jmp .LExitKernel - -// -// Process 1 row of the matrices. -// - -.LProcessCountMLessThan2: - mov BYTE PTR .LSgemmKernelFrame_CountM[esp],1 - mov ebx,.LSgemmKernelFrame_MatrixA[esp] - vbroadcastss ymm2,DWORD PTR .LSgemmKernelFrame_alpha[esp] - cmp ebp,8 - jbe .LProcessRemainingCountN1 - -.LProcessNextColumnLoop16x1: - mov edi,.LSgemmKernelFrame_CountK[esp] - mov ecx,ebx # reload matrix A - vxorps xmm4,xmm4,xmm4 # clear block accumulators - vxorps xmm5,xmm5,xmm5 - ComputeBlockAvxLoop ComputeBlockAvxBy16, 1 - vmulps ymm4,ymm4,ymm2 # multiply by alpha - vmulps ymm5,ymm5,ymm2 - sub ebp,16 - jb .LOutputMasked16x1Block - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulate16x1Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - vaddps ymm5,ymm5,YMMWORD PTR [esi+32] - -.LSkipAccumulate16x1Block: - vmovups YMMWORD PTR [esi],ymm4 - vmovups YMMWORD PTR [esi+32],ymm5 - add esi,16*4 # advance matrix C by 16 columns - cmp ebp,8 - ja .LProcessNextColumnLoop16x1 - test ebp,ebp - jz .LExitKernel - -.LProcessRemainingCountN1: - mov edi,.LSgemmKernelFrame_CountK[esp] - mov ecx,ebx # reload matrix A - vxorps xmm5,xmm5,xmm5 # clear block accumulators - ComputeBlockAvxLoop ComputeBlockAvxBy8, 1 - vmulps ymm5,ymm5,ymm2 # multiply by alpha - cmp ebp,8 - jb .LOutputMasked8x1Block - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulate8x1Block - vaddps ymm5,ymm5,YMMWORD PTR [esi] - -.LSkipAccumulate8x1Block: - vmovups YMMWORD PTR [esi],ymm5 - jmp .LExitKernel - -.LOutputMasked16x1Block: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateMasked16x1Block - vaddps ymm4,ymm4,YMMWORD PTR [esi] - -.LSkipAccumulateMasked16x1Block: - vmovups YMMWORD PTR [esi],ymm4 - add esi,8*4 # advance matrix C by 8 columns - add ebp,8 # correct for over-subtract above - -.LOutputMasked8x1Block: - neg ebp - LoadGlobalOffsetTable bx - mov ebx,DWORD PTR C_UNDERSCORE(MlasMaskMoveTableAvx)@GOT[ebx] - vmovdqu ymm0,YMMWORD PTR [ebx+ebp*4+8*4] - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateMasked8x1Block - vmaskmovps ymm4,ymm0,YMMWORD PTR [esi] - vaddps ymm5,ymm5,ymm4 - -.LSkipAccumulateMasked8x1Block: - vmaskmovps YMMWORD PTR [esi],ymm0,ymm5 - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86/SgemmKernelSse2.S b/onnxruntime/core/mlas/lib/x86/SgemmKernelSse2.S deleted file mode 100644 index f42175ec6c5a5..0000000000000 --- a/onnxruntime/core/mlas/lib/x86/SgemmKernelSse2.S +++ /dev/null @@ -1,406 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelSse2.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the SGEMM kernel. -// - - .equ .LSgemmKernelFrame_SavedEdi, 0 - .equ .LSgemmKernelFrame_SavedEsi, 4 - .equ .LSgemmKernelFrame_SavedEbx, 8 - .equ .LSgemmKernelFrame_SavedEbp, 12 - .equ .LSgemmKernelFrame_ReturnAddress, 16 - .equ .LSgemmKernelFrame_MatrixA, 20 - .equ .LSgemmKernelFrame_MatrixB, 24 - .equ .LSgemmKernelFrame_MatrixC, 28 - .equ .LSgemmKernelFrame_CountK, 32 - .equ .LSgemmKernelFrame_CountM, 36 - .equ .LSgemmKernelFrame_CountN, 40 - .equ .LSgemmKernelFrame_lda, 44 - .equ .LSgemmKernelFrame_ldc, 48 - .equ .LSgemmKernelFrame_alpha, 52 - .equ .LSgemmKernelFrame_ZeroMode, 56 - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a Nx1 block of the output matrix. - -Arguments: - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - Shuffle - Supplies the shuffle mask to extract the element from matrix A. - -Implicit Arguments: - - ebx - Supplies the length in bytes of a row from matrix A. - - ecx - Supplies the address into the matrix A data. - - edx - Supplies the address into the matrix B data. - - xmm2 - Supplies up to four elements loaded from matrix A. - - xmm4-xmm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseBy4 VectorOffset, Shuffle - - pshufd xmm3,xmm1,\Shuffle\() - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()] - mulps xmm0,xmm3 - addps xmm4,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+16] - mulps xmm0,xmm3 - addps xmm5,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+32] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+48] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - .endm - - .macro ComputeBlockSseBy3 VectorOffset, Shuffle - - pshufd xmm3,xmm1,\Shuffle\() - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()] - mulps xmm0,xmm3 - addps xmm5,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+16] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+32] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - .endm - - .macro ComputeBlockSseBy2 VectorOffset, Shuffle - - pshufd xmm3,xmm1,\Shuffle\() - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()] - mulps xmm0,xmm3 - addps xmm6,xmm0 - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()+16] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - .endm - - .macro ComputeBlockSseBy1 VectorOffset, Shuffle - - pshufd xmm3,xmm1,\Shuffle\() - movaps xmm0,XMMWORD PTR [edx+\VectorOffset\()] - mulps xmm0,xmm3 - addps xmm7,xmm0 - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - ebx - Supplies the number of bytes to the next row of matrix A. - - ecx - Supplies the address into the matrix A data. - - edx - Supplies the address into the matrix B data. - - edi - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - xmm4-xmm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseLoop RowCount - - sub edi,4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - movups xmm1,XMMWORD PTR [ecx] - ComputeBlockSseBy\RowCount\() 0, 0x00 - ComputeBlockSseBy\RowCount\() 16*4, 0x55 - sub edx,-32*4 # advance matrix B by 32 columns - ComputeBlockSseBy\RowCount\() 0, 0xAA - ComputeBlockSseBy\RowCount\() 16*4, 0xFF - sub edx,-32*4 # advance matrix B by 32 columns - add ecx,4*4 # advance matrix A by 4 columns - sub edi,4 - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add edi,4 # correct for over-subtract above - jz .LOutputBlock\@ - -.LComputeBlockBy1Loop\@: - movss xmm1,DWORD PTR [ecx] - ComputeBlockSseBy\RowCount\() 0, 0x00 - add edx,16*4 # advance matrix B by 16 columns - add ecx,4 # advance matrix A by 1 column - dec edi - jne .LComputeBlockBy1Loop\@ - -.LOutputBlock\@: - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A - Supplies the address of matrix A. - - B - Supplies the address of matrix B. The matrix data has been packed using - MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C - Supplies the address of matrix C. - - CountK - Supplies the number of columns from matrix A and the number of - rows from matrix B to iterate over. - - CountM - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha - Supplies the scalar multiplier (see SGEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmFloatKernelSse - - push ebp - push ebx - push esi - push edi - mov edx,.LSgemmKernelFrame_MatrixB[esp] - mov esi,.LSgemmKernelFrame_MatrixC[esp] - mov ebp,.LSgemmKernelFrame_CountN[esp] - -// -// Process 1 row of the matrices. -// - - mov eax,.LSgemmKernelFrame_CountK[esp] - mov ebx,.LSgemmKernelFrame_MatrixA[esp] - cmp ebp,12 - jbe .LProcessRemainingCountN - -.LProcessNextColumnLoop16x1: - mov edi,eax # reload CountK - mov ecx,ebx # reload matrix A - xorps xmm4,xmm4 # clear block accumulators - xorps xmm5,xmm5 - xorps xmm6,xmm6 - xorps xmm7,xmm7 - ComputeBlockSseLoop 4 - movss xmm2,DWORD PTR .LSgemmKernelFrame_alpha[esp] - shufps xmm2,xmm2,0 - mulps xmm4,xmm2 # multiply by alpha - mulps xmm5,xmm2 - mulps xmm6,xmm2 - mulps xmm7,xmm2 - sub ebp,16 - jb .LOutputMasked16x1Block - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateOutput16x1 - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - movups xmm2,XMMWORD PTR [esi+32] - movups xmm3,XMMWORD PTR [esi+48] - addps xmm4,xmm0 - addps xmm5,xmm1 - addps xmm6,xmm2 - addps xmm7,xmm3 - -.LSkipAccumulateOutput16x1: - movups XMMWORD PTR [esi],xmm4 - movups XMMWORD PTR [esi+16],xmm5 - movups XMMWORD PTR [esi+32],xmm6 - movups XMMWORD PTR [esi+48],xmm7 - add esi,16*4 # advance matrix C by 16 columns - cmp ebp,12 - ja .LProcessNextColumnLoop16x1 - test ebp,ebp - jnz .LProcessRemainingCountN - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,1 # return 1 row handled - pop edi - pop esi - pop ebx - pop ebp - ret - -// -// Process the remaining 1 to 12 columns of the matrices. -// - -.LProcessRemainingCountN: - mov edi,eax # reload CountK - mov ecx,ebx # reload matrix A - movss xmm4,DWORD PTR .LSgemmKernelFrame_alpha[esp] - shufps xmm4,xmm4,0 - xorps xmm5,xmm5 # clear block accumulators - xorps xmm6,xmm6 - xorps xmm7,xmm7 - cmp ebp,4 - jbe .LProcessRemainingCountN4OrLess - cmp ebp,8 - jbe .LProcessRemainingCountN8OrLess - -.LProcessRemainingCountN12OrLess: - ComputeBlockSseLoop 3 - mulps xmm5,xmm4 # multiply by alpha - mulps xmm6,xmm4 - mulps xmm7,xmm4 - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateLeadingN12OrLess - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - addps xmm5,xmm0 - addps xmm6,xmm1 - -.LSkipAccumulateLeadingN12OrLess: - movups XMMWORD PTR [esi],xmm5 - movups XMMWORD PTR [esi+16],xmm6 - add esi,8*4 # advance matrix C by 8 columns - jmp .LOutputTrailingBlock - -.LProcessRemainingCountN8OrLess: - ComputeBlockSseLoop 2 - mulps xmm6,xmm4 # multiply by alpha - mulps xmm7,xmm4 - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateLeadingN8OrLess - movups xmm0,XMMWORD PTR [esi] - addps xmm6,xmm0 - -.LSkipAccumulateLeadingN8OrLess: - movups XMMWORD PTR [esi],xmm6 - add esi,4*4 # advance matrix C by 4 columns - jmp .LOutputTrailingBlock - -.LProcessRemainingCountN4OrLess: - ComputeBlockSseLoop 1 - mulps xmm7,xmm4 # multiply by alpha - jmp .LOutputTrailingBlock - -.LOutputMasked16x1Block: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateLeading16x1Block - movups xmm0,XMMWORD PTR [esi] - movups xmm1,XMMWORD PTR [esi+16] - movups xmm2,XMMWORD PTR [esi+32] - addps xmm4,xmm0 - addps xmm5,xmm1 - addps xmm6,xmm2 - -.LSkipAccumulateLeading16x1Block: - movups XMMWORD PTR [esi],xmm4 - movups XMMWORD PTR [esi+16],xmm5 - movups XMMWORD PTR [esi+32],xmm6 - add esi,12*4 # advance matrix C by 12 columns - -.LOutputTrailingBlock: - test ebp,3 - jz .LOutputTrailingBlock4Elements - test ebp,2 - jz .LOutputTrailingBlock1Element - -.LOutputTrailingBlock2Elements: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateTrailingBlock2Elements - movsd xmm0,QWORD PTR [esi] - addps xmm7,xmm0 - -.LSkipAccumulateTrailingBlock2Elements: - movsd QWORD PTR [esi],xmm7 - test ebp,1 - jz .LExitKernel - shufps xmm7,xmm7,0xAA # shuffle third float down - add esi,2*4 # advance matrix C by 2 columns - -.LOutputTrailingBlock1Element: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateTrailingBlock1Element - movss xmm0,DWORD PTR [esi] - addss xmm7,xmm0 - -.LSkipAccumulateTrailingBlock1Element: - movss DWORD PTR [esi],xmm7 - jmp .LExitKernel - -.LOutputTrailingBlock4Elements: - cmp BYTE PTR .LSgemmKernelFrame_ZeroMode[esp],0 - jnz .LSkipAccumulateTrailingBlock4Elements - movups xmm0,XMMWORD PTR [esi] - addps xmm7,xmm0 - -.LSkipAccumulateTrailingBlock4Elements: - movups XMMWORD PTR [esi],xmm7 - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86/asmmacro.h b/onnxruntime/core/mlas/lib/x86/asmmacro.h deleted file mode 100644 index 4b80eea735ba9..0000000000000 --- a/onnxruntime/core/mlas/lib/x86/asmmacro.h +++ /dev/null @@ -1,79 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - asmmacro.h - -Abstract: - - This module implements common macros for the assembly modules. - ---*/ - -#if defined(__APPLE__) -#define C_UNDERSCORE(symbol) _##symbol -#else -#define C_UNDERSCORE(symbol) symbol -#endif - -/*++ - -Macro Description: - - This macro emits the assembler directives to annotate a new function. - -Arguments: - - FunctionName - Supplies the name of the function. - ---*/ - - .macro FUNCTION_ENTRY FunctionName - - .p2align 4 -#if defined(__APPLE__) - .globl _\FunctionName\() -_\FunctionName\(): -#else - .globl \FunctionName\() - .type \FunctionName\(),@function -\FunctionName\(): -#endif - - .endm - -/*++ - -Macro Description: - - This macro emits the code to load the global offset table address into the - supplied register. - -Arguments: - - TargetReg - Specifies the target register. - ---*/ - - .macro LoadGlobalOffsetTable, TargetReg - -// -// The LLVM integrated assembler doesn't support the Intel syntax for OFFSET: -// -// add ebx,OFFSET _GLOBAL_OFFSET_TABLE_ -// -// Workaround this by temporarily switching to AT&T syntax. -// - - .att_syntax - - calll __x86.get_pc_thunk.\TargetReg\() - addl $_GLOBAL_OFFSET_TABLE_,%e\TargetReg\() - - .intel_syntax noprefix - - .endm diff --git a/onnxruntime/core/mlas/lib/x86/x86.get_pc_thunk.S b/onnxruntime/core/mlas/lib/x86/x86.get_pc_thunk.S deleted file mode 100644 index fd81ddb7e3737..0000000000000 --- a/onnxruntime/core/mlas/lib/x86/x86.get_pc_thunk.S +++ /dev/null @@ -1,33 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - x86.get_pc_thunk.S - -Abstract: - - This module implements __x86.get_pc_thunk.* to avoid external dependency. - ---*/ - - .intel_syntax noprefix - -/*++ - -Routine Description: - - The routine loads its return address -- which is the address of the - instruction that immediately follows -- into the ebx register. - ---*/ - - .p2align 4 - .weak __x86.get_pc_thunk.bx - .type __x86.get_pc_thunk.bx,@function -__x86.get_pc_thunk.bx: - mov ebx, [esp] - ret diff --git a/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h b/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h deleted file mode 100644 index f02fc3bca47e2..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/AssembleAvx512Vnni.h +++ /dev/null @@ -1,246 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - AssembleAvx512Vnni.h - -Abstract: - - This module contains macros to build VNNI instructions for toolchains that - do not natively support this newer instruction set extension. - ---*/ - -// -// Map friendly register names to the encoded register index. -// - - .equ .LZmmIndex_zmm0, 0 - .equ .LZmmIndex_zmm1, 1 - .equ .LZmmIndex_zmm2, 2 - .equ .LZmmIndex_zmm3, 3 - .equ .LZmmIndex_zmm4, 4 - .equ .LZmmIndex_zmm5, 5 - .equ .LZmmIndex_zmm6, 6 - .equ .LZmmIndex_zmm7, 7 - .equ .LZmmIndex_zmm8, 8 - .equ .LZmmIndex_zmm9, 9 - .equ .LZmmIndex_zmm10, 10 - .equ .LZmmIndex_zmm11, 11 - .equ .LZmmIndex_zmm12, 12 - .equ .LZmmIndex_zmm13, 13 - .equ .LZmmIndex_zmm14, 14 - .equ .LZmmIndex_zmm15, 15 - .equ .LZmmIndex_zmm16, 16 - .equ .LZmmIndex_zmm17, 17 - .equ .LZmmIndex_zmm18, 18 - .equ .LZmmIndex_zmm19, 19 - .equ .LZmmIndex_zmm20, 20 - .equ .LZmmIndex_zmm21, 21 - .equ .LZmmIndex_zmm22, 22 - .equ .LZmmIndex_zmm23, 23 - .equ .LZmmIndex_zmm24, 24 - .equ .LZmmIndex_zmm25, 25 - .equ .LZmmIndex_zmm26, 26 - .equ .LZmmIndex_zmm27, 27 - .equ .LZmmIndex_zmm28, 28 - .equ .LZmmIndex_zmm29, 29 - .equ .LZmmIndex_zmm30, 30 - .equ .LZmmIndex_zmm31, 31 - - .equ .LGprIndex_rax, 0 - .equ .LGprIndex_rcx, 1 - .equ .LGprIndex_rdx, 2 - .equ .LGprIndex_rbx, 3 - .equ .LGprIndex_rbp, 5 - .equ .LGprIndex_rsi, 6 - .equ .LGprIndex_rdi, 7 - .equ .LGprIndex_r8, 8 - .equ .LGprIndex_r9, 9 - .equ .LGprIndex_r10, 10 - .equ .LGprIndex_r11, 11 - .equ .LGprIndex_r12, 12 - .equ .LGprIndex_r13, 13 - .equ .LGprIndex_r14, 14 - .equ .LGprIndex_r15, 15 - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr zmm1,zmm2,zmm3 - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - ---*/ - - .macro VnniZmmZmmZmm Opcode, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + ((((.LZmmIndex_\Src2Reg\() >> 4) & 1) ^ 1) << 6) - .set Payload0, Payload0 + ((((.LZmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 4) & 1) ^ 1) << 4) - - .set Payload1, 0x05 # "66" prefix - .set Payload1, Payload1 + (((.LZmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set Payload2, 0x40 # 512-bit vector length - .set Payload2, Payload2 + ((((.LZmmIndex_\Src1Reg\() >> 4) & 1) ^ 1) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LZmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LZmmIndex_\Src2Reg\() & 7) - - .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte - - .endm - - .macro VpdpbusdZmmZmmZmm DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 0x50, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbusdsZmmZmmZmm DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 0x51, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdZmmZmmZmm DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 0x52, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdsZmmZmmZmm DestReg, Src1Reg, Src2Reg - - VnniZmmZmmZmm 0x53, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr zmm1,zmm2,DWORD PTR [BaseReg+IndexReg*Scale+ByteOffset]{1to16} - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - BaseReg - Specifies the base register of the broadcast operand. - - ByteOffset - Specifies the DWORD aligned byte offset for the broadcast - operand. - - IndexReg - Specifies the optional index register of the broadcast operand. - - Scale - Specifies the scaling factor of the optional index register. - ---*/ - - .macro VnniZmmZmmBroadcast Opcode, DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) -.ifnes "\IndexReg\()", "" - .set Payload0, Payload0 + ((((.LGprIndex_\IndexReg\() >> 3) & 1) ^ 1) << 6) -.else - .set Payload0, Payload0 + 0x40 # zero logical index register -.endif - .set Payload0, Payload0 + ((((.LGprIndex_\BaseReg\() >> 3) & 1) ^ 1) << 5) - .set Payload0, Payload0 + ((((.LZmmIndex_\DestReg\() >> 4) & 1) ^ 1) << 4) - - .set Payload1, 0x05 # "66" prefix - .set Payload1, Payload1 + (((.LZmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set Payload2, 0x50 # 512-bit vector length, broadcast - .set Payload2, Payload2 + ((((.LZmmIndex_\Src1Reg\() >> 4) & 1) ^ 1) << 3) - - .set ModRMByte, 0x00 # memory form - .set ModRMByte, ModRMByte + ((.LZmmIndex_\DestReg\() & 7) << 3) -.ifnes "\IndexReg\()", "" - .set ModRMByte, ModRMByte + 0x04 # indicate SIB byte needed -.else - .set ModRMByte, ModRMByte + (.LGprIndex_\BaseReg\() & 7) -.endif -.if \ByteOffset\() != 0 - .set ModRMByte, ModRMByte + 0x40 # indicate disp8 byte offset -.endif - -.ifnes "\IndexReg\()", "" - .set SibByte, 0 -.ifeqs "\Scale\()", "2" - .set SibByte, SibByte + (1 << 6) -.else -.ifeqs "\Scale\()", "4" - .set SibByte, SibByte + (2 << 6) -.else -.ifeqs "\Scale\()", "8" - .set SibByte, SibByte + (3 << 6) -.else -.ifnes "\Scale\()", "1" - .err -.endif -.endif -.endif -.endif - .set SibByte, SibByte + ((.LGprIndex_\IndexReg\() & 7) << 3) - .set SibByte, SibByte + (.LGprIndex_\BaseReg\() & 7) -.endif - - .byte 0x62, Payload0, Payload1, Payload2, \Opcode\(), ModRMByte -.ifnes "\IndexReg\()", "" - .byte SibByte -.endif -.if \ByteOffset\() != 0 - .byte (\ByteOffset\() >> 2) -.endif - - .endm - - .macro VpdpbusdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 0x50, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() - - .endm - - .macro VpdpbusdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 0x51, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() - - .endm - - .macro VpdpwssdZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 0x52, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() - - .endm - - .macro VpdpwssdsZmmZmmBroadcast DestReg, Src1Reg, BaseReg, ByteOffset, IndexReg, Scale - - VnniZmmZmmBroadcast 0x53, \DestReg\(), \Src1Reg\(), \BaseReg\(), \ByteOffset\(), \IndexReg\(), \Scale\() - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/AssembleAvxVnni.h b/onnxruntime/core/mlas/lib/x86_64/AssembleAvxVnni.h deleted file mode 100644 index 676102391d935..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/AssembleAvxVnni.h +++ /dev/null @@ -1,328 +0,0 @@ -/*++ - -Copyright (c) 2020 Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - AssembleAvxVnni.h - -Abstract: - - This module contains macros to build VNNI instructions for toolchains that - do not natively support this newer instruction set extension. - ---*/ - -// -// Map friendly register names to the encoded register index. -// - - .equ .LYmmIndex_ymm0, 0 - .equ .LYmmIndex_ymm1, 1 - .equ .LYmmIndex_ymm2, 2 - .equ .LYmmIndex_ymm3, 3 - .equ .LYmmIndex_ymm4, 4 - .equ .LYmmIndex_ymm5, 5 - .equ .LYmmIndex_ymm6, 6 - .equ .LYmmIndex_ymm7, 7 - .equ .LYmmIndex_ymm8, 8 - .equ .LYmmIndex_ymm9, 9 - .equ .LYmmIndex_ymm10, 10 - .equ .LYmmIndex_ymm11, 11 - .equ .LYmmIndex_ymm12, 12 - .equ .LYmmIndex_ymm13, 13 - .equ .LYmmIndex_ymm14, 14 - .equ .LYmmIndex_ymm15, 15 - - .equ .LXmmIndex_xmm0, 0 - .equ .LXmmIndex_xmm1, 1 - .equ .LXmmIndex_xmm2, 2 - .equ .LXmmIndex_xmm3, 3 - .equ .LXmmIndex_xmm4, 4 - .equ .LXmmIndex_xmm5, 5 - .equ .LXmmIndex_xmm6, 6 - .equ .LXmmIndex_xmm7, 7 - .equ .LXmmIndex_xmm8, 8 - .equ .LXmmIndex_xmm9, 9 - .equ .LXmmIndex_xmm10, 10 - .equ .LXmmIndex_xmm11, 11 - .equ .LXmmIndex_xmm12, 12 - .equ .LXmmIndex_xmm13, 13 - .equ .LXmmIndex_xmm14, 14 - .equ .LXmmIndex_xmm15, 15 - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr ymm1,ymm2,ymm3 - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - ---*/ - - .macro VnniYmmYmmYmm Opcode, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LYmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((.LYmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - - .set Payload1, 0x05 # "66" prefix - .set Payload1, Payload1 + (((.LYmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LYmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LYmmIndex_\Src2Reg\() & 7) - - .byte 0xC4, Payload0, Payload1, \Opcode\(), ModRMByte - - .endm - - .macro VpdpbusdYmmYmmYmm DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 0x50, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbusdsYmmYmmYmm DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 0x51, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdYmmYmmYmm DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 0x52, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdsYmmYmmYmm DestReg, Src1Reg, Src2Reg - - VnniYmmYmmYmm 0x53, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr xmm1,xmm2,xmm3 - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - ---*/ - - .macro VnniXmmXmmXmm Opcode, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LXmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((.LXmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - - .set Payload1, 0x05 # "66" prefix - .set Payload1, Payload1 + (((.LXmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LXmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LXmmIndex_\Src2Reg\() & 7) - - .byte 0xC4, Payload0, Payload1, \Opcode\(), ModRMByte - - .endm - - .macro VpdpbusdXmmXmmXmm DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 0x50, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbusdsXmmXmmXmm DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 0x51, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdXmmXmmXmm DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 0x52, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpwssdsXmmXmmXmm DestReg, Src1Reg, Src2Reg - - VnniXmmXmmXmm 0x53, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr ymm1,ymm2,ymm3 - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - Prefix - Specifies the opcode prefix for payload 1 - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - ---*/ - .macro Avx2VnniYmmYmmYmm Opcode, Prefix, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LYmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((.LYmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - - .set Payload1, 0x04 + \Prefix\() # 256-bit length and opcode prefix - .set Payload1, Payload1 + (((.LYmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LYmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LYmmIndex_\Src2Reg\() & 7) - - .byte 0xC4, Payload0, Payload1, \Opcode\(), ModRMByte - - .endm - - .macro VpdpbssdYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x50, 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbssdsYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x51, 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbsudYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x50, 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbsudsYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x51, 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbuudYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x50, 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbuudsYmmYmmYmm DestReg, Src1Reg, Src2Reg - - Avx2VnniYmmYmmYmm 0x51, 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro builds a VNNI instruction of the form: - - instr xmm1,xmm2,xmm3 - -Arguments: - - Opcode - Specifies the opcode for the VNNI instruction. - - Prefix - Specifies the opcode prefix for payload 1 - - DestReg - Specifies the destination register. - - Src1Reg - Specifies the first source register. - - Src2Reg - Specifies the second source register. - ---*/ - .macro Avx2VnniXmmXmmXmm Opcode, Prefix, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LYmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((.LYmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - - .set Payload1, 0x00 + \Prefix\() # 128-bit length and opcode prefix - .set Payload1, Payload1 + (((.LYmmIndex_\Src1Reg\() & 15) ^ 15) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LYmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LYmmIndex_\Src2Reg\() & 7) - - .byte 0xC4, Payload0, Payload1, \Opcode\(), ModRMByte - - .endm - - .macro VpdpbssdXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x50, 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbssdsXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x51, 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbsudXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x50, 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbsudsXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x51, 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbuudXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x50, 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - .macro VpdpbuudsXmmXmmXmm DestReg, Src1Reg, Src2Reg - - Avx2VnniXmmXmmXmm 0x51, 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S deleted file mode 100644 index 3004599bcb3d4..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S +++ /dev/null @@ -1,943 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymKernelAvx2.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - - This implementation uses AVX2 and AVX VNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "ConvSymKernelCommon.h" -#include "AssembleAvxVnni.h" - - .intel_syntax noprefix - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate a single row of the - output block. - -Arguments: - - Vec1Reg - Supplies the low block accumulator register. - - Vec2Reg - Supplies the high block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from the filter buffer. - - ymm1 - Supplies the second vector loaded from the filter buffer. - - ymm2 - Supplies the broadcast value loaded from the input buffer. - - ymm3 - Supplies a scratch register for intermediate results. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateRowAvx2 Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 - vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 - - .endm - - .macro MultiplyAccumulateRowAvxVnni Vec1Reg, Vec2Reg - - VpdpbusdsYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbusdsYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from the filter to fetch elements. - - BroadcastOffset - Supplies the byte offset from the input to fetch elements. - -Implicit Arguments: - - rdx - Supplies the address of the filter buffer. - - r10 - Supplies the address of the base of the input buffer. - -Implicit Arguments (Avx2): - - r11-r13 - Supplies the relative byte offsets from the base of the input - buffer to access the second through fourth rows. - - ymm4-ymm11 - Supplies the block accumulators. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - -Implicit Arguments (AvxVnni): - - r11-r15 - Supplies the relative byte offsets from the base of the input - buffer to access the second through sixth rows. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock Isa, RowCount, VectorOffset, BroadcastOffset - - vmovdqu ymm0,YMMWORD PTR [rdx+\VectorOffset\()] - vmovdqu ymm1,YMMWORD PTR [rdx+\VectorOffset\()+32] - EmitIfCountGE \RowCount\(),1,"vpbroadcastd ymm2,DWORD PTR [r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),1,"MultiplyAccumulateRow\Isa\() ymm4,ymm5" - EmitIfCountGE \RowCount\(),2,"vpbroadcastd ymm2,DWORD PTR [r10+r11+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),2,"MultiplyAccumulateRow\Isa\() ymm6,ymm7" - EmitIfCountGE \RowCount\(),3,"vpbroadcastd ymm2,DWORD PTR [r10+r12+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),3,"MultiplyAccumulateRow\Isa\() ymm8,ymm9" - EmitIfCountGE \RowCount\(),4,"vpbroadcastd ymm2,DWORD PTR [r10+r13+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),4,"MultiplyAccumulateRow\Isa\() ymm10,ymm11" - EmitIfCountGE \RowCount\(),5,"vpbroadcastd ymm2,DWORD PTR [r10+r14+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),5,"MultiplyAccumulateRow\Isa\() ymm12,ymm13" - EmitIfCountGE \RowCount\(),6,"vpbroadcastd ymm2,DWORD PTR [r10+r15+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(),6,"MultiplyAccumulateRow\Isa\() ymm14,ymm15" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple times - and advancing the input and filter data pointers. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - RowCount - Supplies the number of rows to produce. - - UnrollLoop - Supplies a non-blank value if the loop should be unrolled to - improve performance. - -Implicit Arguments: - - rax - Supplies the number of input channels. - - rdx - Supplies the address of the filter buffer. - - r10 - Supplies the address of the base of the input buffer. - ---*/ - - .macro ComputeBlockLoop Isa, RowCount, UnrollLoop - -.ifeqs "\UnrollLoop\()","UnrollLoop" - sub rax,4*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - ComputeBlock \Isa\(),\RowCount\(),0*64,0 - ComputeBlock \Isa\(),\RowCount\(),1*64,4 - ComputeBlock \Isa\(),\RowCount\(),2*64,8 - ComputeBlock \Isa\(),\RowCount\(),3*64,12 - add r10,4*4 # advance input base address - add rdx,4*16*4 # advance filter address - sub rax,4*4 # decrement elements remaining - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add rax,4*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ -.endif - -.LComputeBlockBy1Loop\@: - ComputeBlock \Isa\(),\RowCount\(),0*64,0 - add r10,4 # advance input base address - add rdx,16*4 # advance filter address - sub rax,4 # decrement elements remaining - jnz .LComputeBlockBy1Loop\@ - -.LComputeBlockLoopExit\@: - - .endm - -/*++ - -Macro Description: - - This macro generates code to convert the block accumulators from the matrix - multiply loop to float values. - -Arguments: - - RegList - Supplies the list of vector registers to operate on. - -Implicit Arguments: - - ymm0 - Supplies the integer bias vector. - - ymm1 - Supplies the output scale vector. - ---*/ - - .macro ConvertAccumulatorToFloatRegList RegList - -// -// Offset each value by the per-channel bias value, convert to floating point, -// and apply the output scale. -// - - EmitForEachRegister "\RegList\()","vpaddd \RegItem\(),\RegItem\(),ymm0" - EmitForEachRegister "\RegList\()","vcvtdq2ps \RegItem\(),\RegItem\()" - EmitForEachRegister "\RegList\()","vmulps \RegItem\(),\RegItem\(),ymm1" - - .endm - -/*++ - -Macro Description: - - This macro generates code to convert float values to 32-bit integers in the - range 0 to 255. - -Arguments: - - RegList - Supplies the list of vector registers to operate on. - -Implicit Arguments: - - ymm0 - Supplies the broadcasted minimum clip float value. - - This is set to static_cast(0 - ZeroPointValue). - - ymm1 - Supplies the broadcasted maximum clip float value. - - This is set to static_cast(255 - ZeroPointValue). - - ymm2 - Supplies the broadcasted zero point integer value. - ---*/ - - .macro ConvertFloatToIntegerRegList RegList - -// -// Clip the float values to the integer range covered by the output zero point. -// This also keeps values outside the range INT_MIN to INT_MAX from converting -// to INT_MIN. -// - - EmitForEachRegister "\RegList\()","vmaxps \RegItem\(),\RegItem\(),ymm0" - EmitForEachRegister "\RegList\()","vminps \RegItem\(),\RegItem\(),ymm1" - -// -// Convert the float value to integer and add the zero point offset. -// - - EmitForEachRegister "\RegList\()","vcvtps2dq \RegItem\(),\RegItem\()" - EmitForEachRegister "\RegList\()","vpaddd \RegItem\(),\RegItem\(),ymm2" - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner kernel to compute a convolution - for the elements of an output row for a set of filter rows. - -Arguments: - - Isa - Supplies the instruction set architecture string. - ---*/ - - .macro ConvSymKernelFunction Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - KernelSize (rcx) - Supplies the size of the kernel. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (r8) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 4. - - OutputChannels (r9) - Supplies the number of output channels. - - ChannelCount - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 8 or 16. - - OutputCount - Supplies the number of output elements this iteration produces. - -.ifeqs "\Isa\()","AvxVnni" - This implementation requires the count to be in the range 1 to 6. -.else - This implementation requires the count to be in the range 1 to 4. -.endif - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymKernel\Isa\() - - push rbp - push rbx - push r12 - push r13 - sub rsp,.LConvSymKernelFrame_SavedR13 -.ifeqs "\Isa\()","AvxVnni" - mov .LConvSymKernelFrame_SavedR14[rsp],r14 - mov .LConvSymKernelFrame_SavedR15[rsp],r15 -.endif - - mov .LConvSymKernelFrame_InputChannels[rsp],r8 - mov .LConvSymKernelFrame_OutputChannels[rsp],r9 - mov r8,rdx # shuffle registers to Windows ABI - mov r9,rcx - mov rcx,rdi - mov rdx,rsi - - lea rdi,[r9*8] - mov ebx,DWORD PTR .LConvSymKernelFrame_OutputCount[rsp] - mov rsi,.LConvSymKernelFrame_InputChannels[rsp] - mov ebp,DWORD PTR .LConvSymKernelFrame_KernelFlags[rsp] - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - vpxor xmm6,xmm6,xmm6 - vpxor xmm7,xmm7,xmm7 - vpxor xmm8,xmm8,xmm8 - vpxor xmm9,xmm9,xmm9 - vpxor xmm10,xmm10,xmm10 - vpxor xmm11,xmm11,xmm11 -.ifeqs "\Isa\()","AvxVnni" - vpxor xmm12,xmm12,xmm12 - vpxor xmm13,xmm13,xmm13 - vpxor xmm14,xmm14,xmm14 - vpxor xmm15,xmm15,xmm15 -.else - vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] -.endif - -// -// Process an input block of length InputChannels for each element of the kernel. -// - -.LProcessNextInputBlock\@: - test bpl,MLAS_CONV_SYM_FLAG_INPUT_DIRECT - jz .LInputIndirection\@ - -// -// The input buffer points directly at the input data and this is effectively a -// GEMM operation (such as a pointwise convolution or an Im2Col transform). -// - -.LInputDirect\@: - xor r10,r10 - mov r11,rsi - lea r12,[r11+r11] - lea r13,[r12+r11] -.ifeqs "\Isa\()","AvxVnni" - lea r14,[r13+r11] - lea r15,[r14+r11] -.endif - cmp ebx,2 - cmovb r11,r10 # use first row if output count is small - cmovbe r12,r10 - cmp ebx,4 - cmovb r13,r10 -.ifeqs "\Isa\()","AvxVnni" - cmovbe r14,r10 - cmp ebx,6 - cmovb r15,r10 -.endif - mov r10,rcx - jmp .LComputeBlockLoopStart\@ - -.LInputIndirection\@: - lea r11,[rcx+rdi] - lea r12,[rcx+rdi*2] - lea r13,[r11+rdi*2] -.ifeqs "\Isa\()","AvxVnni" - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] -.endif - cmp ebx,2 - cmovb r11,rcx # use first row if output count is small - cmovbe r12,rcx - cmp ebx,4 - cmovb r13,rcx -.ifeqs "\Isa\()","AvxVnni" - cmovbe r14,rcx - cmp ebx,6 - cmovb r15,rcx -.endif - mov r10,QWORD PTR [rcx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] -.ifeqs "\Isa\()","AvxVnni" - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] -.endif - add rcx,8 # advance indirection buffer address - sub r11,r10 # compute deltas from base address - sub r12,r10 - sub r13,r10 -.ifeqs "\Isa\()","AvxVnni" - sub r14,r10 - sub r15,r10 -.endif - -.LComputeBlockLoopStart\@: - mov rax,rsi # reload input channels - cmp ebx,2 # output count <= 2? - jbe .LComputeBlockLoopBy2\@ -.ifeqs "\Isa\()","AvxVnni" - cmp ebx,4 # output count <= 4? - jbe .LComputeBlockLoopBy4\@ - ComputeBlockLoop \Isa\(),6,UnrollLoop -.else - ComputeBlockLoop \Isa\(),4,UnrollLoop -.endif - -.LComputeBlockLoopDone\@: - dec r9 # decrement input blocks remaining - jnz .LProcessNextInputBlock\@ - -// -// Apply the bias and convert the block accumulators to intermediate float values. -// - - mov rdx,.LConvSymKernelFrame_PostProcessParams[rsp] - mov rsi,.LConvSymKernelFrame_OutputChannels[rsp] - mov r11d,DWORD PTR .LConvSymKernelFrame_ChannelCount[rsp] - mov rcx,.LConvSymPostProcessParams_Bias[rdx] - mov r9,.LConvSymPostProcessParams_Scale[rdx] - lea r10,[rsi*2+rsi] # compute fourth row output offset - add r10,r8 - vmovdqu ymm0,YMMWORD PTR [rcx] # load low bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz .LBroadcastScaleValue\@ - vmovups ymm1,YMMWORD PTR [r9] # load low scale vector - jmp .LConvertLowAccumulatorsToFloat\@ - -.LBroadcastScaleValue\@: - vbroadcastss ymm1,DWORD PTR [r9] - -.LConvertLowAccumulatorsToFloat\@: -.ifeqs "\Isa\()","AvxVnni" - ConvertAccumulatorToFloatRegList "ymm4,ymm6,ymm8,ymm10,ymm12,ymm14" -.else - ConvertAccumulatorToFloatRegList "ymm4,ymm6,ymm8,ymm10" -.endif - cmp r11d,8 # output single vector? - jbe .LConvertFloatsToIntegers\@ - vmovdqu ymm0,YMMWORD PTR [rcx+8*4] # load high bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz .LConvertHighAccumulatorsToFloat\@ - vmovups ymm1,YMMWORD PTR [r9+8*4] # load high scale vector - -.LConvertHighAccumulatorsToFloat\@: -.ifeqs "\Isa\()","AvxVnni" - ConvertAccumulatorToFloatRegList "ymm5,ymm7,ymm9,ymm11,ymm13,ymm15" -.else - ConvertAccumulatorToFloatRegList "ymm5,ymm7,ymm9,ymm11" -.endif - -// -// Convert the intermediate float values to 32-bit integers in the range 0 to 255. -// - -.LConvertFloatsToIntegers\@: - vbroadcastss ymm0,DWORD PTR .LConvSymPostProcessParams_MinimumValue[rdx] - vbroadcastss ymm1,DWORD PTR .LConvSymPostProcessParams_MaximumValue[rdx] - vpbroadcastd ymm2,DWORD PTR .LConvSymPostProcessParams_OutputZeroPoint[rdx] -.ifeqs "\Isa\()","AvxVnni" - ConvertFloatToIntegerRegList "ymm4,ymm6,ymm8,ymm10,ymm12,ymm14" -.else - ConvertFloatToIntegerRegList "ymm4,ymm6,ymm8,ymm10" -.endif - cmp r11d,8 # output single vector? - jbe .LStoreQuantizedOutputBy8\@ -.ifeqs "\Isa\()","AvxVnni" - ConvertFloatToIntegerRegList "ymm5,ymm7,ymm9,ymm11,ymm13,ymm15" -.else - ConvertFloatToIntegerRegList "ymm5,ymm7,ymm9,ymm11" -.endif - -// -// Pack with saturation and store 16 bytes to the output buffer. -// - -.LStoreQuantizedOutputBy16\@: -.ifeqs "\Isa\()","AvxVnni" - cmp ebx,5 - ja .LStoreQuantizedOutput6By16\@ - je .LStoreQuantizedOutput5By16\@ -.endif - cmp ebx,3 - ja .LStoreQuantizedOutput4By16\@ - je .LStoreQuantizedOutput3By16\@ - cmp ebx,1 - ja .LStoreQuantizedOutput2By16\@ - jmp .LStoreQuantizedOutput1By16\@ - -.ifeqs "\Isa\()","AvxVnni" -.LStoreQuantizedOutput6By16\@: - vextracti128 xmm0,ymm14,1 - vpackusdw xmm14,xmm14,xmm0 - vextracti128 xmm1,ymm15,1 - vpackusdw xmm15,xmm15,xmm1 - vpackuswb xmm14,xmm14,xmm15 - vmovdqu XMMWORD PTR [r10+rsi*2],xmm14 - -.LStoreQuantizedOutput5By16\@: - vextracti128 xmm0,ymm12,1 - vpackusdw xmm12,xmm12,xmm0 - vextracti128 xmm1,ymm13,1 - vpackusdw xmm13,xmm13,xmm1 - vpackuswb xmm12,xmm12,xmm13 - vmovdqu XMMWORD PTR [r10+rsi],xmm12 -.endif - -.LStoreQuantizedOutput4By16\@: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vextracti128 xmm1,ymm11,1 - vpackusdw xmm11,xmm11,xmm1 - vpackuswb xmm10,xmm10,xmm11 - vmovdqu XMMWORD PTR [r10],xmm10 - -.LStoreQuantizedOutput3By16\@: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vextracti128 xmm1,ymm9,1 - vpackusdw xmm9,xmm9,xmm1 - vpackuswb xmm8,xmm8,xmm9 - vmovdqu XMMWORD PTR [r8+rsi*2],xmm8 - -.LStoreQuantizedOutput2By16\@: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vextracti128 xmm1,ymm7,1 - vpackusdw xmm7,xmm7,xmm1 - vpackuswb xmm6,xmm6,xmm7 - vmovdqu XMMWORD PTR [r8+rsi],xmm6 - -.LStoreQuantizedOutput1By16\@: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vextracti128 xmm1,ymm5,1 - vpackusdw xmm5,xmm5,xmm1 - vpackuswb xmm4,xmm4,xmm5 - vmovdqu XMMWORD PTR [r8],xmm4 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel\@: - vzeroupper -.ifeqs "\Isa\()","AvxVnni" - mov r14,.LConvSymKernelFrame_SavedR14[rsp] - mov r15,.LConvSymKernelFrame_SavedR15[rsp] -.endif - add rsp,.LConvSymKernelFrame_SavedR13 - pop r13 - pop r12 - pop rbx - pop rbp - ret - -// -// Pack with saturation and store 8 bytes to the output buffer. -// - -.LStoreQuantizedOutputBy8\@: -.ifeqs "\Isa\()","AvxVnni" - cmp ebx,5 - ja .LStoreQuantizedOutput6By8\@ - je .LStoreQuantizedOutput5By8\@ -.endif - cmp ebx,3 - ja .LStoreQuantizedOutput4By8\@ - je .LStoreQuantizedOutput3By8\@ - cmp ebx,1 - ja .LStoreQuantizedOutput2By8\@ - jmp .LStoreQuantizedOutput1By8\@ - -.ifeqs "\Isa\()","AvxVnni" -.LStoreQuantizedOutput6By8\@: - vextracti128 xmm0,ymm14,1 - vpackusdw xmm14,xmm14,xmm0 - vpackuswb xmm14,xmm14,xmm14 - vmovq QWORD PTR [r10+rsi*2],xmm14 - -.LStoreQuantizedOutput5By8\@: - vextracti128 xmm0,ymm12,1 - vpackusdw xmm12,xmm12,xmm0 - vpackuswb xmm12,xmm12,xmm12 - vmovq QWORD PTR [r10+rsi],xmm12 -.endif - -.LStoreQuantizedOutput4By8\@: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vpackuswb xmm10,xmm10,xmm10 - vmovq QWORD PTR [r10],xmm10 - -.LStoreQuantizedOutput3By8\@: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vpackuswb xmm8,xmm8,xmm8 - vmovq QWORD PTR [r8+rsi*2],xmm8 - -.LStoreQuantizedOutput2By8\@: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vpackuswb xmm6,xmm6,xmm6 - vmovq QWORD PTR [r8+rsi],xmm6 - -.LStoreQuantizedOutput1By8\@: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vpackuswb xmm4,xmm4,xmm4 - vmovq QWORD PTR [r8],xmm4 - jmp .LExitKernel\@ - -// -// Process the tail output counts out of line with a reduced block size. -// - -.ifeqs "\Isa\()","AvxVnni" -.LComputeBlockLoopBy4\@: - ComputeBlockLoop \Isa\(),4 - jmp .LComputeBlockLoopDone\@ -.endif - -.LComputeBlockLoopBy2\@: - ComputeBlockLoop \Isa\(),2 - jmp .LComputeBlockLoopDone\@ - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate a single cell of the - output block. - -Arguments: - - AccumReg - Supplies the register to accumulate into. - - Mult1Reg - Supplies the first multiplication operand register. This register - may be trashed on return. - - Mult2Reg - Supplies the second multiplication operand register. - ---*/ - - .macro DepthwiseMultiplyAccumulateCellAvx2 AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd \Mult1Reg\(),\Mult1Reg\(),\Mult2Reg\() - vpaddd \AccumReg\(),\AccumReg\(),\Mult1Reg\() - - .endm - - .macro DepthwiseMultiplyAccumulateCellAvxVnni AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsYmmYmmYmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner kernel to compute a depthwise - convolution for the elements of an output row for a set of filter rows. - -Arguments: - - Isa - Supplies the instruction set architecture string. - ---*/ - - .macro ConvSymDepthwiseKernelFunction Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the indirection buffer. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - KernelSize (rcx) - Supplies the size of the kernel. - - Channels (r8) - Supplies the number of input and output channels. - - ChannelOffset (r9) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount - Supplies the number of channels this iteration produces. - - This implementation requires the count to be 16. - - OutputCount - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 4. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseKernel\Isa\() - - push rbp - push rbx - push r12 - push r13 - sub rsp,.LConvSymDepthwiseKernelFrame_SavedR13 - - mov .LConvSymDepthwiseKernelFrame_Channels[rsp],r8 - mov .LConvSymDepthwiseKernelFrame_ChannelOffset[rsp],r9 - mov r8,rdx # shuffle registers to Windows ABI - mov r9,rcx - mov rcx,rdi - mov rdx,rsi - - lea rdi,[r9*8] - mov ebx,DWORD PTR .LConvSymDepthwiseKernelFrame_OutputCount[rsp] - mov rsi,.LConvSymDepthwiseKernelFrame_Channels[rsp] - mov rax,.LConvSymDepthwiseKernelFrame_ChannelOffset[rsp] - mov ebp,DWORD PTR .LConvSymDepthwiseKernelFrame_KernelFlags[rsp] - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - vpxor xmm6,xmm6,xmm6 - vpxor xmm7,xmm7,xmm7 - vpxor xmm8,xmm8,xmm8 - vpxor xmm9,xmm9,xmm9 - vpxor xmm10,xmm10,xmm10 - vpxor xmm11,xmm11,xmm11 - -// -// Process an input block of length Channels for each element of the kernel. -// - -.LProcessNextInputBlock\@: - vpmovsxbd ymm0,QWORD PTR [rdx] - vpmovsxbd ymm1,QWORD PTR [rdx+8] - lea r11,[rcx+rdi] - lea r12,[rcx+rdi*2] - lea r13,[r11+rdi*2] - cmp ebx,2 - cmovb r11,rcx # use first row if output count is small - cmovbe r12,rcx - cmp ebx,4 - cmovb r13,rcx - mov r10,QWORD PTR [rcx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - add rcx,8 # advance indirection buffer address - vpmovzxbd ymm2,QWORD PTR [r10+rax] - vpmovzxbd ymm3,QWORD PTR [r10+rax+8] - DepthwiseMultiplyAccumulateCell\Isa\() ymm4,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r11+rax] - DepthwiseMultiplyAccumulateCell\Isa\() ymm5,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r11+rax+8] - DepthwiseMultiplyAccumulateCell\Isa\() ymm6,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r12+rax] - DepthwiseMultiplyAccumulateCell\Isa\() ymm7,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r12+rax+8] - DepthwiseMultiplyAccumulateCell\Isa\() ymm8,ymm2,ymm0 - vpmovzxbd ymm2,QWORD PTR [r13+rax] - DepthwiseMultiplyAccumulateCell\Isa\() ymm9,ymm3,ymm1 - vpmovzxbd ymm3,QWORD PTR [r13+rax+8] - DepthwiseMultiplyAccumulateCell\Isa\() ymm10,ymm2,ymm0 - add rdx,rsi # advance filter to next kernel - DepthwiseMultiplyAccumulateCell\Isa\() ymm11,ymm3,ymm1 - dec r9 # decrement input blocks remaining - jnz .LProcessNextInputBlock\@ - -// -// Apply the bias and convert the block accumulators to intermediate float values. -// - - mov rdx,.LConvSymDepthwiseKernelFrame_PostProcessParams[rsp] - mov rcx,.LConvSymPostProcessParams_Bias[rdx] - mov r9,.LConvSymPostProcessParams_Scale[rdx] - vmovdqu ymm0,YMMWORD PTR [rcx] # load low bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz .LBroadcastScaleValue\@ - vmovups ymm1,YMMWORD PTR [r9] # load low scale vector - jmp .LConvertLowAccumulatorsToFloat\@ - -.LBroadcastScaleValue\@: - vbroadcastss ymm1,DWORD PTR [r9] - -.LConvertLowAccumulatorsToFloat\@: - ConvertAccumulatorToFloatRegList "ymm4,ymm6,ymm8,ymm10" - vmovdqu ymm0,YMMWORD PTR [rcx+8*4] # load high bias vector - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz .LConvertHighAccumulatorsToFloat\@ - vmovups ymm1,YMMWORD PTR [r9+8*4] # load high scale vector - -.LConvertHighAccumulatorsToFloat\@: - ConvertAccumulatorToFloatRegList "ymm5,ymm7,ymm9,ymm11" - -// -// Convert the intermediate float values to 32-bit integers in the range 0 to 255. -// - -.LConvertFloatsToIntegers\@: - vbroadcastss ymm0,DWORD PTR .LConvSymPostProcessParams_MinimumValue[rdx] - vbroadcastss ymm1,DWORD PTR .LConvSymPostProcessParams_MaximumValue[rdx] - vpbroadcastd ymm2,DWORD PTR .LConvSymPostProcessParams_OutputZeroPoint[rdx] - ConvertFloatToIntegerRegList "ymm4,ymm6,ymm8,ymm10" - ConvertFloatToIntegerRegList "ymm5,ymm7,ymm9,ymm11" - -// -// Pack with saturation and store 16 bytes to the output buffer. -// - -.LStoreQuantizedOutputBy16\@: - lea r10,[rsi*2+rsi] - cmp ebx,3 - ja .LStoreQuantizedOutput4By16\@ - je .LStoreQuantizedOutput3By16\@ - cmp ebx,1 - ja .LStoreQuantizedOutput2By16\@ - jmp .LStoreQuantizedOutput1By16\@ - -.LStoreQuantizedOutput4By16\@: - vextracti128 xmm0,ymm10,1 - vpackusdw xmm10,xmm10,xmm0 - vextracti128 xmm1,ymm11,1 - vpackusdw xmm11,xmm11,xmm1 - vpackuswb xmm10,xmm10,xmm11 - vmovdqu XMMWORD PTR [r8+r10],xmm10 - -.LStoreQuantizedOutput3By16\@: - vextracti128 xmm0,ymm8,1 - vpackusdw xmm8,xmm8,xmm0 - vextracti128 xmm1,ymm9,1 - vpackusdw xmm9,xmm9,xmm1 - vpackuswb xmm8,xmm8,xmm9 - vmovdqu XMMWORD PTR [r8+rsi*2],xmm8 - -.LStoreQuantizedOutput2By16\@: - vextracti128 xmm0,ymm6,1 - vpackusdw xmm6,xmm6,xmm0 - vextracti128 xmm1,ymm7,1 - vpackusdw xmm7,xmm7,xmm1 - vpackuswb xmm6,xmm6,xmm7 - vmovdqu XMMWORD PTR [r8+rsi],xmm6 - -.LStoreQuantizedOutput1By16\@: - vextracti128 xmm0,ymm4,1 - vpackusdw xmm4,xmm4,xmm0 - vextracti128 xmm1,ymm5,1 - vpackusdw xmm5,xmm5,xmm1 - vpackuswb xmm4,xmm4,xmm5 - vmovdqu XMMWORD PTR [r8],xmm4 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel\@: - vzeroupper - add rsp,.LConvSymDepthwiseKernelFrame_SavedR13 - pop r13 - pop r12 - pop rbx - pop rbp - ret - - .endm - -// -// Generate the convolution kernels. -// - -ConvSymKernelFunction Avx2 -ConvSymDepthwiseKernelFunction Avx2 - -ConvSymKernelFunction AvxVnni -ConvSymDepthwiseKernelFunction AvxVnni - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx512Core.S deleted file mode 100644 index 27f2a8aec7e66..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx512Core.S +++ /dev/null @@ -1,890 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymKernelAvx512Core.asm - -Abstract: - - This module implements the kernels for the symmetric quantized integer - convolution operation. - - This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "ConvSymKernelCommon.h" -#include "AssembleAvx512Vnni.h" - - .intel_syntax noprefix - -/*++ - -Macro Description: - - This macro generates code to setup registers that is common between - convolution kernel types. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - r9 - Supplies the size of the kernel. - -Output: - - rbx - Supplies the address of the input buffer. - - rdi - Supplies the input indirection buffer stride. - -.ifeqs , - zmm7 - Supplies a 512-bit with the broadcasted word value 0x0001. -.endif - - zmm8-zmm31 - Supplies the zeroed block accumulators. - - k1-k4 - Supplies the opmask registers loaded with a 64-bit channel bitmask - for KernelFrame.ChannelCount. - ---*/ - - .macro SetupRegistersCommon Isa, KernelFrame - - mov rbx,rcx # preserve base input address - lea rdi,[r9*8] # indirection buffer offset to next output -.ifeqs "\Isa\()","Avx512Core" - mov esi,1 - vpbroadcastw zmm7,esi # generate 512-bit word vector [0x0001] -.endif - EmitForEachRegister "zmm8,zmm9,zmm10,zmm11","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - mov ecx,DWORD PTR \KernelFrame\()_ChannelCount[rsp] - EmitForEachRegister "zmm12,zmm13,zmm14,zmm15","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - dec ecx # convert shift count to 0..63 - mov eax,2 - shl rax,cl # compute 2 << ChannelShiftCount - dec rax # convert to 64-bit channel bitmask - EmitForEachRegister "zmm16,zmm17,zmm18,zmm19","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - kmovw k1,eax # k1 = channel bitmask[0..15] - shr rax,16 - EmitForEachRegister "zmm20,zmm21,zmm22,zmm23","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - kmovw k2,eax # k2 = channel bitmask[16..31] - shr rax,16 - EmitForEachRegister "zmm24,zmm25,zmm26,zmm27","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - kmovw k3,eax # k3 = channel bitmask[32..47] - shr eax,16 - EmitForEachRegister "zmm28,zmm29,zmm30,zmm31","vpxord \RegItem\(),\RegItem\(),\RegItem\()" - kmovw k4,eax # k4 = channel bitmask[48..63] - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate a single cell of the - output block. - -Arguments: - - AccumReg - Supplies the register to accumulate into. - - Mult1Reg - Supplies the first multiplication operand register. - - Mult2Reg - Supplies the second multiplication operand register. - -Implicit Arguments: - - zmm5 - Supplies a scratch register for intermediate results. - - zmm7 - Supplies a 512-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateCellAvx512Core AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm5,\Mult1Reg\(),\Mult2Reg\() - vpmaddwd zmm5,zmm5,zmm7 - vpaddd \AccumReg\(),\AccumReg\(),zmm5 - - .endm - - .macro MultiplyAccumulateCellAvx512Vnni AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsZmmZmmZmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - ColumnCount - Supplies the number of columns to produce. - - VectorOffset - Supplies the byte offset from the filter to fetch elements. - - BroadcastOffset - Supplies the byte offset from the input to fetch elements. - -Implicit Arguments: - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the filter stride to access the packed data for the next 16 - output channels. - - rbp - Supplies three times the above filter stride. - - r10 - Supplies the address of the base of the input buffer. - - r11-r15 - Supplies the relative byte offsets from the base of the input - buffer to access the second through sixth rows. - - zmm8-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock Isa, ColumnCount, VectorOffset, BroadcastOffset - - EmitIfCountGE \ColumnCount\(),16,"vmovdqu32 zmm0,ZMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCountGE \ColumnCount\(),32,"vmovdqu32 zmm1,ZMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCountGE \ColumnCount\(),48,"vmovdqu32 zmm2,ZMMWORD PTR [rdx+rsi*2+\VectorOffset\()]" - EmitIfCountGE \ColumnCount\(),64,"vmovdqu32 zmm3,ZMMWORD PTR [rdx+rbp+\VectorOffset\()]" - vpbroadcastd zmm4,DWORD PTR [r10+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm8,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm9,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm10,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm11,zmm4,zmm3" - vpbroadcastd zmm4,DWORD PTR [r10+r11+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm12,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm13,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm14,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm15,zmm4,zmm3" - vpbroadcastd zmm4,DWORD PTR [r10+r12+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm16,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm17,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm18,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm19,zmm4,zmm3" - vpbroadcastd zmm4,DWORD PTR [r10+r13+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm20,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm21,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm22,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm23,zmm4,zmm3" - vpbroadcastd zmm4,DWORD PTR [r10+r14+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm24,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm25,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm26,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm27,zmm4,zmm3" - vpbroadcastd zmm4,DWORD PTR [r10+r15+\BroadcastOffset\()] - EmitIfCountGE \ColumnCount\(),16,"MultiplyAccumulateCell\Isa\() zmm28,zmm4,zmm0" - EmitIfCountGE \ColumnCount\(),32,"MultiplyAccumulateCell\Isa\() zmm29,zmm4,zmm1" - EmitIfCountGE \ColumnCount\(),48,"MultiplyAccumulateCell\Isa\() zmm30,zmm4,zmm2" - EmitIfCountGE \ColumnCount\(),64,"MultiplyAccumulateCell\Isa\() zmm31,zmm4,zmm3" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple times - and advancing the input and filter data pointers. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - ColumnCount - Supplies the number of columns to produce. - -Implicit Arguments: - - rax - Supplies the number of byte elements to process (multiple of 4). - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the filter stride to access the packed data for the next 16 - output channels. - - rbp - Supplies three times the above filter stride. - - r10 - Supplies the address of the base of the input buffer. - - r11-r15 - Supplies the relative byte offsets from the base of the input - buffer to access the second through sixth rows. - - zmm8-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop Isa, ColumnCount - -.LComputeBlockBy1Loop\@: - ComputeBlock \Isa\(),\ColumnCount\(),0*64,0 - add r10,4 # advance input base address - add rdx,16*4 # advance filter address - sub rax,4 # decrement elements remaining - jnz .LComputeBlockBy1Loop\@ - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner kernel to compute a convolution - for the elements of an output row for a set of filter rows. - -Arguments: - - Isa - Supplies the instruction set architecture string. - ---*/ - - .macro ConvSymKernelFunction Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points - directly at the input tensor. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an - indirection buffer. Every pointer in the indirection buffer points at a - InputChannels length vector (either from the input tensor or a vector of - padding values). These are grouped in batches of length KernelSize. - These batches are then repeated OutputCount times. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - KernelSize (rcx) - Supplies the size of the kernel. - - If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. - - InputChannels (r8) - Supplies the number of input channels. - - This implementation requires the count to be a multiple of 4. - - OutputChannels (r9) - Supplies the number of output channels. - - ChannelCount - Supplies the number of channels this iteration produces. - - This implementation requires the count to be in the range 1 to 64. - - OutputCount - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 6. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymKernel\Isa\() - - push rbp - push rbx - push r12 - push r13 - push r14 - push r15 - sub rsp,.LConvSymKernelFrame_SavedR15 - - mov .LConvSymKernelFrame_InputChannels[rsp],r8 - mov .LConvSymKernelFrame_OutputChannels[rsp],r9 - mov r8,rdx # shuffle registers to Windows ABI - mov r9,rcx - mov rcx,rdi - mov rdx,rsi - - SetupRegistersCommon \Isa\(),.LConvSymKernelFrame - - mov rsi,.LConvSymKernelFrame_InputChannels[rsp] - mov ecx,DWORD PTR .LConvSymKernelFrame_ChannelCount[rsp] - shl rsi,4 # 16 output channels per filter block - imul rsi,r9 # compute filter stride - lea rbp,[rsi*2+rsi] - -// -// Process an input block of length InputChannels for each element of the kernel. -// -// To keep code size small, this kernel always computes a fixed number of output -// rows. If the output count is less than this fixed number, then the first row -// is duplicated into the unused slots and the results are discarded. -// - -.LProcessNextInputBlock\@: - mov eax,DWORD PTR .LConvSymKernelFrame_OutputCount[rsp] - test BYTE PTR .LConvSymKernelFrame_KernelFlags[rsp],MLAS_CONV_SYM_FLAG_INPUT_DIRECT - jz .LInputIndirection\@ - -// -// The input buffer points directly at the input data and this is effectively a -// GEMM operation (such as a pointwise convolution or an Im2Col transform). -// - -.LInputDirect\@: - xor r10,r10 - mov r11,.LConvSymKernelFrame_InputChannels[rsp] - lea r12,[r11+r11] - lea r13,[r12+r11] - lea r14,[r13+r11] - lea r15,[r14+r11] - cmp eax,2 - cmovb r11,r10 # use first row if output count is small - cmovbe r12,r10 - cmp eax,4 - cmovb r13,r10 - cmovbe r14,r10 - cmp eax,6 - cmovb r15,r10 - mov r10,rbx - jmp .LComputeBlockLoopStart\@ - -.LInputIndirection\@: - lea r11,[rbx+rdi] - lea r12,[rbx+rdi*2] - lea r13,[r11+rdi*2] - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] - cmp eax,2 - cmovb r11,rbx # use first row if output count is small - cmovbe r12,rbx - cmp eax,4 - cmovb r13,rbx - cmovbe r14,rbx - cmp eax,6 - cmovb r15,rbx - mov r10,QWORD PTR [rbx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] - add rbx,8 # advance indirection buffer address - sub r11,r10 # compute deltas from base address - sub r12,r10 - sub r13,r10 - sub r14,r10 - sub r15,r10 - -.LComputeBlockLoopStart\@: - mov rax,.LConvSymKernelFrame_InputChannels[rsp] - cmp ecx,16 - jbe .LComputeBlockLoopBy16\@ - cmp ecx,32 - jbe .LComputeBlockLoopBy32\@ - cmp ecx,48 - jbe .LComputeBlockLoopBy48\@ - -.LComputeBlockLoopBy64\@: - ComputeBlockLoop \Isa\(),64 - jmp .LComputeBlockLoopDone\@ - -.LComputeBlockLoopBy48\@: - ComputeBlockLoop \Isa\(),48 - jmp .LComputeBlockLoopDone\@ - -.LComputeBlockLoopBy32\@: - ComputeBlockLoop \Isa\(),32 - jmp .LComputeBlockLoopDone\@ - -.LComputeBlockLoopBy16\@: - ComputeBlockLoop \Isa\(),16 - -.LComputeBlockLoopDone\@: - dec r9 # decrement input blocks remaining - jnz .LProcessNextInputBlock\@ - -// -// Post-process the block accumulators. -// - - mov ebx,DWORD PTR .LConvSymKernelFrame_OutputCount[rsp] - mov rsi,.LConvSymKernelFrame_OutputChannels[rsp] - mov rdx,.LConvSymKernelFrame_PostProcessParams[rsp] - mov ebp,DWORD PTR .LConvSymKernelFrame_KernelFlags[rsp] - call MlasConvSymPostProcessAvx512Core - -// -// Restore non-volatile registers and return. -// - -.LExitKernel\@: - vzeroupper - add rsp,.LConvSymKernelFrame_SavedR15 - pop r15 - pop r14 - pop r13 - pop r12 - pop rbx - pop rbp - ret - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner kernel to compute a depthwise - convolution for the elements of an output row for a set of filter rows. - -Arguments: - - Isa - Supplies the instruction set architecture string. - ---*/ - - .macro ConvSymDepthwiseKernelFunction Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a depthwise convolution for the - elements of an output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input indirection buffer. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - KernelSize (rcx) - Supplies the size of the kernel. - - Channels (r8) - Supplies the number of input and output channels. - - ChannelOffset (r9) - Supplies the byte offset from the indirection buffer base - address for this iteration. - - ChannelCount - Supplies the number of channels this iteration produces. - - This implementation requires the count to be in the range 1 to 64. - - OutputCount - Supplies the number of output elements this iteration produces. - - This implementation requires the count to be in the range 1 to 6. - - PostProcessParams - Supplies the address of the post process parameter block. - - KernelFlags - Supplies additional flags controlling the operation. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvSymDepthwiseKernel\Isa\() - - push rbp - push rbx - push r12 - push r13 - push r14 - push r15 - sub rsp,.LConvSymDepthwiseKernelFrame_SavedR15 - - mov .LConvSymDepthwiseKernelFrame_Channels[rsp],r8 - mov .LConvSymDepthwiseKernelFrame_ChannelOffset[rsp],r9 - mov r8,rdx # shuffle registers to Windows ABI - mov r9,rcx - mov rcx,rdi - mov rdx,rsi - - SetupRegistersCommon \Isa\(),.LConvSymDepthwiseKernelFrame - - mov rsi,.LConvSymDepthwiseKernelFrame_Channels[rsp] - mov ebp,DWORD PTR .LConvSymDepthwiseKernelFrame_OutputCount[rsp] - mov rax,.LConvSymDepthwiseKernelFrame_ChannelOffset[rsp] - mov ecx,DWORD PTR .LConvSymDepthwiseKernelFrame_ChannelCount[rsp] - -// -// Process an input block of length Channels for each element of the kernel. -// -// To keep code size small, this kernel always computes a fixed number of output -// rows. If the output count is less than this fixed number, then the first row -// is duplicated into the unused slots and the results are discarded. -// - -.LProcessNextInputBlock\@: - lea r11,[rbx+rdi] - lea r12,[rbx+rdi*2] - lea r13,[r11+rdi*2] - lea r14,[r12+rdi*2] - lea r15,[r13+rdi*2] - cmp ebp,2 - cmovb r11,rbx # use first row if output count is small - cmovbe r12,rbx - cmp ebp,4 - cmovb r13,rbx - cmovbe r14,rbx - cmp ebp,6 - cmovb r15,rbx - mov r10,QWORD PTR [rbx] - mov r11,QWORD PTR [r11] - mov r12,QWORD PTR [r12] - mov r13,QWORD PTR [r13] - mov r14,QWORD PTR [r14] - mov r15,QWORD PTR [r15] - add rbx,8 - cmp ecx,16 - jbe .LComputeDepthwiseBlockBy16\@ - cmp ecx,32 - jbe .LComputeDepthwiseBlockBy32\@ - cmp ecx,48 - jbe .LComputeDepthwiseBlockBy48\@ - -.LComputeDepthwiseBlockBy64\@: - vpmovzxbd zmm2{k4}{z},XMMWORD PTR [rdx+3*16] - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r10+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r11+rax+3*16] - MultiplyAccumulateCell\Isa\() zmm11,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm15,zmm1,zmm2 - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r12+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r13+rax+3*16] - MultiplyAccumulateCell\Isa\() zmm19,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm23,zmm1,zmm2 - vpmovzxbd zmm0{k4}{z},XMMWORD PTR [r14+rax+3*16] - vpmovzxbd zmm1{k4}{z},XMMWORD PTR [r15+rax+3*16] - MultiplyAccumulateCell\Isa\() zmm27,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm31,zmm1,zmm2 - -.LComputeDepthwiseBlockBy48\@: - vpmovzxbd zmm2{k3}{z},XMMWORD PTR [rdx+2*16] - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r10+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r11+rax+2*16] - MultiplyAccumulateCell\Isa\() zmm10,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm14,zmm1,zmm2 - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r12+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r13+rax+2*16] - MultiplyAccumulateCell\Isa\() zmm18,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm22,zmm1,zmm2 - vpmovzxbd zmm0{k3}{z},XMMWORD PTR [r14+rax+2*16] - vpmovzxbd zmm1{k3}{z},XMMWORD PTR [r15+rax+2*16] - MultiplyAccumulateCell\Isa\() zmm26,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm30,zmm1,zmm2 - -.LComputeDepthwiseBlockBy32\@: - vpmovzxbd zmm2{k2}{z},XMMWORD PTR [rdx+1*16] - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r10+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r11+rax+1*16] - MultiplyAccumulateCell\Isa\() zmm9,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm13,zmm1,zmm2 - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r12+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r13+rax+1*16] - MultiplyAccumulateCell\Isa\() zmm17,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm21,zmm1,zmm2 - vpmovzxbd zmm0{k2}{z},XMMWORD PTR [r14+rax+1*16] - vpmovzxbd zmm1{k2}{z},XMMWORD PTR [r15+rax+1*16] - MultiplyAccumulateCell\Isa\() zmm25,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm29,zmm1,zmm2 - -.LComputeDepthwiseBlockBy16\@: - vpmovzxbd zmm2{k1}{z},XMMWORD PTR [rdx] - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r10+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r11+rax] - MultiplyAccumulateCell\Isa\() zmm8,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm12,zmm1,zmm2 - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r12+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r13+rax] - MultiplyAccumulateCell\Isa\() zmm16,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm20,zmm1,zmm2 - vpmovzxbd zmm0{k1}{z},XMMWORD PTR [r14+rax] - vpmovzxbd zmm1{k1}{z},XMMWORD PTR [r15+rax] - MultiplyAccumulateCell\Isa\() zmm24,zmm0,zmm2 - MultiplyAccumulateCell\Isa\() zmm28,zmm1,zmm2 - add rdx,rsi # advance filter to next kernel - dec r9 # decrement input blocks remaining - jnz .LProcessNextInputBlock\@ - -// -// Post-process the block accumulators. -// - - mov ebx,ebp - mov rdx,.LConvSymDepthwiseKernelFrame_PostProcessParams[rsp] - mov ebp,DWORD PTR .LConvSymDepthwiseKernelFrame_KernelFlags[rsp] - call MlasConvSymPostProcessAvx512Core - -// -// Restore non-volatile registers and return. -// - -.LExitKernel\@: - vzeroupper - add rsp,.LConvSymDepthwiseKernelFrame_SavedR15 - pop r15 - pop r14 - pop r13 - pop r12 - pop rbx - pop rbp - ret - - .endm - -/*++ - -Macro Description: - - This macro generates code to convert the block accumulators from the matrix - multiply loop to float values. - -Arguments: - - RegList - Supplies the list of vector registers to operate on. - - ScaleReg - Supplies the output scale vector. - -Implicit Arguments: - - zmm4 - Supplies the integer bias vector. - ---*/ - - .macro ConvertAccumulatorToFloatRegList RegList, ScaleReg - -// -// Offset each value by the per-channel bias value, convert to floating point, -// and apply the output scale. -// - - EmitForEachRegister "\RegList\()","vpaddd \RegItem\(),\RegItem\(),zmm4" - EmitForEachRegister "\RegList\()","vcvtdq2ps \RegItem\(),\RegItem\()" - EmitForEachRegister "\RegList\()","vmulps \RegItem\(),\RegItem\(),\ScaleReg\()" - - .endm - -/*++ - -Macro Description: - - This macro generates code to convert float values to 32-bit integers in the - range 0 to 255. - -Arguments: - - RegList - Supplies the list of vector registers to operate on. - -Implicit Arguments: - - zmm0 - Supplies the broadcasted minimum clip float value. - - This is set to static_cast(0 - ZeroPointValue). - - zmm1 - Supplies the broadcasted maximum clip float value. - - This is set to static_cast(255 - ZeroPointValue). - - zmm2 - Supplies the broadcasted zero point integer value. - ---*/ - - .macro ConvertFloatToIntegerRegList RegList - -// -// Clip the float values to the integer range covered by the output zero point. -// This also keeps values outside the range INT_MIN to INT_MAX from converting -// to INT_MIN. -// - - EmitForEachRegister "\RegList\()","vmaxps \RegItem\(),\RegItem\(),zmm0" - EmitForEachRegister "\RegList\()","vminps \RegItem\(),\RegItem\(),zmm1" - -// -// Convert the float value to integer and add the zero point offset. -// - - EmitForEachRegister "\RegList\()","vcvtps2dq \RegItem\(),\RegItem\()" - EmitForEachRegister "\RegList\()","vpaddd \RegItem\(),\RegItem\(),zmm2" - - .endm - -/*++ - -Routine Description: - - This routine post processes the block accumulators produced by the convolution - kernels, including type conversion, requantization, and storing to the output - buffer. - -Arguments: - -Return Value: - - None. - ---*/ - -MlasConvSymPostProcessAvx512Core: - -// -// Apply the bias and convert the block accumulators to intermediate float values. -// - - mov r10,.LConvSymPostProcessParams_Bias[rdx] - mov r11,.LConvSymPostProcessParams_Scale[rdx] - test bpl,MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE - jz .LPostProcess.BroadcastScaleValue - vmovups zmm0{k1}{z},ZMMWORD PTR [r11] - vmovups zmm1{k2}{z},ZMMWORD PTR [r11+16*4] - vmovups zmm2{k3}{z},ZMMWORD PTR [r11+32*4] - vmovups zmm3{k4}{z},ZMMWORD PTR [r11+48*4] - jmp .LPostProcess.ConvertAccumulatorsToFloat - -.LPostProcess.BroadcastScaleValue: - vbroadcastss zmm0,DWORD PTR [r11] - vmovups zmm1,zmm0 - vmovups zmm2,zmm0 - vmovups zmm3,zmm0 - -.LPostProcess.ConvertAccumulatorsToFloat: - cmp ecx,16 - jbe .LPostProcess.ConvertAccumulatorsToFloatBy16 - cmp ecx,32 - jbe .LPostProcess.ConvertAccumulatorsToFloatBy32 - cmp ecx,48 - jbe .LPostProcess.ConvertAccumulatorsToFloatBy48 - -.LPostProcess.ConvertAccumulatorsToFloatBy64: - vmovdqu32 zmm4{k4}{z},ZMMWORD PTR [r10+48*4] - ConvertAccumulatorToFloatRegList "zmm11,zmm15,zmm19,zmm23,zmm27,zmm31",zmm3 - -.LPostProcess.ConvertAccumulatorsToFloatBy48: - vmovdqu32 zmm4{k3}{z},ZMMWORD PTR [r10+32*4] - ConvertAccumulatorToFloatRegList "zmm10,zmm14,zmm18,zmm22,zmm26,zmm30",zmm2 - -.LPostProcess.ConvertAccumulatorsToFloatBy32: - vmovdqu32 zmm4{k2}{z},ZMMWORD PTR [r10+16*4] - ConvertAccumulatorToFloatRegList "zmm9,zmm13,zmm17,zmm21,zmm25,zmm29",zmm1 - -.LPostProcess.ConvertAccumulatorsToFloatBy16: - vmovdqu32 zmm4{k1}{z},ZMMWORD PTR [r10] - ConvertAccumulatorToFloatRegList "zmm8,zmm12,zmm16,zmm20,zmm24,zmm28",zmm0 - -// -// Convert the intermediate float values to 32-bit integers in the range 0 to 255. -// - - vbroadcastss zmm0,DWORD PTR .LConvSymPostProcessParams_MinimumValue[rdx] - vbroadcastss zmm1,DWORD PTR .LConvSymPostProcessParams_MaximumValue[rdx] - vpbroadcastd zmm2,DWORD PTR .LConvSymPostProcessParams_OutputZeroPoint[rdx] - cmp ecx,16 - jbe .LPostProcess.ConvertFloatsToIntegerBy16 - cmp ecx,32 - jbe .LPostProcess.ConvertFloatsToIntegerBy32 - cmp ecx,48 - jbe .LPostProcess.ConvertFloatsToIntegerBy48 - -.LPostProcess.ConvertFloatsToIntegerBy64: - ConvertFloatToIntegerRegList "zmm11,zmm15,zmm19,zmm23,zmm27,zmm31" - -.LPostProcess.ConvertFloatsToIntegerBy48: - ConvertFloatToIntegerRegList "zmm10,zmm14,zmm18,zmm22,zmm26,zmm30" - -.LPostProcess.ConvertFloatsToIntegerBy32: - ConvertFloatToIntegerRegList "zmm9,zmm13,zmm17,zmm21,zmm25,zmm29" - -.LPostProcess.ConvertFloatsToIntegerBy16: - ConvertFloatToIntegerRegList "zmm8,zmm12,zmm16,zmm20,zmm24,zmm28" - -// -// Pack with saturation and store 1 to 64 bytes to the output buffer. -// - -.LPostProcess.StoreQuantizedOutput: - lea r9,[rsi*2+rsi] - add r9,r8 - cmp ebx,5 - ja .LPostProcess.StoreQuantizedOutput6 - je .LPostProcess.StoreQuantizedOutput5 - cmp ebx,3 - ja .LPostProcess.StoreQuantizedOutput4 - je .LPostProcess.StoreQuantizedOutput3 - cmp ebx,1 - ja .LPostProcess.StoreQuantizedOutput2 - jmp .LPostProcess.StoreQuantizedOutput1 - -.LPostProcess.StoreQuantizedOutput6: - vpmovusdb XMMWORD PTR [r9+rsi*2]{k1},zmm28 - vpmovusdb XMMWORD PTR [r9+rsi*2+16]{k2},zmm29 - vpmovusdb XMMWORD PTR [r9+rsi*2+32]{k3},zmm30 - vpmovusdb XMMWORD PTR [r9+rsi*2+48]{k4},zmm31 - -.LPostProcess.StoreQuantizedOutput5: - vpmovusdb XMMWORD PTR [r9+rsi]{k1},zmm24 - vpmovusdb XMMWORD PTR [r9+rsi+16]{k2},zmm25 - vpmovusdb XMMWORD PTR [r9+rsi+32]{k3},zmm26 - vpmovusdb XMMWORD PTR [r9+rsi+48]{k4},zmm27 - -.LPostProcess.StoreQuantizedOutput4: - vpmovusdb XMMWORD PTR [r9]{k1},zmm20 - vpmovusdb XMMWORD PTR [r9+16]{k2},zmm21 - vpmovusdb XMMWORD PTR [r9+32]{k3},zmm22 - vpmovusdb XMMWORD PTR [r9+48]{k4},zmm23 - -.LPostProcess.StoreQuantizedOutput3: - vpmovusdb XMMWORD PTR [r8+rsi*2]{k1},zmm16 - vpmovusdb XMMWORD PTR [r8+rsi*2+16]{k2},zmm17 - vpmovusdb XMMWORD PTR [r8+rsi*2+32]{k3},zmm18 - vpmovusdb XMMWORD PTR [r8+rsi*2+48]{k4},zmm19 - -.LPostProcess.StoreQuantizedOutput2: - vpmovusdb XMMWORD PTR [r8+rsi]{k1},zmm12 - vpmovusdb XMMWORD PTR [r8+rsi+16]{k2},zmm13 - vpmovusdb XMMWORD PTR [r8+rsi+32]{k3},zmm14 - vpmovusdb XMMWORD PTR [r8+rsi+48]{k4},zmm15 - -.LPostProcess.StoreQuantizedOutput1: - vpmovusdb XMMWORD PTR [r8]{k1},zmm8 - vpmovusdb XMMWORD PTR [r8+16]{k2},zmm9 - vpmovusdb XMMWORD PTR [r8+32]{k3},zmm10 - vpmovusdb XMMWORD PTR [r8+48]{k4},zmm11 - ret - -// -// Generate the convolution kernels. -// - -ConvSymKernelFunction Avx512Core -ConvSymDepthwiseKernelFunction Avx512Core - -ConvSymKernelFunction Avx512Vnni -ConvSymDepthwiseKernelFunction Avx512Vnni - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelCommon.h deleted file mode 100644 index 2f3f8c8336983..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelCommon.h +++ /dev/null @@ -1,67 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ConvSymKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the symmetric - quantized integer convolution operation. - ---*/ - -// -// Define the convolution kernel flags. -// - -#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 0x00000001 -#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 0x00000002 - -// -// Define the structure of the post process parameter block. -// - - .equ .LConvSymPostProcessParams_Bias, 0 - .equ .LConvSymPostProcessParams_Scale, 8 - .equ .LConvSymPostProcessParams_MinimumValue, 16 - .equ .LConvSymPostProcessParams_MaximumValue, 20 - .equ .LConvSymPostProcessParams_OutputZeroPoint, 24 - -// -// Stack frame layout for the symmetric convolution kernels. -// - - .equ .LConvSymKernelFrame_InputChannels, 0 - .equ .LConvSymKernelFrame_OutputChannels, 8 - .equ .LConvSymKernelFrame_Padding, 16 - .equ .LConvSymKernelFrame_SavedR15, 24 - .equ .LConvSymKernelFrame_SavedR14, 32 - .equ .LConvSymKernelFrame_SavedR13, 40 - .equ .LConvSymKernelFrame_SavedR12, 48 - .equ .LConvSymKernelFrame_SavedRbx, 56 - .equ .LConvSymKernelFrame_SavedRbp, 64 - .equ .LConvSymKernelFrame_ReturnAddress, 72 - .equ .LConvSymKernelFrame_ChannelCount, 80 - .equ .LConvSymKernelFrame_OutputCount, 88 - .equ .LConvSymKernelFrame_PostProcessParams, 96 - .equ .LConvSymKernelFrame_KernelFlags, 104 - - .equ .LConvSymDepthwiseKernelFrame_Channels, 0 - .equ .LConvSymDepthwiseKernelFrame_ChannelOffset, 8 - .equ .LConvSymDepthwiseKernelFrame_Padding, 16 - .equ .LConvSymDepthwiseKernelFrame_SavedR15, 24 - .equ .LConvSymDepthwiseKernelFrame_SavedR14, 32 - .equ .LConvSymDepthwiseKernelFrame_SavedR13, 40 - .equ .LConvSymDepthwiseKernelFrame_SavedR12, 48 - .equ .LConvSymDepthwiseKernelFrame_SavedRbx, 56 - .equ .LConvSymDepthwiseKernelFrame_SavedRbp, 64 - .equ .LConvSymDepthwiseKernelFrame_ReturnAddress, 72 - .equ .LConvSymDepthwiseKernelFrame_ChannelCount, 80 - .equ .LConvSymDepthwiseKernelFrame_OutputCount, 88 - .equ .LConvSymDepthwiseKernelFrame_PostProcessParams, 96 - .equ .LConvSymDepthwiseKernelFrame_KernelFlags, 104 diff --git a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx.S deleted file mode 100644 index 8f791a734adda..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelAvx.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" -#include "DgemmKernelCommon.h" -#include "FgemmKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelAvxFunction MlasGemmDoubleKernelAvx - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx512F.S deleted file mode 100644 index 23f8afcb2bd1e..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelAvx512F.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelAvx512F.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" -#include "DgemmKernelCommon.h" -#include "FgemmKernelAvx512FCommon.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelAvx512FFunction MlasGemmDoubleKernelAvx512F - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/DgemmKernelCommon.h deleted file mode 100644 index ba53a91b68c45..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelCommon.h +++ /dev/null @@ -1,50 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the double - precision matrix/matrix multiply operation (DGEMM). - ---*/ - -// -// Define the double precision parameters. -// - - .equ .LFgemmElementShift, 3 - .equ .LFgemmElementSize, 1 << .LFgemmElementShift - -#include "FgemmKernelCommon.h" - -// -// Define the typed instructions for double precision. -// - -FGEMM_TYPED_INSTRUCTION(addpf, addpd) -FGEMM_TYPED_INSTRUCTION(movsf, movsd) -FGEMM_TYPED_INSTRUCTION(movupf, movupd) - -FGEMM_TYPED_INSTRUCTION(vaddpf, vaddpd) -FGEMM_TYPED_INSTRUCTION(vbroadcastsf, vbroadcastsd) -FGEMM_TYPED_INSTRUCTION(vfmadd213pf, vfmadd213pd) -FGEMM_TYPED_INSTRUCTION(vfmadd231pf, vfmadd231pd) -FGEMM_TYPED_INSTRUCTION(vmaskmovpf, vmaskmovpd) -FGEMM_TYPED_INSTRUCTION(vmovapf, vmovapd) -FGEMM_TYPED_INSTRUCTION(vmovsf, vmovsd) -FGEMM_TYPED_INSTRUCTION(vmovupf, vmovupd) -FGEMM_TYPED_INSTRUCTION(vmulpf, vmulpd) -FGEMM_TYPED_INSTRUCTION(vxorpf, vxorpd) - - .macro vfmadd231pf_bcst DestReg, SrcReg, Address - - vfmadd231pd \DestReg\(), \SrcReg\(), \Address\(){1to8} - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/DgemmKernelFma3.S deleted file mode 100644 index 707882af7828b..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelFma3.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelFma3.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "DgemmKernelCommon.h" -#include "FgemmKernelFma3Common.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelFma3Function MlasGemmDoubleKernelFma3 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelSse2.S b/onnxruntime/core/mlas/lib/x86_64/DgemmKernelSse2.S deleted file mode 100644 index 929eaf3510af6..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/DgemmKernelSse2.S +++ /dev/null @@ -1,230 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - DgemmKernelSse2.s - -Abstract: - - This module implements the kernels for the double precision matrix/matrix - multiply operation (DGEMM). - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" -#include "DgemmKernelCommon.h" -#include "FgemmKernelSse2Common.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 8xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rsi - Supplies the address into the matrix B data. - - xmm0-xmm1 - Supplies up to two elements loaded from matrix A and matrix A - plus one row. - - xmm8-xmm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseBy8 RowCount - - movapd xmm4,XMMWORD PTR [rsi] - movapd xmm5,XMMWORD PTR [rsi+16] -.if \RowCount\() == 2 - movapd xmm6,xmm4 - movapd xmm7,xmm5 -.endif - mulpd xmm4,xmm0 - mulpd xmm5,xmm0 - addpd xmm8,xmm4 - addpd xmm9,xmm5 -.if \RowCount\() == 2 - mulpd xmm6,xmm1 - mulpd xmm7,xmm1 - addpd xmm12,xmm6 - addpd xmm13,xmm7 -.endif - movapd xmm4,XMMWORD PTR [rsi+32] - movapd xmm5,XMMWORD PTR [rsi+48] -.if \RowCount\() == 2 - movapd xmm6,xmm4 - movapd xmm7,xmm5 -.endif - mulpd xmm4,xmm0 - mulpd xmm5,xmm0 - addpd xmm10,xmm4 - addpd xmm11,xmm5 -.if \RowCount\() == 2 - mulpd xmm6,xmm1 - mulpd xmm7,xmm1 - addpd xmm14,xmm6 - addpd xmm15,xmm7 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rdx - Supplies the address of matrix C. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - rax - Supplies the length in bytes of a row from matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - -.LProcessNextColumnLoop8xN\@: - EmitIfCountGE \RowCount\(), 1, "xorpd xmm8,xmm8" - EmitIfCountGE \RowCount\(), 1, "xorpd xmm9,xmm9" - EmitIfCountGE \RowCount\(), 1, "xorpd xmm10,xmm10" - EmitIfCountGE \RowCount\(), 1, "xorpd xmm11,xmm11" - EmitIfCountGE \RowCount\(), 2, "xorpd xmm12,xmm12" - EmitIfCountGE \RowCount\(), 2, "xorpd xmm13,xmm13" - EmitIfCountGE \RowCount\(), 2, "xorpd xmm14,xmm14" - EmitIfCountGE \RowCount\(), 2, "xorpd xmm15,xmm15" - mov rbp,rcx # reload CountK - -.LCompute8xNBlockBy1Loop\@: - EmitIfCountGE \RowCount\(), 1, "movsd xmm0,[rdi]" - EmitIfCountGE \RowCount\(), 1, "movlhps xmm0,xmm0" - EmitIfCountGE \RowCount\(), 2, "movsd xmm1,[rdi+r10]" - EmitIfCountGE \RowCount\(), 2, "movlhps xmm1,xmm1" - ComputeBlockSseBy8 \RowCount\() - add rsi,8*8 # advance matrix B by 8 columns - add rdi,8 # advance matrix A by 1 column - dec rbp - jne .LCompute8xNBlockBy1Loop\@ - -.LOutput8xNBlock\@: - movsd xmm2,.LFgemmKernelFrame_alpha[rsp] - movlhps xmm2,xmm2 - EmitIfCountGE \RowCount\(), 1, "mulpd xmm8,xmm2" - # multiply by alpha - EmitIfCountGE \RowCount\(), 1, "mulpd xmm9,xmm2" - EmitIfCountGE \RowCount\(), 1, "mulpd xmm10,xmm2" - EmitIfCountGE \RowCount\(), 1, "mulpd xmm11,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulpd xmm12,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulpd xmm13,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulpd xmm14,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulpd xmm15,xmm2" - sub r9,8 - jb .LOutputPartial8xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 4 - add rdx,8*8 # advance matrix C by 8 columns - mov rdi,r11 # reload matrix A - test r9,r9 - jnz .LProcessNextColumnLoop8xN\@ - jmp .LExitKernel - -// -// Output a partial 8xN block to the matrix. -// - -.LOutputPartial8xNBlock\@: - add r9,8 # correct for over-subtract above - cmp r9,2 - jb .LOutputPartial1xNBlock\@ - cmp r9,4 - jb .LOutputPartialLessThan4xNBlock\@ - cmp r9,6 - jb .LOutputPartialLessThan6xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 3 - test r9d,1 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movapd xmm8,xmm11" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movapd xmm12,xmm15" - add rdx,6*8 # advance matrix C by 6 columns - jmp .LOutputPartial1xNBlock\@ - -.LOutputPartialLessThan6xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 2 - test r9d,1 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movapd xmm8,xmm10" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movapd xmm12,xmm14" - add rdx,4*8 # advance matrix C by 4 columns - jmp .LOutputPartial1xNBlock\@ - -.LOutputPartialLessThan4xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 1 - test r9d,1 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movapd xmm8,xmm9" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movapd xmm12,xmm13" - add rdx,2*8 # advance matrix C by 2 columns - -.LOutputPartial1xNBlock\@: - test r15b,r15b # ZeroMode? - jnz .LSkipAccumulateOutput1xN\@ - EmitIfCountGE \RowCount\(), 1, "addsd xmm8,[rdx]" - EmitIfCountGE \RowCount\(), 2, "addsd xmm12,[rdx+rax]" - -.LSkipAccumulateOutput1xN\@: - EmitIfCountGE \RowCount\(), 1, "movsd [rdx],xmm8" - EmitIfCountGE \RowCount\(), 2, "movsd [rdx+rax],xmm12" -.ifb \Fallthrough\() - jmp .LExitKernel -.endif - - .endm - -// -// Generate the GEMM kernel. -// - -FgemmKernelSse2Function MlasGemmDoubleKernelSse - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S deleted file mode 100644 index 92b7976d7db79..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/ErfKernelFma3.S +++ /dev/null @@ -1,517 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - ErfKernelFma3.s - -Abstract: - - This module implements a kernel for computing the error function for a - buffer of elements. - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -// -// Structure layout for the erf constants block. -// - - .equ ErfUpperAbsRange, 0 - .equ ErfSplitBoundary, 4 - .equ ErfSMALL_P0, 8 - .equ ErfSMALL_P1, 12 - .equ ErfSMALL_P2, 16 - .equ ErfSMALL_P3, 20 - .equ ErfSMALL_P4, 24 - .equ ErfSMALL_P5_Minus_One, 28 - .equ ErfReserve0, 32 - .equ ErfBIG_P0, 36 - .equ ErfBIG_P1, 40 - .equ ErfBIG_P2, 44 - .equ ErfBIG_P3, 48 - .equ ErfBIG_P4, 52 - .equ ErfBIG_P5, 56 - .equ ErfBIG_P6_Minus_One, 60 - .equ ErfNegZero, 64 - .equ ErfOne, 68 - - .equ ExpConstOffset, 72 - .equ Exp_UpperRange, 0 + ExpConstOffset - .equ Exp_LowerRange, 4 + ExpConstOffset - .equ Exp_Log2Reciprocal, 8 + ExpConstOffset - .equ Exp_log2_hi, 12 + ExpConstOffset - .equ Exp_log2_lo, 16 + ExpConstOffset - .equ Exp_P0, 20 + ExpConstOffset - .equ Exp_P1, 24 + ExpConstOffset - .equ Exp_P2, 28 + ExpConstOffset - .equ Exp_P3, 32 + ExpConstOffset - .equ Exp_P4, 36 + ExpConstOffset - .equ Exp_P5, 40 + ExpConstOffset - .equ Exp_P6, 44 + ExpConstOffset - .equ Exp_C, 48 + ExpConstOffset - .equ Exp_X7F, 52 + ExpConstOffset - -// -// Stack frame layout for the erf kernel. -// - .equ ErfBuffer0, 0 - .equ ErfBuffer1, 128 - .equ ErfKernelFrame_CountN, 256 - .equ ErfKernelFrame_ReturnAddress, 256+8 - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the error function. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ - - .globl C_UNDERSCORE(MlasErfKernelFma3) -C_UNDERSCORE(MlasErfKernelFma3): - sub rsp,ErfKernelFrame_ReturnAddress - lea rax,C_UNDERSCORE(MlasErfConstants)[rip] - - sub rdx,8*4 - jb .LErfProcessRemainingCount - -.LComputeErf4x8Loop: - vbroadcastss ymm15,ErfNegZero[rax] - vmovups ymm0,YMMWORD PTR [rdi] # original input vx0 - vmovups ymm1,YMMWORD PTR [rdi+32] # original input vx1 - vmovups ymm2,YMMWORD PTR [rdi+64] # original input vx2 - vmovups ymm3,YMMWORD PTR [rdi+96] # original input vx3 - - vandps ymm4,ymm0,ymm15 # vsign0 - vandps ymm5,ymm1,ymm15 # vsign1 - vandps ymm6,ymm2,ymm15 # vsign2 - vandps ymm7,ymm3,ymm15 # vsign3 - vandnps ymm0,ymm15,ymm0 # abs(vx0) va0 - vandnps ymm1,ymm15,ymm1 # abs(vx1) va1 - vandnps ymm2,ymm15,ymm2 # abs(vx2) va2 - vandnps ymm3,ymm15,ymm3 # abs(vx3) va3 - - vbroadcastss ymm14,ErfUpperAbsRange[rax] - vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4 - vmovups YMMWORD PTR ErfBuffer0[rsp+32],ymm5 - vmovups YMMWORD PTR ErfBuffer0[rsp+64],ymm6 - vmovups YMMWORD PTR ErfBuffer0[rsp+96],ymm7 - - vbroadcastss ymm8,ErfSMALL_P0[rax] - vminps ymm0,ymm0,ymm14 # force abs value in range - vminps ymm1,ymm1,ymm14 - vminps ymm2,ymm2,ymm14 - vminps ymm3,ymm3,ymm14 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm15,ErfSMALL_P1[rax] - vmulps ymm4,ymm0,ymm0 # vs0 (square) - vmulps ymm5,ymm1,ymm1 # vs1 - vmulps ymm6,ymm2,ymm2 # vs2 - vmulps ymm7,ymm3,ymm3 # vs3 - - vbroadcastss ymm14,ErfSMALL_P2[rax] - vfmadd213ps ymm8,ymm4,ymm15 - vfmadd213ps ymm9,ymm5,ymm15 - vfmadd213ps ymm10,ymm6,ymm15 - vfmadd213ps ymm11,ymm7,ymm15 - - vbroadcastss ymm13,ErfSMALL_P3[rax] - vfmadd213ps ymm8,ymm4,ymm14 - vfmadd213ps ymm9,ymm5,ymm14 - vfmadd213ps ymm10,ymm6,ymm14 - vfmadd213ps ymm11,ymm7,ymm14 - - vbroadcastss ymm15,ErfSMALL_P4[rax] - vfmadd213ps ymm8,ymm4,ymm13 - vfmadd213ps ymm9,ymm5,ymm13 - vfmadd213ps ymm10,ymm6,ymm13 - vfmadd213ps ymm11,ymm7,ymm13 - - vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax] - vfmadd213ps ymm8,ymm4,ymm15 - vfmadd213ps ymm9,ymm5,ymm15 - vfmadd213ps ymm10,ymm6,ymm15 - vfmadd213ps ymm11,ymm7,ymm15 - - vfmadd213ps ymm8,ymm4,ymm14 - vfmadd213ps ymm9,ymm5,ymm14 - vfmadd213ps ymm10,ymm6,ymm14 - vfmadd213ps ymm11,ymm7,ymm14 - - vbroadcastss ymm12,ErfSplitBoundary[rax] - vfmadd213ps ymm8,ymm0,ymm0 - vfmadd213ps ymm9,ymm1,ymm1 - vfmadd213ps ymm10,ymm2,ymm2 - vfmadd213ps ymm11,ymm3,ymm3 - - vcmpgtps ymm4,ymm0,ymm12 # vmask0 - vcmpgtps ymm5,ymm1,ymm12 # vmask1 - vcmpgtps ymm6,ymm2,ymm12 # vmask2 - vcmpgtps ymm7,ymm3,ymm12 # vmask3 - - vandnps ymm8,ymm4,ymm8 - vandnps ymm9,ymm5,ymm9 - vandnps ymm10,ymm6,ymm10 - vandnps ymm11,ymm7,ymm11 - - vbroadcastss ymm15,ErfBIG_P1[rax] - vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8 - vmovups YMMWORD PTR ErfBuffer1[rsp+32],ymm9 - vmovups YMMWORD PTR ErfBuffer1[rsp+64],ymm10 - vmovups YMMWORD PTR ErfBuffer1[rsp+96],ymm11 - -.BiggerNumbers: - vbroadcastss ymm8,ErfBIG_P0[rax] - vandps ymm0,ymm4,ymm0 - vandps ymm1,ymm5,ymm1 - vandps ymm2,ymm6,ymm2 - vandps ymm3,ymm7,ymm3 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm14,ErfBIG_P2[rax] - vfmadd213ps ymm8,ymm0,ymm15 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,ErfBIG_P3[rax] - vfmadd213ps ymm8,ymm0,ymm14 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm15,ErfBIG_P4[rax] - vfmadd213ps ymm8,ymm0,ymm13 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm14,ErfBIG_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax] - vfmadd213ps ymm8,ymm0,ymm14 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm15,ErfNegZero[rax] - vfmadd213ps ymm8,ymm0,ymm13 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm14,Exp_LowerRange[rax] - vfmadd213ps ymm8,ymm0,ymm0 - vfmadd213ps ymm9,ymm1,ymm1 - vfmadd213ps ymm10,ymm2,ymm2 - vfmadd213ps ymm11,ymm3,ymm3 - - vbroadcastss ymm4,Exp_Log2Reciprocal[rax] - vxorps ymm8,ymm8,ymm15 - vxorps ymm9,ymm9,ymm15 - vxorps ymm10,ymm10,ymm15 - vxorps ymm11,ymm11,ymm15 - - vbroadcastss ymm13,Exp_C[rax] - vmovaps ymm5,ymm4 - vmovaps ymm6,ymm4 - vmovaps ymm7,ymm4 - - # expf(ymm8 -- ymm11) - vmaxps ymm8,ymm8,ymm14 - vmaxps ymm9,ymm9,ymm14 - vmaxps ymm10,ymm10,ymm14 - vmaxps ymm11,ymm11,ymm14 - - vbroadcastss ymm0,Exp_log2_hi[rax] - vfmadd213ps ymm4,ymm8,ymm13 - vfmadd213ps ymm5,ymm9,ymm13 - vfmadd213ps ymm6,ymm10,ymm13 - vfmadd213ps ymm7,ymm11,ymm13 - - vbroadcastss ymm15,Exp_log2_lo[rax] - vmovaps ymm1,ymm0 - vmovaps ymm2,ymm0 - vmovaps ymm3,ymm0 - - vsubps ymm4,ymm4,ymm13 # vr = round() - vsubps ymm5,ymm5,ymm13 - vsubps ymm6,ymm6,ymm13 - vsubps ymm7,ymm7,ymm13 - - vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve - vfmadd213ps ymm1,ymm5,ymm9 - vfmadd213ps ymm2,ymm6,ymm10 - vfmadd213ps ymm3,ymm7,ymm11 - - vbroadcastss ymm8,Exp_P0[rax] - vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo - vfmadd231ps ymm1,ymm5,ymm15 - vfmadd231ps ymm2,ymm6,ymm15 - vfmadd231ps ymm3,ymm7,ymm15 - vmovaps ymm9,ymm8 - vmovaps ymm10,ymm8 - vmovaps ymm11,ymm8 - - vbroadcastss ymm14,Exp_P1[rax] - vbroadcastss ymm13,Exp_P2[rax] - vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm12,Exp_P3[rax] - vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vbroadcastss ymm15,Exp_P4[rax] - vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3 - vfmadd213ps ymm9,ymm1,ymm12 - vfmadd213ps ymm10,ymm2,ymm12 - vfmadd213ps ymm11,ymm3,ymm12 - - vbroadcastss ymm14,Exp_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4 - vfmadd213ps ymm9,ymm1,ymm15 - vfmadd213ps ymm10,ymm2,ymm15 - vfmadd213ps ymm11,ymm3,ymm15 - - vbroadcastss ymm13,Exp_P6[rax] - vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5 - vfmadd213ps ymm9,ymm1,ymm14 - vfmadd213ps ymm10,ymm2,ymm14 - vfmadd213ps ymm11,ymm3,ymm14 - - vbroadcastss ymm12,Exp_X7F[rax] - vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6 - vfmadd213ps ymm9,ymm1,ymm13 - vfmadd213ps ymm10,ymm2,ymm13 - vfmadd213ps ymm11,ymm3,ymm13 - - vcvttps2dq ymm4,ymm4 - vcvttps2dq ymm5,ymm5 - vcvttps2dq ymm6,ymm6 - vcvttps2dq ymm7,ymm7 - - - vbroadcastss ymm15,ErfOne[rax] - vpaddd ymm4,ymm4,ymm12 # +127 - vpaddd ymm5,ymm5,ymm12 - vpaddd ymm6,ymm6,ymm12 - vpaddd ymm7,ymm7,ymm12 - - vpslld ymm4,ymm4,23 - vpslld ymm5,ymm5,23 - vpslld ymm6,ymm6,23 - vpslld ymm7,ymm7,23 - - vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf) - vmulps ymm9,ymm9,ymm5 - vmulps ymm10,ymm10,ymm6 - vmulps ymm11,ymm11,ymm7 - - vsubps ymm8,ymm15,ymm8 - vsubps ymm9,ymm15,ymm9 - vsubps ymm10,ymm15,ymm10 - vsubps ymm11,ymm15,ymm11 - - # merge small numbers' result - vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp] - vorps ymm9,ymm9,YMMWORD PTR ErfBuffer1[rsp+32] - vorps ymm10,ymm10,YMMWORD PTR ErfBuffer1[rsp+64] - vorps ymm11,ymm11,YMMWORD PTR ErfBuffer1[rsp+96] - - # copy sign - vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp] - vorps ymm1,ymm9,YMMWORD PTR ErfBuffer0[rsp+32] - vorps ymm2,ymm10,YMMWORD PTR ErfBuffer0[rsp+64] - vorps ymm3,ymm11,YMMWORD PTR ErfBuffer0[rsp+96] - - vmovups YMMWORD PTR [rsi],ymm0 - vmovups YMMWORD PTR [rsi+32],ymm1 - vmovups YMMWORD PTR [rsi+64],ymm2 - vmovups YMMWORD PTR [rsi+96],ymm3 - - add rdi,32*4 # advance by 4*8 elements - add rsi,32*4 - sub rdx,32 - jae .LComputeErf4x8Loop - -.LErfProcessRemainingCount: - add rdx,32 # correct for over-subtract above - jz .LErfBatchExp - -.LErfProcess1x8: - mov DWORD PTR ErfKernelFrame_CountN[rsp],edx - vbroadcastss ymm3,DWORD PTR ErfKernelFrame_CountN[rsp] - - vpcmpgtd ymm3,ymm3,YMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - vbroadcastss ymm15,ErfNegZero[rax] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] # original input vx0 - - vandps ymm4,ymm0,ymm15 # vsign0 - vandnps ymm0,ymm15,ymm0 # abs(vx0) va0 - - vbroadcastss ymm14,ErfUpperAbsRange[rax] - vmovups YMMWORD PTR ErfBuffer0[rsp],ymm4 - - vbroadcastss ymm8,ErfSMALL_P0[rax] - vminps ymm0,ymm0,ymm14 # force abs value in range - - vbroadcastss ymm15,ErfSMALL_P1[rax] - vmulps ymm4,ymm0,ymm0 # vs0 (square) - - vbroadcastss ymm14,ErfSMALL_P2[rax] - vfmadd213ps ymm8,ymm4,ymm15 - - vbroadcastss ymm13,ErfSMALL_P3[rax] - vfmadd213ps ymm8,ymm4,ymm14 - - vbroadcastss ymm15,ErfSMALL_P4[rax] - vfmadd213ps ymm8,ymm4,ymm13 - - vbroadcastss ymm14,ErfSMALL_P5_Minus_One[rax] - vfmadd213ps ymm8,ymm4,ymm15 - - vfmadd213ps ymm8,ymm4,ymm14 - - vbroadcastss ymm12,ErfSplitBoundary[rax] - vfmadd213ps ymm8,ymm0,ymm0 - - vcmpgtps ymm4,ymm0,ymm12 # vmask0 - - vandnps ymm8,ymm4,ymm8 - - vmovups YMMWORD PTR ErfBuffer1[rsp],ymm8 - -.BiggerNumbersRemaining: - vbroadcastss ymm15,ErfBIG_P1[rax] - vbroadcastss ymm8,ErfBIG_P0[rax] - vandps ymm0,ymm4,ymm0 - - vbroadcastss ymm14,ErfBIG_P2[rax] - vfmadd213ps ymm8,ymm0,ymm15 - - vbroadcastss ymm13,ErfBIG_P3[rax] - vfmadd213ps ymm8,ymm0,ymm14 - - vbroadcastss ymm15,ErfBIG_P4[rax] - vfmadd213ps ymm8,ymm0,ymm13 - - vbroadcastss ymm14,ErfBIG_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 - - vbroadcastss ymm13,ErfBIG_P6_Minus_One[rax] - vfmadd213ps ymm8,ymm0,ymm14 - - vbroadcastss ymm15,ErfNegZero[rax] - vfmadd213ps ymm8,ymm0,ymm13 - - vbroadcastss ymm14,Exp_LowerRange[rax] - vfmadd213ps ymm8,ymm0,ymm0 - - vbroadcastss ymm4,Exp_Log2Reciprocal[rax] - vxorps ymm8,ymm8,ymm15 - - vbroadcastss ymm13,Exp_C[rax] - - # expf(ymm8 -- ymm11) - vmaxps ymm8,ymm8,ymm14 - - vbroadcastss ymm0,Exp_log2_hi[rax] - vfmadd213ps ymm4,ymm8,ymm13 - - vbroadcastss ymm15,Exp_log2_lo[rax] - - vsubps ymm4,ymm4,ymm13 # vr = round() - - vfmadd213ps ymm0,ymm4,ymm8 # vf = vr * log2_hi + ve - - vbroadcastss ymm8,Exp_P0[rax] - - vfmadd231ps ymm0,ymm4,ymm15 # vf += vr * log_2_lo - - vbroadcastss ymm14,Exp_P1[rax] - - vbroadcastss ymm13,Exp_P2[rax] - vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p1 - - vbroadcastss ymm12,Exp_P3[rax] - vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p2 - - vbroadcastss ymm15,Exp_P4[rax] - vfmadd213ps ymm8,ymm0,ymm12 # *+ exp_p3 - - vbroadcastss ymm14,Exp_P5[rax] - vfmadd213ps ymm8,ymm0,ymm15 # *+ exp_p4 - - vbroadcastss ymm13,Exp_P6[rax] - vfmadd213ps ymm8,ymm0,ymm14 # *+ exp_p5 - - vbroadcastss ymm12,Exp_X7F[rax] - vfmadd213ps ymm8,ymm0,ymm13 # *+ exp_p6 - - vcvttps2dq ymm4,ymm4 - - vbroadcastss ymm15,ErfOne[rax] - vpaddd ymm4,ymm4,ymm12 # +127 - - vpslld ymm4,ymm4,23 - - vmulps ymm8,ymm8,ymm4 # 2^i * exp(vf) - - vsubps ymm8,ymm15,ymm8 - - # merge small numbers' result - vorps ymm8,ymm8,YMMWORD PTR ErfBuffer1[rsp] - - # copy sign - vorps ymm0,ymm8,YMMWORD PTR ErfBuffer0[rsp] - - vmaskmovps YMMWORD PTR [rsi],ymm3,ymm0 - - add rdi,8*4 - add rsi,8*4 - sub rdx,8 - jg .LErfProcess1x8 - -.LErfBatchExp: - vzeroupper - add rsp,ErfKernelFrame_ReturnAddress - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvx512FCommon.h b/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvx512FCommon.h deleted file mode 100644 index 9f243ee8c0829..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvx512FCommon.h +++ /dev/null @@ -1,529 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelAvx512FCommon.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses AVX512F instructions. - ---*/ - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 2 ZMMWORDs by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - r13 - Supplies the address into the matrix A data plus 6 rows. - - r14 - Supplies the address into the matrix A data plus 9 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - zmm4-zmm27 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvx512FBy2 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.ifnb \PrefetchOffset\() - prefetcht0 [rsi+\VectorOffset\()+\PrefetchOffset\()] - prefetcht0 [rsi+r12+\VectorOffset\()+\PrefetchOffset\()] -.endif -.if \RowCount\() == 1 - vbroadcastsf zmm3,[rdi+\BroadcastOffset\()] - vfmadd231pf zmm4,zmm3,ZMMWORD PTR [rsi+\VectorOffset\()] - vfmadd231pf zmm5,zmm3,ZMMWORD PTR [rsi+r12+\VectorOffset\()] -.else - vmovapf zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] - vmovapf zmm1,ZMMWORD PTR [rsi+r12+\VectorOffset\()] - EmitIfCountGE \RowCount\(), 1, "vbroadcastsf zmm3,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf zmm4,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf zmm5,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 2, "vbroadcastsf zmm3,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf zmm6,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf zmm7,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 3, "vbroadcastsf zmm3,[rdi+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf zmm8,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf zmm9,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 4, "vbroadcastsf zmm3,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf zmm10,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf zmm11,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 5, "vbroadcastsf zmm3,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf zmm12,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf zmm13,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 6, "vbroadcastsf zmm3,[rbx+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf zmm14,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf zmm15,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r13+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm16,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm17,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r13+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm18,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm19,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r13+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm20,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm21,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r14+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm22,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm23,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r14+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm24,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm25,zmm3,zmm1" - EmitIfCountGE \RowCount\(), 12, "vbroadcastsf zmm3,[r14+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm26,zmm3,zmm0" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf zmm27,zmm3,zmm1" -.endif - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 1 ZMMWORD by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - r13 - Supplies the address into the matrix A data plus 6 rows. - - r14 - Supplies the address into the matrix A data plus 9 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - zmm4-zmm27 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvx512FBy1 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.ifnb \PrefetchOffset\() - prefetcht0 [rsi+\VectorOffset\()+\PrefetchOffset\()] -.endif - vmovapf zmm0,ZMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf_bcst zmm5,zmm0,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf_bcst zmm7,zmm0,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf_bcst zmm9,zmm0,[rdi+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf_bcst zmm11,zmm0,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf_bcst zmm13,zmm0,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf_bcst zmm15,zmm0,[rbx+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm17,zmm0,[r13+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm19,zmm0,[r13+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm21,zmm0,[r13+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm23,zmm0,[r14+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm25,zmm0,[r14+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 12, "vfmadd231pf_bcst zmm27,zmm0,[r14+r10*2+\BroadcastOffset\()]" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - zmm4-zmm27 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvx512FLoop ComputeBlock, RowCount - -.if \RowCount\() > 3 - lea rbx,[r10*2+r10] -.if \RowCount\() == 12 - lea r13,[rdi+rbx*2] # compute matrix A plus 6 rows - lea r14,[r13+rbx] # compute matrix A plus 9 rows -.endif - add rbx,rdi # compute matrix A plus 3 rows -.endif - ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 3 -.if \RowCount\() > 3 - lea rbx,[rax*2+rax] -.if \RowCount\() == 12 - lea r13,[rdx+rbx*2] # compute matrix C plus 6 rows - lea r14,[r13+rbx] # compute matrix C plus 9 rows -.endif - add rbx,rdx # compute matrix C plus 3 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rdx - Supplies the address of matrix C. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - rax - Supplies the length in bytes of a row from matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount - - cmp r9,.LFgemmZmmElementCount - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop2xN\@: - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm16,zmm4" - # clear upper block accumulators - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm17,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm18,zmm4" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm19,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm20,zmm4" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm21,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm22,zmm4" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm23,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm24,zmm4" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm25,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm26,zmm4" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm27,zmm5" - ComputeBlockAvx512FLoop ComputeBlockAvx512FBy2, \RowCount\() - add rsi,r12 # advance matrix B by 64*CountK bytes - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlpha2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf zmm4,zmm31,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf zmm6,zmm31,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf zmm8,zmm31,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf zmm10,zmm31,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf zmm12,zmm31,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf zmm14,zmm31,ZMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm16,zmm31,ZMMWORD PTR [r13]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm18,zmm31,ZMMWORD PTR [r13+rax]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm20,zmm31,ZMMWORD PTR [r13+rax*2]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm22,zmm31,ZMMWORD PTR [r14]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm24,zmm31,ZMMWORD PTR [r14+rax]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm26,zmm31,ZMMWORD PTR [r14+rax*2]" - jmp .LStore2xNBlock\@ - -.LMultiplyAlpha2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf zmm4,zmm4,zmm31" - EmitIfCountGE \RowCount\(), 2, "vmulpf zmm6,zmm6,zmm31" - EmitIfCountGE \RowCount\(), 3, "vmulpf zmm8,zmm8,zmm31" - EmitIfCountGE \RowCount\(), 4, "vmulpf zmm10,zmm10,zmm31" - EmitIfCountGE \RowCount\(), 5, "vmulpf zmm12,zmm12,zmm31" - EmitIfCountGE \RowCount\(), 6, "vmulpf zmm14,zmm14,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm16,zmm16,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm18,zmm18,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm20,zmm20,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm22,zmm22,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm24,zmm24,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm26,zmm26,zmm31" - -.LStore2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf ZMMWORD PTR [rdx],zmm4" - EmitIfCountGE \RowCount\(), 2, "vmovupf ZMMWORD PTR [rdx+rax],zmm6" - EmitIfCountGE \RowCount\(), 3, "vmovupf ZMMWORD PTR [rdx+rax*2],zmm8" - EmitIfCountGE \RowCount\(), 4, "vmovupf ZMMWORD PTR [rbx],zmm10" - EmitIfCountGE \RowCount\(), 5, "vmovupf ZMMWORD PTR [rbx+rax],zmm12" - EmitIfCountGE \RowCount\(), 6, "vmovupf ZMMWORD PTR [rbx+rax*2],zmm14" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13],zmm16" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13+rax],zmm18" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13+rax*2],zmm20" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14],zmm22" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14+rax],zmm24" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14+rax*2],zmm26" - add rdx,64 # advance matrix C by ZMMWORD -.if \RowCount\() > 3 - add rbx,64 # advance matrix C plus 3 rows by ZMMWORD -.if \RowCount\() == 12 - add r13,64 # advance matrix C plus 6 rows by ZMMWORD - add r14,64 # advance matrix C plus 9 rows by ZMMWORD -.endif -.endif - sub r9,.LFgemmZmmElementCount - -.LOutput1xNBlock\@: - sub r9,.LFgemmZmmElementCount - jae .LOutput1xNBlockWithMask\@ - lea rcx,[r9+.LFgemmZmmElementCount] - # correct for over-subtract above - mov ebp,1 - shl ebp,cl - dec ebp - kmovw k1,ebp # update mask for remaining columns - xor r9,r9 # no more columns remaining - -.LOutput1xNBlockWithMask\@: - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlpha1xNBlockWithMask\@ - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf zmm5{k1},zmm31,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf zmm7{k1},zmm31,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf zmm9{k1},zmm31,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf zmm11{k1},zmm31,ZMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf zmm13{k1},zmm31,ZMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf zmm15{k1},zmm31,ZMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm17{k1},zmm31,ZMMWORD PTR [r13]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm19{k1},zmm31,ZMMWORD PTR [r13+rax]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm21{k1},zmm31,ZMMWORD PTR [r13+rax*2]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm23{k1},zmm31,ZMMWORD PTR [r14]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm25{k1},zmm31,ZMMWORD PTR [r14+rax]" - EmitIfCountGE \RowCount\(), 12, "vfmadd213pf zmm27{k1},zmm31,ZMMWORD PTR [r14+rax*2]" - jmp .LStore1xNBlockWithMask\@ - -.LMultiplyAlpha1xNBlockWithMask\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf zmm5,zmm5,zmm31" - EmitIfCountGE \RowCount\(), 2, "vmulpf zmm7,zmm7,zmm31" - EmitIfCountGE \RowCount\(), 3, "vmulpf zmm9,zmm9,zmm31" - EmitIfCountGE \RowCount\(), 4, "vmulpf zmm11,zmm11,zmm31" - EmitIfCountGE \RowCount\(), 5, "vmulpf zmm13,zmm13,zmm31" - EmitIfCountGE \RowCount\(), 6, "vmulpf zmm15,zmm15,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm17,zmm17,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm19,zmm19,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm21,zmm21,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm23,zmm23,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm25,zmm25,zmm31" - EmitIfCountGE \RowCount\(), 12, "vmulpf zmm27,zmm27,zmm31" - -.LStore1xNBlockWithMask\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf ZMMWORD PTR [rdx]{k1},zmm5" - EmitIfCountGE \RowCount\(), 2, "vmovupf ZMMWORD PTR [rdx+rax]{k1},zmm7" - EmitIfCountGE \RowCount\(), 3, "vmovupf ZMMWORD PTR [rdx+rax*2]{k1},zmm9" - EmitIfCountGE \RowCount\(), 4, "vmovupf ZMMWORD PTR [rbx]{k1},zmm11" - EmitIfCountGE \RowCount\(), 5, "vmovupf ZMMWORD PTR [rbx+rax]{k1},zmm13" - EmitIfCountGE \RowCount\(), 6, "vmovupf ZMMWORD PTR [rbx+rax*2]{k1},zmm15" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13]{k1},zmm17" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13+rax]{k1},zmm19" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r13+rax*2]{k1},zmm21" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14]{k1},zmm23" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14+rax]{k1},zmm25" - EmitIfCountGE \RowCount\(), 12, "vmovupf ZMMWORD PTR [r14+rax*2]{k1},zmm27" - add rdx,64 # advance matrix C by ZMMWORD - mov rdi,r11 # reload matrix A - vzeroall - cmp r9,.LFgemmZmmElementCount - ja .LProcessNextColumnLoop2xN\@ - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN\@: - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm17,zmm5" - # clear upper block accumulators - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm19,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm21,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm23,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm25,zmm5" - EmitIfCountGE \RowCount\(), 12, "vmovapf zmm27,zmm5" - ComputeBlockAvx512FLoop ComputeBlockAvx512FBy1, \RowCount\() - jmp .LOutput1xNBlock\@ - - .endm - -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelAvx512FFunction FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha (xmm0) - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY \FunctionName\() - - push rbp - push rbx - push r15 - mov .LFgemmKernelFrame_SavedR12[rsp],r12 - mov .LFgemmKernelFrame_SavedR13[rsp],r13 - mov .LFgemmKernelFrame_SavedR14[rsp],r14 - mov r11,rdi - mov r10,.LFgemmKernelFrame_lda[rsp] - shl r10,.LFgemmElementShift # convert lda to bytes - mov rax,.LFgemmKernelFrame_ldc[rsp] - shl rax,.LFgemmElementShift # convert ldc to bytes - mov r12,rcx - shl r12,6 # compute 64*CountK bytes - mov ebp,-1 - kmovw k1,ebp # update mask to write all columns - movzx r15,BYTE PTR .LFgemmKernelFrame_ZeroMode[rsp] - vbroadcastsf zmm31,xmm0 - vzeroall - -// -// Process CountM rows of the matrices. -// - - cmp r8,12 - jb .LProcessCountMLessThan12 - mov r8d,12 # return 12 rows handled - ProcessCountM 12 - -.LProcessCountMLessThan12: - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - mov r12,.LFgemmKernelFrame_SavedR12[rsp] - mov r13,.LFgemmKernelFrame_SavedR13[rsp] - mov r14,.LFgemmKernelFrame_SavedR14[rsp] - pop r15 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvxCommon.h b/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvxCommon.h deleted file mode 100644 index 69c8d17e2797b..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelAvxCommon.h +++ /dev/null @@ -1,451 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelAvxCommon.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses AVX instructions. - ---*/ - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 2 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm8-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxBy16 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.if \RowCount\() == 1 - vbroadcastsf ymm3,[rdi+\BroadcastOffset\()] - vmulpf ymm4,ymm3,YMMWORD PTR [rsi+\VectorOffset\()] - vaddpf ymm8,ymm8,ymm4 - vmulpf ymm5,ymm3,YMMWORD PTR [rsi+\VectorOffset\()+32] - vaddpf ymm9,ymm9,ymm5 -.else - vmovapf ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - vmovapf ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32] - EmitIfCountGE \RowCount\(), 1, "vbroadcastsf ymm3,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm4,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm8,ymm8,ymm4" - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm5,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm9,ymm9,ymm5" - EmitIfCountGE \RowCount\(), 2, "vbroadcastsf ymm3,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm6,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm10,ymm10,ymm6" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm7,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm11,ymm11,ymm7" - EmitIfCountGE \RowCount\(), 3, "vbroadcastsf ymm3,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm4,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm12,ymm12,ymm4" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm5,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm13,ymm13,ymm5" - EmitIfCountGE \RowCount\(), 4, "vbroadcastsf ymm3,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm6,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm14,ymm14,ymm6" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm7,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm15,ymm15,ymm7" -.endif - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 1 YMMWORD by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 2 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm8-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxBy8 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.if \RowCount\() == 1 - vbroadcastsf ymm3,[rdi+\BroadcastOffset\()] - vmulpf ymm5,ymm3,YMMWORD PTR [rsi+\VectorOffset\()] - vaddpf ymm9,ymm9,ymm5 -.else - vmovapf ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \RowCount\(), 1, "vbroadcastsf ymm3,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm5,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm9,ymm9,ymm5" - EmitIfCountGE \RowCount\(), 2, "vbroadcastsf ymm3,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm7,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm11,ymm11,ymm7" - EmitIfCountGE \RowCount\(), 3, "vbroadcastsf ymm3,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm5,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm13,ymm13,ymm5" - EmitIfCountGE \RowCount\(), 4, "vbroadcastsf ymm3,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm7,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm15,ymm15,ymm7" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxLoop ComputeBlock, RowCount - -.if \RowCount\() > 2 - lea rbx,[rdi+r10*2] # compute matrix A plus 2 rows -.endif - ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 2 -.if \RowCount\() > 2 - lea rbx,[rdx+rax*2] # compute matrix C plus 2 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rdx - Supplies the address of matrix C. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - rax - Supplies the length in bytes of a row from matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - - cmp r9,.LFgemmYmmElementCount - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop2xN\@: - EmitIfCountGE \RowCount\(), 1, "vxorpf xmm8,xmm8,xmm8" - EmitIfCountGE \RowCount\(), 1, "vxorpf xmm9,xmm9,xmm9" - EmitIfCountGE \RowCount\(), 2, "vxorpf xmm10,xmm10,xmm10" - EmitIfCountGE \RowCount\(), 2, "vxorpf xmm11,xmm11,xmm11" - EmitIfCountGE \RowCount\(), 3, "vxorpf xmm12,xmm12,xmm12" - EmitIfCountGE \RowCount\(), 3, "vxorpf xmm13,xmm13,xmm13" - EmitIfCountGE \RowCount\(), 4, "vxorpf xmm14,xmm14,xmm14" - EmitIfCountGE \RowCount\(), 4, "vxorpf xmm15,xmm15,xmm15" - ComputeBlockAvxLoop ComputeBlockAvxBy16, \RowCount\() - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm8,ymm8,ymm2" - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm9,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm10,ymm10,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm11,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm12,ymm12,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm13,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm14,ymm14,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm15,ymm15,ymm2" - sub r9,2*.LFgemmYmmElementCount - jb .LOutputMasked2xNBlock\@ - test r15b,r15b # ZeroMode? - jnz .LStore2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm8,ymm8,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm9,ymm9,YMMWORD PTR [rdx+32]" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm10,ymm10,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm11,ymm11,YMMWORD PTR [rdx+rax+32]" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm12,ymm12,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm13,ymm13,YMMWORD PTR [rbx+32]" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm14,ymm14,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm15,ymm15,YMMWORD PTR [rbx+rax+32]" - -.LStore2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm8" - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx+32],ymm9" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm10" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax+32],ymm11" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rbx],ymm12" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rbx+32],ymm13" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx+rax],ymm14" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx+rax+32],ymm15" - add rdx,2*32 # advance matrix C by 2 YMMWORDs - mov rdi,r11 # reload matrix A - cmp r9,.LFgemmYmmElementCount - ja .LProcessNextColumnLoop2xN\@ - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN\@: - EmitIfCountGE \RowCount\(), 1, "vxorpf xmm9,xmm9,xmm9" - EmitIfCountGE \RowCount\(), 2, "vxorpf xmm11,xmm11,xmm11" - EmitIfCountGE \RowCount\(), 3, "vxorpf xmm13,xmm13,xmm13" - EmitIfCountGE \RowCount\(), 4, "vxorpf xmm15,xmm15,xmm15" - ComputeBlockAvxLoop ComputeBlockAvxBy8, \RowCount\() - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm9,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm11,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm13,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm15,ymm15,ymm2" - cmp r9,.LFgemmYmmElementCount - jb .LOutputMasked1xNBlock\@ - test r15b,r15b # ZeroMode? - jnz .LStore1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm9,ymm9,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm11,ymm11,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm13,ymm13,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm15,ymm15,YMMWORD PTR [rbx+rax]" - -.LStore1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm9" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm11" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rbx],ymm13" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx+rax],ymm15" - jmp .LExitKernel - -.LOutputMasked2xNBlock\@: - test r15b,r15b # ZeroMode? - jnz .LStoreMasked2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm8,ymm8,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm10,ymm10,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm12,ymm12,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm14,ymm14,YMMWORD PTR [rbx+rax]" - -.LStoreMasked2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm8" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm10" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rbx],ymm12" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx+rax],ymm14" - add rdx,32 # advance matrix C by YMMWORD -.if \RowCount\() > 2 - add rbx,32 # advance matrix C plus 2 rows by YMMWORD -.endif - add r9,.LFgemmYmmElementCount # correct for over-subtract above - -.LOutputMasked1xNBlock\@: - neg r9 - lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovdqu ymm0,YMMWORD PTR [rdi+r9*.LFgemmElementSize] - test r15b,r15b # ZeroMode? - jnz .LStoreMasked1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vmaskmovpf ymm8,ymm0,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vmaskmovpf ymm10,ymm0,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vmaskmovpf ymm12,ymm0,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vmaskmovpf ymm14,ymm0,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 1, "vaddpf ymm9,ymm9,ymm8" - EmitIfCountGE \RowCount\(), 2, "vaddpf ymm11,ymm11,ymm10" - EmitIfCountGE \RowCount\(), 3, "vaddpf ymm13,ymm13,ymm12" - EmitIfCountGE \RowCount\(), 4, "vaddpf ymm15,ymm15,ymm14" - -.LStoreMasked1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmaskmovpf YMMWORD PTR [rdx],ymm0,ymm9" - EmitIfCountGE \RowCount\(), 2, "vmaskmovpf YMMWORD PTR [rdx+rax],ymm0,ymm11" - EmitIfCountGE \RowCount\(), 3, "vmaskmovpf YMMWORD PTR [rbx],ymm0,ymm13" - EmitIfCountGE \RowCount\(), 4, "vmaskmovpf YMMWORD PTR [rbx+rax],ymm0,ymm15" -.ifb \Fallthrough\() - jmp .LExitKernel -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelAvxFunction FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha (xmm0) - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY \FunctionName\() - - push rbp - push rbx - push r15 - mov r11,rdi - mov r10,.LFgemmKernelFrame_lda[rsp] - shl r10,.LFgemmElementShift # convert lda to bytes - mov rax,.LFgemmKernelFrame_ldc[rsp] - shl rax,.LFgemmElementShift # convert ldc to bytes - movzx r15,BYTE PTR .LFgemmKernelFrame_ZeroMode[rsp] - vmovsf .LFgemmKernelFrame_alpha[rsp],xmm0 - vbroadcastsf ymm2,.LFgemmKernelFrame_alpha[rsp] - -// -// Process 4 rows of the matrices. -// - - cmp r8,4 - jb .LProcessCountMLessThan4 - mov r8d,4 # return 4 rows handled - ProcessCountM 4, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - vzeroupper - mov eax,r8d - pop r15 - pop rbx - pop rbp - ret - -// -// Process 2 rows of the matrices. -// - -.LProcessCountMLessThan4: - cmp r8,2 - jb .LProcessCountMLessThan2 - mov r8d,2 # return 2 rows handled - ProcessCountM 2 - -// -// Process 1 row of the matrices. -// - -.LProcessCountMLessThan2: - ProcessCountM 1 - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/FgemmKernelCommon.h deleted file mode 100644 index f3e8890fd3aeb..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelCommon.h +++ /dev/null @@ -1,124 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the floating - point matrix/matrix multiply operation (SGEMM and DGEMM). - ---*/ - -// -// Stack frame layout for the floating point kernels. -// - - .equ .LFgemmKernelFrame_SavedR12, -32 - .equ .LFgemmKernelFrame_SavedR13, -24 - .equ .LFgemmKernelFrame_SavedR14, -16 - .equ .LFgemmKernelFrame_alpha, -8 - .equ .LFgemmKernelFrame_SavedR15, 0 - .equ .LFgemmKernelFrame_SavedRbx, 8 - .equ .LFgemmKernelFrame_SavedRbp, 16 - .equ .LFgemmKernelFrame_ReturnAddress, 24 - .equ .LFgemmKernelFrame_lda, 32 - .equ .LFgemmKernelFrame_ldc, 40 - .equ .LFgemmKernelFrame_ZeroMode, 48 - -// -// Define the number of elements per vector register. -// - - .equ .LFgemmXmmElementCount, 16 / .LFgemmElementSize - .equ .LFgemmYmmElementCount, 32 / .LFgemmElementSize - .equ .LFgemmZmmElementCount, 64 / .LFgemmElementSize - -// -// Define the typed instruction template. -// - -#define FGEMM_TYPED_INSTRUCTION(Untyped, Typed) \ - .macro Untyped Operand:vararg; Typed \Operand\(); .endm; - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - - AdvanceMatrixAPlusRows - Supplies a non-zero value if the data pointer - in rbx should also be advanced as part of the loop. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop ComputeBlock, RowCount, AdvanceMatrixAPlusRows - - mov rbp,rcx # reload CountK - sub rbp,4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - \ComputeBlock\() \RowCount\(), 0, .LFgemmElementSize*0, 64*4 - \ComputeBlock\() \RowCount\(), 2*32, .LFgemmElementSize*1, 64*4 - add_immed rsi,2*2*32 # advance matrix B by 128 bytes - \ComputeBlock\() \RowCount\(), 0, .LFgemmElementSize*2, 64*4 - \ComputeBlock\() \RowCount\(), 2*32, .LFgemmElementSize*3, 64*4 - add_immed rsi,2*2*32 # advance matrix B by 128 bytes - add rdi,4*.LFgemmElementSize # advance matrix A by 4 elements -.if \RowCount\() > 3 - add rbx,4*.LFgemmElementSize # advance matrix A plus rows by 4 elements -.if \RowCount\() == 12 - add r13,4*.LFgemmElementSize - add r14,4*.LFgemmElementSize -.endif -.endif - sub rbp,4 - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,4 # correct for over-subtract above - jz .LOutputBlock\@ - -.LComputeBlockBy1Loop\@: - \ComputeBlock\() \RowCount\(), 0, 0 - add rsi,2*32 # advance matrix B by 64 bytes - add rdi,.LFgemmElementSize # advance matrix A by 1 element -.if \RowCount\() > 3 - add rbx,.LFgemmElementSize # advance matrix A plus rows by 1 element -.if \RowCount\() == 12 - add r13,.LFgemmElementSize - add r14,.LFgemmElementSize -.endif -.endif - dec rbp - jne .LComputeBlockBy1Loop\@ - -.LOutputBlock\@: - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelFma3Common.h b/onnxruntime/core/mlas/lib/x86_64/FgemmKernelFma3Common.h deleted file mode 100644 index 77fc1f1d9fee9..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelFma3Common.h +++ /dev/null @@ -1,512 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelFma3Common.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 2 YMMWORDs by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockFma3By2 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.ifnb \PrefetchOffset\() - prefetcht0 [rsi+\VectorOffset\()+\PrefetchOffset\()] -.endif -.if \RowCount\() == 1 - vbroadcastsf ymm3,[rdi+\BroadcastOffset\()] - vfmadd231pf ymm4,ymm3,YMMWORD PTR [rsi+\VectorOffset\()] - vfmadd231pf ymm5,ymm3,YMMWORD PTR [rsi+\VectorOffset\()+32] -.else - vmovapf ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - vmovapf ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32] - EmitIfCountGE \RowCount\(), 1, "vbroadcastsf ymm3,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf ymm4,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf ymm5,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 2, "vbroadcastsf ymm3,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf ymm6,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf ymm7,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 3, "vbroadcastsf ymm3,[rdi+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf ymm8,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf ymm9,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 4, "vbroadcastsf ymm3,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf ymm10,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf ymm11,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 5, "vbroadcastsf ymm3,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf ymm12,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf ymm13,ymm3,ymm1" - EmitIfCountGE \RowCount\(), 6, "vbroadcastsf ymm3,[rbx+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf ymm14,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf ymm15,ymm3,ymm1" -.endif - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for 1 YMMWORD by N rows of the output - matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - - PrefetchOffset - Optionally supplies the byte offset from matrix B to - prefetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rbx - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockFma3By1 RowCount, VectorOffset, BroadcastOffset, PrefetchOffset - -.ifnb \PrefetchOffset\() - prefetcht0 [rsi+\VectorOffset\()+\PrefetchOffset\()] -.endif -.if \RowCount\() == 1 - vbroadcastsf ymm3,[rdi+\BroadcastOffset\()] - vfmadd231pf ymm5,ymm3,YMMWORD PTR [rsi+\VectorOffset\()] -.else - vmovapf ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \RowCount\(), 1, "vbroadcastsf ymm3,[rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "vfmadd231pf ymm5,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 2, "vbroadcastsf ymm3,[rdi+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "vfmadd231pf ymm7,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 3, "vbroadcastsf ymm3,[rdi+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "vfmadd231pf ymm9,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 4, "vbroadcastsf ymm3,[rbx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "vfmadd231pf ymm11,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 5, "vbroadcastsf ymm3,[rbx+r10+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "vfmadd231pf ymm13,ymm3,ymm0" - EmitIfCountGE \RowCount\(), 6, "vbroadcastsf ymm3,[rbx+r10*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "vfmadd231pf ymm15,ymm3,ymm0" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple - times and advancing the matrix A and matrix B data pointers. - -Arguments: - - ComputeBlock - Supplies the macro to compute a single block. - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockFma3Loop ComputeBlock, RowCount - -.if \RowCount\() > 3 - lea rbx,[r10*2+r10] - add rbx,rdi # compute matrix A plus 3 rows -.endif - ComputeBlockLoop \ComputeBlock\(), \RowCount\(), \RowCount\() > 3 - vbroadcastsf ymm2,[rsp+.LFgemmKernelFrame_alpha] -.if \RowCount\() > 3 - lea rbx,[rax*2+rax] - add rbx,rdx # compute matrix C plus 3 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernelAndZeroUpper label. - -Implicit Arguments: - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rdx - Supplies the address of matrix C. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - rax - Supplies the length in bytes of a row from matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - - cmp r9,.LFgemmYmmElementCount - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop2xN\@: - ComputeBlockFma3Loop ComputeBlockFma3By2, \RowCount\() - EmitIfCountGE \RowCount\(), 1, "prefetcht0 [rdx+64]" - EmitIfCountGE \RowCount\(), 2, "prefetcht0 [rdx+rax+64]" - EmitIfCountGE \RowCount\(), 3, "prefetcht0 [rdx+rax*2+64]" - EmitIfCountGE \RowCount\(), 4, "prefetcht0 [rbx+64]" - EmitIfCountGE \RowCount\(), 5, "prefetcht0 [rbx+rax+64]" - EmitIfCountGE \RowCount\(), 6, "prefetcht0 [rbx+rax*2+64]" - sub r9,2*.LFgemmYmmElementCount - jb .LOutputMasked2xNBlock\@ - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlpha2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf ymm4,ymm2,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf ymm5,ymm2,YMMWORD PTR [rdx+32]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf ymm6,ymm2,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf ymm7,ymm2,YMMWORD PTR [rdx+rax+32]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf ymm8,ymm2,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf ymm9,ymm2,YMMWORD PTR [rdx+rax*2+32]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf ymm10,ymm2,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf ymm11,ymm2,YMMWORD PTR [rbx+32]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf ymm12,ymm2,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf ymm13,ymm2,YMMWORD PTR [rbx+rax+32]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf ymm14,ymm2,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf ymm15,ymm2,YMMWORD PTR [rbx+rax*2+32]" - jmp .LStore2xNBlock\@ - -.LMultiplyAlpha2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm4,ymm4,ymm2" - # multiply by alpha - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm5,ymm5,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm6,ymm6,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm7,ymm7,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm8,ymm8,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm9,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm10,ymm10,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm11,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 5, "vmulpf ymm12,ymm12,ymm2" - EmitIfCountGE \RowCount\(), 5, "vmulpf ymm13,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 6, "vmulpf ymm14,ymm14,ymm2" - EmitIfCountGE \RowCount\(), 6, "vmulpf ymm15,ymm15,ymm2" - -.LStore2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx+32],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax+32],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rdx+rax*2+32],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx+32],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovupf YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 5, "vmovupf YMMWORD PTR [rbx+rax+32],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovupf YMMWORD PTR [rbx+rax*2],ymm14" - EmitIfCountGE \RowCount\(), 6, "vmovupf YMMWORD PTR [rbx+rax*2+32],ymm15" - add rdx,2*32 # advance matrix C by 2 YMMWORDs - mov rdi,r11 # reload matrix A - vzeroall - cmp r9,.LFgemmYmmElementCount - ja .LProcessNextColumnLoop2xN\@ - test r9,r9 - jz .LExitKernel - -.LProcessRemainingCountN\@: - ComputeBlockFma3Loop ComputeBlockFma3By1, \RowCount\() - cmp r9,.LFgemmYmmElementCount - jb .LOutputMasked1xNBlock\@ - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlpha1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf ymm5,ymm2,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf ymm7,ymm2,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf ymm9,ymm2,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf ymm11,ymm2,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf ymm13,ymm2,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf ymm15,ymm2,YMMWORD PTR [rbx+rax*2]" - jmp .LStore1xNBlock\@ - -.LMultiplyAlpha1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm5,ymm5,ymm2" - # multiply by alpha - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm7,ymm7,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm9,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm11,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 5, "vmulpf ymm13,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 6, "vmulpf ymm15,ymm15,ymm2" - -.LStore1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rdx+rax*2],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovupf YMMWORD PTR [rbx+rax],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovupf YMMWORD PTR [rbx+rax*2],ymm15" - jmp .LExitKernelAndZeroUpper - -.LOutputMasked2xNBlock\@: - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlphaMasked2xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf ymm4,ymm2,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf ymm6,ymm2,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf ymm8,ymm2,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf ymm10,ymm2,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf ymm12,ymm2,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf ymm14,ymm2,YMMWORD PTR [rbx+rax*2]" - jmp .LStoreMasked2xNBlock\@ - -.LMultiplyAlphaMasked2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm4,ymm4,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm6,ymm6,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm8,ymm8,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm10,ymm10,ymm2" - EmitIfCountGE \RowCount\(), 5, "vmulpf ymm12,ymm12,ymm2" - EmitIfCountGE \RowCount\(), 6, "vmulpf ymm14,ymm14,ymm2" - -.LStoreMasked2xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovupf YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 2, "vmovupf YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 3, "vmovupf YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 4, "vmovupf YMMWORD PTR [rbx],ymm10" - EmitIfCountGE \RowCount\(), 5, "vmovupf YMMWORD PTR [rbx+rax],ymm12" - EmitIfCountGE \RowCount\(), 6, "vmovupf YMMWORD PTR [rbx+rax*2],ymm14" - add rdx,32 # advance matrix C by YMMWORD -.if \RowCount\() > 3 - add rbx,32 # advance matrix C plus 3 rows by YMMWORD -.endif - add r9,.LFgemmYmmElementCount # correct for over-subtract above - -.LOutputMasked1xNBlock\@: - neg r9 - lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovdqu ymm0,YMMWORD PTR [rdi+r9*.LFgemmElementSize] - test r15b,r15b # ZeroMode? - jnz .LMultiplyAlphaMasked1xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vmaskmovpf ymm4,ymm0,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vmaskmovpf ymm6,ymm0,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vmaskmovpf ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vmaskmovpf ymm10,ymm0,YMMWORD PTR [rbx]" - EmitIfCountGE \RowCount\(), 5, "vmaskmovpf ymm12,ymm0,YMMWORD PTR [rbx+rax]" - EmitIfCountGE \RowCount\(), 6, "vmaskmovpf ymm14,ymm0,YMMWORD PTR [rbx+rax*2]" - EmitIfCountGE \RowCount\(), 1, "vfmadd213pf ymm5,ymm2,ymm4" - EmitIfCountGE \RowCount\(), 2, "vfmadd213pf ymm7,ymm2,ymm6" - EmitIfCountGE \RowCount\(), 3, "vfmadd213pf ymm9,ymm2,ymm8" - EmitIfCountGE \RowCount\(), 4, "vfmadd213pf ymm11,ymm2,ymm10" - EmitIfCountGE \RowCount\(), 5, "vfmadd213pf ymm13,ymm2,ymm12" - EmitIfCountGE \RowCount\(), 6, "vfmadd213pf ymm15,ymm2,ymm14" - jmp .LStoreMasked1xNBlock\@ - -.LMultiplyAlphaMasked1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmulpf ymm5,ymm5,ymm2" - EmitIfCountGE \RowCount\(), 2, "vmulpf ymm7,ymm7,ymm2" - EmitIfCountGE \RowCount\(), 3, "vmulpf ymm9,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 4, "vmulpf ymm11,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 5, "vmulpf ymm13,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 6, "vmulpf ymm15,ymm15,ymm2" - -.LStoreMasked1xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmaskmovpf YMMWORD PTR [rdx],ymm0,ymm5" - EmitIfCountGE \RowCount\(), 2, "vmaskmovpf YMMWORD PTR [rdx+rax],ymm0,ymm7" - EmitIfCountGE \RowCount\(), 3, "vmaskmovpf YMMWORD PTR [rdx+rax*2],ymm0,ymm9" - EmitIfCountGE \RowCount\(), 4, "vmaskmovpf YMMWORD PTR [rbx],ymm0,ymm11" - EmitIfCountGE \RowCount\(), 5, "vmaskmovpf YMMWORD PTR [rbx+rax],ymm0,ymm13" - EmitIfCountGE \RowCount\(), 6, "vmaskmovpf YMMWORD PTR [rbx+rax*2],ymm0,ymm15" -.ifb \Fallthrough\() - jmp .LExitKernelAndZeroUpper -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelFma3Function FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha (xmm0) - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY \FunctionName\() - - push rbp - push rbx - push r15 - mov r11,rdi - mov r10,.LFgemmKernelFrame_lda[rsp] - shl r10,.LFgemmElementShift # convert lda to bytes - mov rax,.LFgemmKernelFrame_ldc[rsp] - shl rax,.LFgemmElementShift # convert ldc to bytes - movzx r15,BYTE PTR .LFgemmKernelFrame_ZeroMode[rsp] - vmovsf .LFgemmKernelFrame_alpha[rsp],xmm0 - vzeroall - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - mov r8d,6 # return 6 rows handled - ProcessCountM 6, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernelAndZeroUpper: - vzeroupper - -.LExitKernel: - mov eax,r8d - pop r15 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelSse2Common.h b/onnxruntime/core/mlas/lib/x86_64/FgemmKernelSse2Common.h deleted file mode 100644 index 2f71864f00cf0..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/FgemmKernelSse2Common.h +++ /dev/null @@ -1,173 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - FgemmKernelSse2Common.h - -Abstract: - - This module implements the kernels for the floating point matrix/matrix - multiply operation (SGEMM and DGEMM). - - This implementation uses SSE2 instructions. - ---*/ - -/*++ - -Macro Description: - - This stores the block accumulators to the output matrix with an optional - accumulation of the existing contents of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorCount - Supplies the number of vector columns to process. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdx - Supplies the address of matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - - xmm8-xmm15 - Supplies the block accumulators. - ---*/ - - .macro AccumulateAndStoreBlock RowCount, VectorCount - - test r15b,r15b # ZeroMode? - jnz .LSkipAccumulateOutput\@ - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "movupf xmm0,XMMWORD PTR [rdx]" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "movupf xmm1,XMMWORD PTR [rdx+16]" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "movupf xmm2,XMMWORD PTR [rdx+32]" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "movupf xmm3,XMMWORD PTR [rdx+48]" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "movupf xmm4,XMMWORD PTR [rdx+rax]" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "movupf xmm5,XMMWORD PTR [rdx+rax+16]" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "movupf xmm6,XMMWORD PTR [rdx+rax+32]" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "movupf xmm7,XMMWORD PTR [rdx+rax+48]" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "addpf xmm8,xmm0" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "addpf xmm9,xmm1" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "addpf xmm10,xmm2" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "addpf xmm11,xmm3" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "addpf xmm12,xmm4" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "addpf xmm13,xmm5" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "addpf xmm14,xmm6" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "addpf xmm15,xmm7" - -.LSkipAccumulateOutput\@: - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 1, "movupf XMMWORD PTR [rdx],xmm8" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 2, "movupf XMMWORD PTR [rdx+16],xmm9" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 3, "movupf XMMWORD PTR [rdx+32],xmm10" - EmitIfCount2GE \RowCount\(), 1, \VectorCount\(), 4, "movupf XMMWORD PTR [rdx+48],xmm11" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 1, "movupf XMMWORD PTR [rdx+rax],xmm12" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 2, "movupf XMMWORD PTR [rdx+rax+16],xmm13" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 3, "movupf XMMWORD PTR [rdx+rax+32],xmm14" - EmitIfCount2GE \RowCount\(), 2, \VectorCount\(), 4, "movupf XMMWORD PTR [rdx+rax+48],xmm15" - - .endm - -/*++ - -Macro Description: - - This macro generates the inner kernel to compute matrix multiplication. - -Arguments: - - FunctionName - Supplies the name for the generated function. - ---*/ - - .macro FgemmKernelSse2Function FunctionName - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasSgemmCopyPackB or MlasSgemmTransposePackB. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - lda - Supplies the first dimension of matrix A. - - ldc - Supplies the first dimension of matrix C. - - Alpha (xmm0) - Supplies the scalar alpha multiplier (see GEMM definition). - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY \FunctionName\() - - push rbp - push rbx - push r15 - mov r11,rdi - mov r10,.LFgemmKernelFrame_lda[rsp] - shl r10,.LFgemmElementShift # convert lda to bytes - mov rax,.LFgemmKernelFrame_ldc[rsp] - shl rax,.LFgemmElementShift # convert ldc to bytes - movzx r15,BYTE PTR .LFgemmKernelFrame_ZeroMode[rsp] - movsf .LFgemmKernelFrame_alpha[rsp],xmm0 - -// -// Process CountM rows of the matrices. -// - - cmp r8,2 - jb .LProcessCountM1 - mov r8d,2 # return 2 rows handled - ProcessCountM 2, Fallthrough - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - mov eax,r8d - pop r15 - pop rbx - pop rbp - ret - -// -// Process 1 row of the matrices. -// - -.LProcessCountM1: - ProcessCountM 1 - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S deleted file mode 100644 index f1ee717363e12..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/LogisticKernelFma3.S +++ /dev/null @@ -1,125 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - LogisticKernelFma3.s - -Abstract: - - This module implements a kernel for computing the logistic function for a - buffer of elements. - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "TransKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the logistic function. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeLogisticF32KernelFma3 - - lea rax,C_UNDERSCORE(MlasLogisticConstants)[rip] - vbroadcastss ymm4,.LLogisticConstants_LowerRange[rax] - vbroadcastss ymm5,.LLogisticConstants_UpperRange[rax] - vbroadcastss ymm6,.LLogisticConstants_alpha_9[rax] - vbroadcastss ymm7,.LLogisticConstants_alpha_7[rax] - vbroadcastss ymm8,.LLogisticConstants_alpha_5[rax] - vbroadcastss ymm9,.LLogisticConstants_alpha_3[rax] - vbroadcastss ymm10,.LLogisticConstants_alpha_1[rax] - vbroadcastss ymm11,.LLogisticConstants_beta_10[rax] - vbroadcastss ymm12,.LLogisticConstants_beta_6[rax] - vbroadcastss ymm13,.LLogisticConstants_beta_4[rax] - vbroadcastss ymm14,.LLogisticConstants_beta_2[rax] - vbroadcastss ymm15,.LLogisticConstants_beta_0[rax] - - sub rdx,8 - jb .LProcessRemainingCount - -.LComputeLogisticBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rdi] # clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LLogisticConstants_beta_8[rax] - vfmadd231ps ymm2,ymm1,ymm6 # p = x2 * alpha_9 + alpha_7 - vfmadd213ps ymm2,ymm1,ymm8 # p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm9 # p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm10 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm11 # q = x2 * beta_10 + beta_8 - vfmadd213ps ymm3,ymm1,ymm12 # q = x2 * q + beta_6 - vfmadd213ps ymm3,ymm1,ymm13 # q = x2 * q + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 # p = x * p - vbroadcastss ymm0,.LLogisticConstants_one_half[rax] - vdivps ymm2,ymm2,ymm3 - vxorps ymm3,ymm3,ymm3 - vaddps ymm0,ymm2,ymm0 # logistic = p / q + 0.5 - vmaxps ymm0,ymm3,ymm0 # clamp lower bound - add rdi,8*4 # advance input by 8 elements - vmovups YMMWORD PTR [rsi],ymm0 - add rsi,8*4 # advance output by 8 elements - sub rdx,8 - jae .LComputeLogisticBy8Loop - -.LProcessRemainingCount: - add rdx,8 # correct for over-subtract above - jz .LExitKernel - neg rdx - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovups ymm2,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] - vmaxps ymm0,ymm4,ymm0 # clamp lower bound - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LLogisticConstants_beta_8[rax] - vfmadd231ps ymm7,ymm1,ymm6 # p = x2 * alpha_9 + alpha_7 - vfmadd213ps ymm7,ymm1,ymm8 # p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm9 # p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm10 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm11 # q = x2 * beta_10 + beta_8 - vfmadd213ps ymm3,ymm1,ymm12 # q = x2 * q + beta_6 - vfmadd213ps ymm3,ymm1,ymm13 # q = x2 * q + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 # p = x * p - vbroadcastss ymm0,.LLogisticConstants_one_half[rax] - vdivps ymm7,ymm7,ymm3 - vxorps ymm3,ymm3,ymm3 - vaddps ymm0,ymm7,ymm0 # logistic = p / q + 0.5 - vmaxps ymm0,ymm3,ymm0 # clamp lower bound - vmaskmovps YMMWORD PTR [rsi],ymm2,ymm0 - -.LExitKernel: - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmx.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmx.S deleted file mode 100644 index c4868167b2a84..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmx.S +++ /dev/null @@ -1,516 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAmx.s - -Abstract: - - This module implements the packing functions for the quantized integer matrix/matrix - multiply operation (QGEMM). - - These packing functions are suited for AMX Qgemm kernel. The implementation only - uses AVX2 instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the U8S8 CopyPackB routine. -// - - .equ .LGemmU8S8CopyPackBFrame_SavedR12, 0 - .equ .LGemmU8S8CopyPackBFrame_SavedRbx, 8 - .equ .LGemmU8S8CopyPackBFrame_SavedRbp, 16 - .equ .LGemmU8S8CopyPackBFrame_ReturnAddress, 24 - .equ .LGemmU8S8CopyPackBFrame_BIsSigned, 32 - - .text - -/*++ - -Routine Description: - - This routine copies elements from the source B matrix to the destination - packed buffer. - - This implementation is almost identical to MlasGemmU8S8CopyPackBAvx2 - where it traverse B vertically, take a block of 4 row 16 col, transpose - and store it, then go down 4 row to grab the next 4x16 block. The only - difference here is that we need K to be aligned to 64 to the fill - an AMX tile. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - B (rsi) - Supplies the address of the source matrix. - - ldb (rdx) - Supplies the number of elements per row of the source matrix. - - CountN (rcx) - Supplies the number of columns of the source matrix to copy. - - CountK (r8) - Supplies the number of rows of the source matrix to copy. - - ColumnSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the columns. - - BIsSigned - Supplies true if the source matrix is signed data, else false if - the source matrix is unsigned data. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasGemmU8S8CopyPackBAmx - - push rbp - push rbx - push r12 - - mov r10,rdx - lea r11,[r10+r10*2] # compute ldb * 3 - lea r12,[r8+3] # compute extra padding for 64|K - shr r12,2 - neg r12 - and r12,15 - vpcmpeqw ymm0,ymm0,ymm0 # generate word vector [0xFFFF] - vpsrlw ymm0,ymm0,15 # generate word vector [0x0001] - vpsllw ymm1,ymm0,8 # generate word vector [0x0100] - vpor ymm1,ymm0,ymm1 # generate word vector [0x0101] - -// -// Compute the bit flip vector to adjust input from U8 to S8. -// - - vpxor xmm2,xmm2,xmm2 # generate word vector [0x0000] - cmp BYTE PTR .LGemmU8S8CopyPackBFrame_BIsSigned[rsp],0 - jnz .LCopyPackB.SkipUnsignedBitFlipVector - vpsllw ymm2,ymm1,7 # generate word vector [0x8080] - -.LCopyPackB.SkipUnsignedBitFlipVector: - -// -// Process 16 columns of matrix B in a loop. -// - - sub rcx,16 # CountN -= 16 - jb .LCopyPackB.ProcessRemainingColumns - -.LCopyPackB.ProcessNextColumnN16: - vpxord xmm30,xmm30,xmm30 # clear column accumulators - vpxord xmm31,xmm31,xmm31 - mov rdx,rsi # rdx -> B start of 16 columns - add rsi,16 # advance next matrix B by 16 columns - mov rbx,r8 # reload rows remaining - sub rbx,4 - jb .LCopyPackB.ProcessRemainingRowsN16 - -.LCopyPackB.ProcessNextRowLoopN16: - vmovdqu64 xmm16,XMMWORD PTR [rdx] # load 4 rows - vmovdqu64 xmm17,XMMWORD PTR [rdx+r10] - vmovdqu64 xmm18,XMMWORD PTR [rdx+r10*2] - vmovdqu64 xmm19,XMMWORD PTR [rdx+r11] - lea rdx,[rdx+r10*4] # advance matrix B by 4 rows - -.LCopyPackB.InterleaveRowDataN16: - vpunpcklbw xmm3,xmm16,xmm17 # interleave row data - vpunpckhbw xmm17,xmm16,xmm17 - vpunpcklbw xmm16,xmm18,xmm19 - vpunpckhbw xmm19,xmm18,xmm19 - vpunpcklwd xmm18,xmm3,xmm16 - vpunpckhwd xmm3,xmm3,xmm16 - vpunpcklwd xmm16,xmm17,xmm19 - vpunpckhwd xmm17,xmm17,xmm19 - vinserti64x2 ymm18,ymm18,xmm3,1 - vinserti64x2 ymm16,ymm16,xmm17,1 - vpxord ymm18,ymm18,ymm2 # optionally adjust unsigned data - vpxord ymm16,ymm16,ymm2 - vmovdqu64 YMMWORD PTR [rdi],ymm18 # store interleaved rows - vmovdqu64 YMMWORD PTR [rdi+32],ymm16 - vpmaddubsw ymm18,ymm1,ymm18 # horizontal byte+byte=word per row - vpmaddwd ymm18,ymm18,ymm0 # horizontal word+word=dword per row - vpaddd ymm30,ymm30,ymm18 # accumulate per column - vpmaddubsw ymm16,ymm1,ymm16 - vpmaddwd ymm16,ymm16,ymm0 - vpaddd ymm31,ymm31,ymm16 - add rdi,64 # advance matrix D by 64 bytes - sub rbx,4 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopN16 - -// -// Process the less than 4 remaining rows where the row has 16 columns. -// - -.LCopyPackB.ProcessRemainingRowsN16: - add rbx,4 # correct for over-subtract above - jz .LCopyPackB.StoreColumnSumBufferN16 - vmovdqu64 xmm16,XMMWORD PTR [rdx] - vmovaps xmm17,xmm2 - vmovaps xmm18,xmm2 - vmovaps xmm19,xmm2 - xor ebx,ebx # no more rows remaining - test r8b,2 # (CountK & 2) != 0? - jz .LCopyPackB.InterleaveRowDataN16 - vmovdqu64 xmm17,XMMWORD PTR [rdx+r10] - test r8b,1 # (CountK & 1) != 0? - jz .LCopyPackB.InterleaveRowDataN16 - vmovdqu64 xmm18,XMMWORD PTR [rdx+r10*2] - jmp .LCopyPackB.InterleaveRowDataN16 - -.LCopyPackB.StoreColumnSumBufferN16: - vmovdqu64 YMMWORD PTR [r9],ymm30 - vmovdqu64 YMMWORD PTR [r9+32],ymm31 - test r12,r12 - jz .LCopyPackB.N16K64PaddingFinished - mov rax, r12 - vpxord xmm30,xmm30,xmm30 - -.LCopyPackB.N16K64Padding: - vmovdqu64 YMMWORD PTR [rdi],ymm30 # store 0 - vmovdqu64 YMMWORD PTR [rdi+32],ymm30 - add rdi,64 - dec rax - jnz .LCopyPackB.N16K64Padding - -.LCopyPackB.N16K64PaddingFinished: - add r9,16*4 # advance column sum buffer by 16 dwords - sub rcx,16 # subtract columns remaining - jae .LCopyPackB.ProcessNextColumnN16 - -.LCopyPackB.ProcessRemainingColumns: - add rcx,16 # correct for over-subtract above - jnz .LCopyPackB.ProcessColumnNUnaligned - -// -// Restore non-volatile registers and return. -// - -.LCopyPackB.ExitRoutine: - vzeroupper - - pop r12 - pop rbx - pop rbp - ret - -// -// Process the remaining columns of matrix B. -// - -.LCopyPackB.ProcessColumnNUnaligned: - vpxord xmm30,xmm30,xmm30 # clear column accumulators - vpxord xmm31,xmm31,xmm31 - neg ecx - and ecx,63 - mov rbx,-1 - shr rbx,cl # mask for left over N - kmovq k1,rbx # mask - sub r8,4 - jb .LCopyPackB.ProcessRemainingRowsNUnaligned - -.LCopyPackB.ProcessNextRowLoopNUnaligned: - vmovdqu64 xmm16,xmm2 - vmovdqu8 xmm16 {k1},XMMWORD PTR [rsi] # load 4 rows - vmovdqu64 xmm17,xmm2 - vmovdqu8 xmm17 {k1},XMMWORD PTR [rsi+r10] - vmovdqu64 xmm18,xmm2 - vmovdqu8 xmm18 {k1},XMMWORD PTR [rsi+r10*2] - vmovdqu64 xmm19,xmm2 - vmovdqu8 xmm19 {k1},XMMWORD PTR [rsi+r11] - lea rsi,[rsi+r10*4] # advance next matrix B by 4 rows - -.LCopyPackB.InterleaveRowDataUnaligned: - vpunpcklbw xmm3,xmm16,xmm17 # interleave row data - vpunpckhbw xmm17,xmm16,xmm17 - vpunpcklbw xmm16,xmm18,xmm19 - vpunpckhbw xmm19,xmm18,xmm19 - vpunpcklwd xmm18,xmm3,xmm16 - vpunpckhwd xmm3,xmm3,xmm16 - vpunpcklwd xmm16,xmm17,xmm19 - vpunpckhwd xmm17,xmm17,xmm19 - vinserti64x2 ymm18,ymm18,xmm3,1 - vinserti64x2 ymm16,ymm16,xmm17,1 - vpxord ymm18,ymm18,ymm2 # optionally adjust unsigned data - vpxord ymm16,ymm16,ymm2 - vmovdqu64 YMMWORD PTR [rdi],ymm18 # store interleaved rows - vmovdqu64 YMMWORD PTR [rdi+32],ymm16 - vpmaddubsw ymm18,ymm1,ymm18 # horizontal byte+byte=word per row - vpmaddwd ymm18,ymm18,ymm0 # horizontal word+word=dword per row - vpaddd ymm30,ymm30,ymm18 # accumulate per column - vpmaddubsw ymm16,ymm1,ymm16 - vpmaddwd ymm16,ymm16,ymm0 - vpaddd ymm31,ymm31,ymm16 - add rdi,64 # advance matrix D by 64 bytes - sub r8,4 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopNUnaligned - -// -// Process the less than 4 remaining rows where the row has less than 16 columns. -// - -.LCopyPackB.ProcessRemainingRowsNUnaligned: - add r8,4 - jz .LCopyPackB.StoreColumnSumBufferNUnaligned - - vmovaps xmm16,xmm2 - vmovdqu8 xmm16 {k1},XMMWORD PTR [rsi] - vmovaps xmm17,xmm2 - vmovaps xmm18,xmm2 - vmovaps xmm19,xmm2 - mov rbx,r8 - xor r8b,r8b # no more rows remaining - test bl,2 # (CountK & 2) != 0? - jz .LCopyPackB.InterleaveRowDataUnaligned - vmovdqu8 xmm17 {k1},XMMWORD PTR [rsi+r10] - test bl,1 # (CountK & 1) != 0? - jz .LCopyPackB.InterleaveRowDataUnaligned - vmovdqu8 xmm18 {k1},XMMWORD PTR [rsi+r10*2] - jmp .LCopyPackB.InterleaveRowDataUnaligned - -.LCopyPackB.StoreColumnSumBufferNUnaligned: - vmovdqu64 YMMWORD PTR [r9],ymm30 - vmovdqu64 YMMWORD PTR [r9+32],ymm31 - test r12,r12 - jz .LCopyPackB.ExitRoutine - mov rax, r12 - vpxord xmm30,xmm30,xmm30 - -.LCopyPackB.K64Padding: - vmovdqu64 YMMWORD PTR [rdi],ymm30 # store 0 - vmovdqu64 YMMWORD PTR [rdi+32],ymm30 - add rdi,64 - dec rax - jne .LCopyPackB.K64Padding - jmp .LCopyPackB.ExitRoutine - - -// -// Stack frame layout for the U8S8 CopyPackA routine. -// - .equ .LGemmU8S8CopyPackAFrame_SavedR13, 0 - .equ .LGemmU8S8CopyPackAFrame_SavedR12, 8 - .equ .LGemmU8S8CopyPackAFrame_SavedRbx, 16 - .equ .LGemmU8S8CopyPackAFrame_SavedRbp, 24 - .equ .LGemmU8S8CopyPackAFrame_ReturnAddress, 32 - -/*++ - -Routine Description: - - This routine copies elements from the source matrix A to the destination - packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - A (rsi) - Supplies the address of the source matrix. - - lda (rdx) - Supplies the number of elements per row of the source matrix. - - CountM (rcx) - Supplies the number of rows of the source matrix to copy. - - CountK (r8) - Supplies the number of columns of the source matrix to copy. - - RowSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the rows. - by the zero point offset. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasGemmU8S8CopyPackAAmx - - push rbp - push rbx - push r12 - push r13 - - mov r10,rdx # lda - mov r11,rcx # m = CountM - lea r12,[r8+63] - and r12,NOT 63 # align CountK up to 64 - vpternlogd zmm30,zmm30,zmm30,255 # generate word vector [0xFFFF] - vpsrlw zmm30,zmm30,15 # generate word vector [0x0001] - vpsllw zmm31,zmm30,8 # generate word vector [0x0100] - vpord zmm31,zmm30,zmm31 # generate word vector [0x0101] - lea r13,[r10+r10*2] # compute ldb * 3 - lea rax,[r12+r12*2] # compute AlignedCountK * 3 - mov ecx,r8d # CountK - neg ecx - and ecx,63 - mov rbx,-1 - shr rbx,cl # mask for left over k < 64 - kmovq k1,rbx # mask - -// -// Process 4 rows of matrix A in a loop. -// - - sub r11,4 # m -= 4 - jb .LCopyPackA.ProcessRemainingRows - -.LCopyPackA.ProcessNextRowM4: - vpxor xmm0,xmm0,xmm0 # clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - mov rdx,rsi # src = A - mov rcx,rdi # dst = D - lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows - lea rdi,[rdi+r12*4] # advance next matrix D by 4 rows - mov rbx,r8 # k = CountK - sub rbx,64 - jb .LCopyPackA.ProcessRemainingColumnsM4 - -.LCopyPackA.ProcessNextColumnLoopM4: - vmovdqu64 zmm16,ZMMWORD PTR [rdx] - vmovdqu64 zmm17,ZMMWORD PTR [rdx+r10] - vmovdqu64 zmm18,ZMMWORD PTR [rdx+r10*2] - vmovdqu64 zmm19,ZMMWORD PTR [rdx+r13] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vmovdqu64 ZMMWORD PTR [rcx+r12],zmm17 - vmovdqu64 ZMMWORD PTR [rcx+r12*2],zmm18 - vmovdqu64 ZMMWORD PTR [rcx+rax],zmm19 - vpmaddubsw zmm16,zmm16,zmm31 # horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 # add words to row accumulators - vpmaddubsw zmm17,zmm17,zmm31 - vpaddw zmm1,zmm1,zmm17 - vpmaddubsw zmm18,zmm18,zmm31 - vpaddw zmm2,zmm2,zmm18 - vpmaddubsw zmm19,zmm19,zmm31 - vpaddw zmm3,zmm3,zmm19 - add rdx,64 # src += 64 - add rcx,64 # dst += 64 - sub rbx,64 # k -= 64 - jae .LCopyPackA.ProcessNextColumnLoopM4 - -.LCopyPackA.ProcessRemainingColumnsM4: - add rbx,64 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM4 - vmovdqu8 zmm16{k1}{z},ZMMWORD PTR [rdx] - vmovdqu8 zmm17{k1}{z},ZMMWORD PTR [rdx+r10] - vmovdqu8 zmm18{k1}{z},ZMMWORD PTR [rdx+r10*2] - vmovdqu8 zmm19{k1}{z},ZMMWORD PTR [rdx+r13] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vmovdqu64 ZMMWORD PTR [rcx+r12],zmm17 - vmovdqu64 ZMMWORD PTR [rcx+r12*2],zmm18 - vmovdqu64 ZMMWORD PTR [rcx+rax],zmm19 - vpmaddubsw zmm16,zmm16,zmm31 # horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 # add words to row accumulators - vpmaddubsw zmm17,zmm17,zmm31 - vpaddw zmm1,zmm1,zmm17 - vpmaddubsw zmm18,zmm18,zmm31 - vpaddw zmm2,zmm2,zmm18 - vpmaddubsw zmm19,zmm19,zmm31 - vpaddw zmm3,zmm3,zmm19 - -// -// Reduce the sums for the four rows of output. -// - -.LCopyPackA.ReduceRowSumBufferM4: - vpmaddwd zmm0,zmm0,zmm30 # horizontal word+word=dword per row - vpmaddwd zmm1,zmm1,zmm30 - vpmaddwd zmm2,zmm2,zmm30 - vpmaddwd zmm3,zmm3,zmm30 - vextracti64x4 ymm16,zmm0,1 # fold zmm -> ymm - vextracti64x4 ymm17,zmm1,1 - vextracti64x4 ymm18,zmm2,1 - vextracti64x4 ymm19,zmm3,1 - vpaddd ymm0,ymm0,ymm16 - vpaddd ymm1,ymm1,ymm17 - vpaddd ymm2,ymm2,ymm18 - vpaddd ymm3,ymm3,ymm19 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 - vphaddd ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 # fold ymm -> xmm - vpaddd xmm0,xmm0,xmm1 - vmovdqu XMMWORD PTR [r9],xmm0 - add r9,4*4 # advance row sum buffer by 4 dwords - sub r11,4 # m -= 4 - jae .LCopyPackA.ProcessNextRowM4 - -.LCopyPackA.ProcessRemainingRows: - add r11,4 # correct for over-subtract above - jz .LCopyPackA.ExitRoutine - -// -// Process a single row of matrix A in a loop. -// - -.LCopyPackA.ProcessNextRowM1: - vpxor xmm0,xmm0,xmm0 # clear row accumulator - mov rdx,rsi # src = A - mov rcx,rdi # dst = D - add rsi,r10 # A to next row - add rdi,r12 # D to next row - mov rbx,r8 # k = CountK - sub rbx,64 # k -= 64 - jb .LCopyPackA.ProcessRemainingColumnsM1 - -.LCopyPackA.ProcessNextColumnLoopM1: - vmovdqu64 zmm16,ZMMWORD PTR [rdx] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vpmaddubsw zmm16,zmm16,zmm31 # horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 # add words to row accumulators - add rdx,64 # src += 64 - add rcx,64 # dst += 64 - sub rbx,64 # k -= 64 - jae .LCopyPackA.ProcessNextColumnLoopM1 - -.LCopyPackA.ProcessRemainingColumnsM1: - add rbx,64 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM1 - - vmovdqu8 zmm16{k1}{z},ZMMWORD PTR [rdx] - vmovdqu64 ZMMWORD PTR [rcx],zmm16 - vpmaddubsw zmm16,zmm16,zmm31 # horizontal byte+byte=word per row - vpaddw zmm0,zmm0,zmm16 # add words to row accumulators - -// -// Reduce the sum for the single row of output. -// - -.LCopyPackA.ReduceRowSumBufferM1: - vpmaddwd zmm0,zmm0,zmm30 # horizontal word+word=dword per row - vextracti64x4 ymm16,zmm0,1 # fold zmm -> ymm - vpaddd ymm0,ymm0,ymm16 - vextracti128 xmm1,ymm0,1 # fold ymm -> xmm - vpaddd xmm0,xmm0,xmm1 # reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r9],xmm0 - add r9,4 # advance row sum buffer by 1 dword - dec r11 # decrement rows remaining - jnz .LCopyPackA.ProcessNextRowM1 - -// -// Restore non-volatile registers and return. -// - -.LCopyPackA.ExitRoutine: - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S deleted file mode 100644 index 7d042e2d8f476..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAmxCommon.S +++ /dev/null @@ -1,234 +0,0 @@ -/*++ - -Copyright (c) 2023 Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - AssembleAmx.h - -Abstract: - - This module contains macros to build AMX instructions for toolchains that - do not natively support this newer instruction set extension. - ---*/ - -// -// Map friendly register names to the encoded register index. -// - - .equ .LTmmIndex_tmm0, 0 - .equ .LTmmIndex_tmm1, 1 - .equ .LTmmIndex_tmm2, 2 - .equ .LTmmIndex_tmm3, 3 - .equ .LTmmIndex_tmm4, 4 - .equ .LTmmIndex_tmm5, 5 - .equ .LTmmIndex_tmm6, 6 - .equ .LTmmIndex_tmm7, 7 - -/*++ - -Macro Description: - - This macro builds a AMX instruction of the form: - - instr tmm1,tmm2,tmm3 - -Arguments: - - prefix - Specifies the opcode for the AMX instruction. - - DestReg - Specifies the destination AMX tile. - - Src1Reg - Specifies the first source AMX tile. - - Src2Reg - Specifies the second source AMX tile. - ---*/ - - .macro DPTmmTmmTmm prefix, DestReg, Src1Reg, Src2Reg - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LTmmIndex_\DestReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((.LTmmIndex_\Src2Reg\() >> 3) & 1) ^ 1) << 5) - - .set Payload1, \prefix\() - .set Payload1, Payload1 + (((.LTmmIndex_\Src2Reg\() & 15) ^ 15) << 3) - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LTmmIndex_\DestReg\() & 7) << 3) - .set ModRMByte, ModRMByte + (.LTmmIndex_\Src1Reg\() & 7) - - .byte 0xC4, Payload0, Payload1, 0x5E, ModRMByte - - .endm - - - .macro TdpbssdTmmTmmTmm DestReg, Src1Reg, Src2Reg - - DPTmmTmmTmm 0x03, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - - .macro TdpbsudTmmTmmTmm DestReg, Src1Reg, Src2Reg - - DPTmmTmmTmm 0x02, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - - .macro TdpbusdTmmTmmTmm DestReg, Src1Reg, Src2Reg - - DPTmmTmmTmm 0x01, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - - - .macro TdpbuudTmmTmmTmm DestReg, Src1Reg, Src2Reg - - DPTmmTmmTmm 0x00, \DestReg\(), \Src1Reg\(), \Src2Reg\() - - .endm - -/*++ - -Macro Description: - - This macro builds a AMX tile release instruction. - -Arguments: - - - ---*/ - -// .macro TileReleaseMacro - -// .byte 0xC4, 0xE2, 0x78, 0x49, 0xC0 - -// .endm - - -/*++ - -Macro Description: - - This macro builds an AMX tile zero instruction of the form: - - instr tmm1 - -Arguments: - - SrcReg - Specifies the source AMX tile. - ---*/ - - .macro TileZeroMacro SrcReg - - .set ModRMByte, 0xC0 # register form - .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) - .byte 0xC4, 0xE2, 0x7B, 0x49, ModRMByte - - .endm - -/*++ - -Macro Description: - - This macro builds an AMX memory instruction of the form: - - instr tmm, base, stride - -Arguments: - - instr - Specifies the opcode for the AMX instruction. - - SrcReg - Specifies the target AMX tile. - - BaseReg - Specifies the base address of memory location. - - Stride - Specifies the stride for the memory instruction - ---*/ - - .macro TileLoadMacro instr, SrcReg, BaseReg, Stride - - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + ((((.LTmmIndex_\SrcReg\() >> 3) & 1) ^ 1) << 7) - .set Payload0, Payload0 + ((((3 >> 3) & 1) ^ 1) << 6) - .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) - - .set ModRMByte, 0x00 # memory form - .set ModRMByte, ModRMByte + (1 << 2) # SibBye required - .set ModRMByte, ModRMByte + ((.LTmmIndex_\SrcReg\() & 7) << 3) - - .set SibByte, 0x00 # scale factor 1(SS) - .set SibByte, SibByte + ((3 & 7) << 3) - .set SibByte, SibByte + (0 & 7) - - .byte 0xC4, Payload0, \instr\(), 0x4B, ModRMByte, SibByte - - .endm - - - .macro TileloaddTmmMem DstReg, BaseReg, Stride - TileLoadMacro 0x7B, \DstReg\(), \BaseReg\(), \Stride\() - .endm - - .macro TileloaddT1TmmMem DstReg, BaseReg, Stride - TileLoadMacro 0x79, \DstReg\(), \BaseReg\(), \Stride\() - .endm - - - .macro TileStoredMemTmm SrcReg, BaseReg, Stride - TileLoadMacro 0x7A, \SrcReg\(), \BaseReg\(), \Stride\() - .endm - - -/*++ - -Macro Description: - - This macro builds an AMX tile configuration instruction of the form: - - instr base - -Arguments: - - instr - Specifies the opcode for the AMX instruction. - - BaseReg - Specifies the memory address of the tile configuration. - ---*/ - - .macro tilecfgMacro instr, BaseReg - .set Payload0, 0x02 # "0F 38" prefix - .set Payload0, Payload0 + (1 << 7) - .set Payload0, Payload0 + (1 << 6) - .set Payload0, Payload0 + ((((0 >> 3) & 1) ^ 1) << 5) - - .set ModRMByte, 0x00 # memory form & no reg - .set ModRMByte, ModRMByte + (0 & 7) - - .byte 0xC4, Payload0, \instr\(), 0x49, ModRMByte - - .endm - - - .macro ldtilecfgMacro BaseReg - - tilecfgMacro 0x78, \BaseReg\() - - .endm - - - .macro sttilecfgMacro BaseReg - - tilecfgMacro 0x79, \BaseReg\() - - .endm - diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S deleted file mode 100644 index 3066eadaf153c..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S +++ /dev/null @@ -1,782 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8S8KernelAvx2.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX2 instructions. - Support for AVX-VNNI-INT8 for certain code paths. - ---*/ - -#include "asmmacro.h" -#include "AssembleAvxVnni.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the Int8 CopyPackA routine. -// - - .equ .LGemmInt8CopyPackAFrame_PaddedMatrixAData, -72 - .equ .LGemmInt8CopyPackAFrame_Padding, -8 - .equ .LGemmInt8CopyPackAFrame_SavedR13, 0 - .equ .LGemmInt8CopyPackAFrame_SavedR12, 8 - .equ .LGemmInt8CopyPackAFrame_SavedRbx, 16 - .equ .LGemmInt8CopyPackAFrame_SavedRbp, 24 - .equ .LGemmInt8CopyPackAFrame_ReturnAddress, 32 - -// -// Stack frame layout for the Int8 CopyPackB routine. -// - - .equ .LGemmInt8CopyPackBFrame_PaddedMatrixBData, -72 - .equ .LGemmInt8CopyPackBFrame_Padding, -8 - .equ .LGemmInt8CopyPackBFrame_SavedRbx, 0 - .equ .LGemmInt8CopyPackBFrame_SavedRbp, 8 - .equ .LGemmInt8CopyPackBFrame_ReturnAddress, 16 - .equ .LGemmInt8CopyPackBFrame_BIsSigned, 24 - - .text - -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - A (rsi) - Supplies the address of the source matrix. - - lda (rdx) - Supplies the number of elements per row of the source matrix. - - CountM (rcx) - Supplies the number of rows of the source matrix to copy. - - CountK (r8) - Supplies the number of columns of the source matrix to copy. - - RowSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the rows. - by the zero point offset. - -Return Value: - - None. - ---*/ - -.macro MlasGemmCopyPackAAvx2 ASigned - - push rbp - push rbx - push r12 - push r13 - - mov r10,rdx - mov r11,rcx - lea r12,[r8+3] - and r12,NOT 3 # align CountK up to quad count - vpcmpeqw ymm8,ymm8,ymm8 # generate word vector [0xFFFF] - vpsrlw ymm8,ymm8,15 # generate word vector [0x0001] - vpsllw ymm9,ymm8,8 # generate word vector [0x0100] - vpor ymm9,ymm8,ymm9 # generate word vector [0x0101] - -// -// Compute the conditional load/store mask for an unaligned CountK. -// - - mov eax,r8d - and eax,15 # isolate unaligned count - add eax,3 - shr eax,2 # align unaligned count to quad count - neg rax - lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovdqu xmm10,XMMWORD PTR [rbx+rax*4] - -// -// Zero initialize the padded stack buffers. -// - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 - -// -// Process 4 rows of matrix A in a loop. -// - - sub r11,4 - jb .LCopyPackA.ProcessRemainingRows\@ - -.LCopyPackA.ProcessNextRowM4\@: - vpxor xmm0,xmm0,xmm0 # clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - lea r13,[r10+r10*2] # compute ldb * 3 - lea rax,[r12+r12*2] # compute output stride * 3 - mov rdx,rsi - mov rcx,rdi - lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows - lea rdi,[rdi+r12*4] # advance next matrix D by 4 rows - mov rbx,r8 # reload columns remaining - sub rbx,32 - jb .LCopyPackA.ProcessRemainingColumnsM4\@ - -.LCopyPackA.ProcessNextColumnLoopM4\@: - vmovdqu ymm4,YMMWORD PTR [rdx] - vmovdqu ymm5,YMMWORD PTR [rdx+r10] - vmovdqu ymm6,YMMWORD PTR [rdx+r10*2] - vmovdqu ymm7,YMMWORD PTR [rdx+r13] - vmovdqu YMMWORD PTR [rcx],ymm4 - vmovdqu YMMWORD PTR [rcx+r12],ymm5 - vmovdqu YMMWORD PTR [rcx+r12*2],ymm6 - vmovdqu YMMWORD PTR [rcx+rax],ymm7 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -.else - vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # add words to row accumulators - vpmaddubsw ymm5,ymm5,ymm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw ymm6,ymm6,ymm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw ymm7,ymm7,ymm9 - vpaddw ymm3,ymm3,ymm7 -.endif - add rdx,32 # advance matrix A by 32 bytes - add rcx,32 # advance matrix D by 32 bytes - sub rbx,32 # subtract columns remaining - jae .LCopyPackA.ProcessNextColumnLoopM4\@ - -.LCopyPackA.ProcessRemainingColumnsM4\@: - add rbx,32 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM4\@ - test bl,16 # (CountK & 16) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan16M4\@ - vmovdqu xmm4,XMMWORD PTR [rdx] - vmovdqu xmm5,XMMWORD PTR [rdx+r10] - vmovdqu xmm6,XMMWORD PTR [rdx+r10*2] - vmovdqu xmm7,XMMWORD PTR [rdx+r13] - vmovdqu XMMWORD PTR [rcx],xmm4 - vmovdqu XMMWORD PTR [rcx+r12],xmm5 - vmovdqu XMMWORD PTR [rcx+r12*2],xmm6 - vmovdqu XMMWORD PTR [rcx+rax],xmm7 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -.else - vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # add words to row accumulators - vpmaddubsw xmm5,xmm5,xmm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw xmm6,xmm6,xmm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw xmm7,xmm7,xmm9 - vpaddw ymm3,ymm3,ymm7 -.endif - add rdx,16 # advance matrix A by 16 bytes - add rcx,16 # advance matrix D by 16 bytes - test bl,15 # test for unaligned columns - jz .LCopyPackA.ReduceRowSumBufferM4\@ - -// -// Copy the unaligned CountK columns to a zero padded stack buffer. -// - -.LCopyPackA.CopyRemainingCountKLessThan16M4\@: - lea rbp,.LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp] - test bl,8 # (CountK & 8) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan8M4\@ - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r10] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [rdx+r10*2] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [rdx+r13] - mov QWORD PTR [rbp+48],rax - add rdx,8 - add rbp,8 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan8M4\@: - test bl,4 # (CountK & 4) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan4M4\@ - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r10] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [rdx+r10*2] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [rdx+r13] - mov DWORD PTR [rbp+48],eax - add rdx,4 - add rbp,4 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan4M4\@: - test bl,2 # (CountK & 2) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan2M4\@ - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r10] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [rdx+r10*2] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [rdx+r13] - mov WORD PTR [rbp+48],ax - add rdx,2 - add rbp,2 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan2M4\@: - test bl,1 # (CountK & 1) != 0? - jz .LCopyPackA.ProcessPaddedMatrixADataM4\@ - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r10] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [rdx+r10*2] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [rdx+r13] - mov BYTE PTR [rbp+48],al - -// -// Process the remaining CountK columns using the zero padded stack buffer. -// - -.LCopyPackA.ProcessPaddedMatrixADataM4\@: - vmovdqu xmm4,XMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp] - vmovdqu xmm5,XMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp+16] - vmovdqu xmm6,XMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp+32] - vmovdqu xmm7,XMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp+48] - lea rax,[rcx+r12*2] # compute matrix D plus 2 rows - vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 - vpmaskmovd XMMWORD PTR [rcx+r12],xmm10,xmm5 - vpmaskmovd XMMWORD PTR [rax],xmm10,xmm6 - vpmaskmovd XMMWORD PTR [rax+r12],xmm10,xmm7 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 - VpdpbssdYmmYmmYmm ymm1,ymm5,ymm9 - VpdpbssdYmmYmmYmm ymm2,ymm6,ymm9 - VpdpbssdYmmYmmYmm ymm3,ymm7,ymm9 -.else - vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # add words to row accumulators - vpmaddubsw xmm5,xmm5,xmm9 - vpaddw ymm1,ymm1,ymm5 - vpmaddubsw xmm6,xmm6,xmm9 - vpaddw ymm2,ymm2,ymm6 - vpmaddubsw xmm7,xmm7,xmm9 - vpaddw ymm3,ymm3,ymm7 -.endif - -// -// Reduce the sums for the four rows of output. -// - -.LCopyPackA.ReduceRowSumBufferM4\@: -.if \ASigned\() == 1 - vphaddd ymm0,ymm0,ymm1 -.else - vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row - vpmaddwd ymm1,ymm1,ymm8 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 - vpmaddwd ymm2,ymm2,ymm8 - vpmaddwd ymm3,ymm3,ymm8 -.endif - vphaddd ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 # extract high dwords - vpaddd xmm0,xmm0,xmm1 # reduce low/high dwords - vmovdqu XMMWORD PTR [r9],xmm0 - add r9,4*4 # advance row sum buffer by 4 dwords - sub r11,4 # subtract rows remaining - jae .LCopyPackA.ProcessNextRowM4\@ - -.LCopyPackA.ProcessRemainingRows\@: - add r11,4 # correct for over-subtract above - jz .LCopyPackA.ExitRoutine\@ - -// -// Process a single row of matrix A in a loop. -// - -.LCopyPackA.ProcessNextRowM1\@: - vpxor xmm0,xmm0,xmm0 # clear row accumulator - mov rdx,rsi - mov rcx,rdi - add rsi,r10 - add rdi,r12 - mov rbx,r8 # reload columns remaining - sub rbx,32 - jb .LCopyPackA.ProcessRemainingColumnsM1\@ - -.LCopyPackA.ProcessNextColumnLoopM1\@: - vmovdqu ymm4,YMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rcx],ymm4 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -.else - vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # add words to row accumulators -.endif - add rdx,32 # advance matrix A by 32 bytes - add rcx,32 # advance matrix D by 32 bytes - sub rbx,32 # subtract columns remaining - jae .LCopyPackA.ProcessNextColumnLoopM1\@ - -.LCopyPackA.ProcessRemainingColumnsM1\@: - add rbx,32 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM1\@ - test bl,16 # (CountK & 16) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan16M1\@ - vmovdqu xmm4,XMMWORD PTR [rdx] - vmovdqu XMMWORD PTR [rcx],xmm4 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -.else - vpmaddubsw xmm4,xmm4,xmm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # add words to row accumulators -.endif - add rdx,16 # advance matrix A by 16 bytes - add rcx,16 # advance matrix D by 16 bytes - test bl,15 # test for unaligned columns - jz .LCopyPackA.ReduceRowSumBufferM1\@ - -// -// Copy the unaligned CountK columns to a zero padded stack buffer. -// - -.LCopyPackA.CopyRemainingCountKLessThan16M1\@: - lea rbp,.LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp] - test bl,8 # (CountK & 8) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan8M1\@ - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 - add rbp,8 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan8M1\@: - test bl,4 # (CountK & 4) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan4M1\@ - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 - add rbp,4 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan4M1\@: - test bl,2 # (CountK & 2) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan2M1\@ - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 - add rbp,2 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan2M1\@: - test bl,1 # (CountK & 1) != 0? - jz .LCopyPackA.ProcessPaddedMatrixADataM1\@ - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -// -// Process the remaining CountK columns using the zero padded stack buffer. -// - -.LCopyPackA.ProcessPaddedMatrixADataM1\@: - vmovdqu xmm4,XMMWORD PTR .LGemmInt8CopyPackAFrame_PaddedMatrixAData[rsp] - vpmaskmovd XMMWORD PTR [rcx],xmm10,xmm4 -.if \ASigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm9 -.else - vpmaddubsw ymm4,ymm4,ymm9 # horizontal byte+byte=word per row - vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns -.endif - -// -// Reduce the sum for the single row of output. -// - -.LCopyPackA.ReduceRowSumBufferM1\@: -.if \ASigned\() == 0 - vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row -.endif - vextracti128 xmm1,ymm0,1 # extract high dwords - vpaddd xmm0,xmm0,xmm1 # reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r9],xmm0 - add r9,4 # advance row sum buffer by 1 dword - dec r11 # decrement rows remaining - jnz .LCopyPackA.ProcessNextRowM1\@ - -// -// Restore non-volatile registers and return. -// - -.LCopyPackA.ExitRoutine\@: - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret -.endm - - FUNCTION_ENTRY MlasGemmU8S8CopyPackAAvx2 - MlasGemmCopyPackAAvx2 0 - - FUNCTION_ENTRY MlasGemmS8CopyPackAAvx2Vnni - MlasGemmCopyPackAAvx2 1 - -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - B (rsi) - Supplies the address of the source matrix. - - ldb (rdx) - Supplies the number of elements per row of the source matrix. - - CountN (rcx) - Supplies the number of columns of the source matrix to copy. - - CountK (r8) - Supplies the number of rows of the source matrix to copy. - - ColumnSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the columns. - - BIsSigned - Supplies true if the source matrix is signed data, else false if - the source matrix is unsigned data. - -Return Value: - - None. - ---*/ - -.macro MlasGemmCopyPackBAvx2 IsVnni, BSigned - - push rbp - push rbx - - mov r10,rdx - lea r11,[r10+r10*2] # compute ldb * 3 - vpcmpeqw ymm7,ymm7,ymm7 # generate word vector [0xFFFF] - vpsrlw ymm7,ymm7,15 # generate word vector [0x0001] - vpsllw ymm8,ymm7,8 # generate word vector [0x0100] - vpor ymm8,ymm7,ymm8 # generate word vector [0x0101] - -// -// Compute the bit flip vector to adjust input from U8 to S8. -// - - vpxor xmm9,xmm9,xmm9 # generate word vector [0x0000] -.if \IsVnni\() == 0 - cmp BYTE PTR .LGemmInt8CopyPackBFrame_BIsSigned[rsp],0 - jnz .LCopyPackB.SkipUnsignedBitFlipVector\@ - vpsllw ymm9,ymm8,7 # generate word vector [0x8080] -.endif -.LCopyPackB.SkipUnsignedBitFlipVector\@: - -// -// Process 16 columns of matrix B in a loop. -// - - sub rcx,16 - jb .LCopyPackB.ProcessRemainingColumns\@ - -.LCopyPackB.ProcessNextColumnN16\@: - vpxor xmm0,xmm0,xmm0 # clear column accumulators - vpxor xmm1,xmm1,xmm1 - mov rdx,rsi - add rsi,16 # advance next matrix B by 16 columns - mov rbx,r8 # reload rows remaining - sub rbx,4 - jb .LCopyPackB.ProcessRemainingRowsN16\@ - -.LCopyPackB.ProcessNextRowLoopN16\@: - vmovdqu xmm2,XMMWORD PTR [rdx] # load 4 rows - vmovdqu xmm3,XMMWORD PTR [rdx+r10] - vmovdqu xmm4,XMMWORD PTR [rdx+r10*2] - vmovdqu xmm5,XMMWORD PTR [rdx+r11] - lea rdx,[rdx+r10*4] # advance matrix B by 4 rows - -.LCopyPackB.InterleaveRowDataN16\@: - vpunpcklbw xmm6,xmm2,xmm3 # interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vpunpcklbw xmm2,xmm4,xmm5 - vpunpckhbw xmm5,xmm4,xmm5 - vpunpcklwd xmm4,xmm6,xmm2 - vpunpckhwd xmm6,xmm6,xmm2 - vpunpcklwd xmm2,xmm3,xmm5 - vpunpckhwd xmm3,xmm3,xmm5 - vinserti128 ymm4,ymm4,xmm6,1 - vinserti128 ymm2,ymm2,xmm3,1 -.if \IsVnni\() == 0 - vpxor ymm4,ymm4,ymm9 # optionally adjust unsigned data - vpxor ymm2,ymm2,ymm9 -.endif - vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows - vmovdqu YMMWORD PTR [rdi+32],ymm2 -.if \IsVnni\() == 1 - .if \BSigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbssdYmmYmmYmm ymm1,ymm2,ymm8 - .else - VpdpbuudYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbuudYmmYmmYmm ymm1,ymm2,ymm8 - .endif -.else - vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row - vpmaddwd ymm4,ymm4,ymm7 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddubsw ymm2,ymm8,ymm2 - vpmaddwd ymm2,ymm2,ymm7 - vpaddd ymm1,ymm1,ymm2 -.endif - add rdi,64 # advance matrix D by 64 bytes - sub rbx,4 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopN16\@ - -// -// Process the less than 4 remaining rows where the row has 16 columns. -// - -.LCopyPackB.ProcessRemainingRowsN16\@: - add rbx,4 # correct for over-subtract above - jz .LCopyPackB.StoreColumnSumBufferN16\@ - vmovdqu xmm2,XMMWORD PTR [rdx] - vmovaps xmm3,xmm9 - vmovaps xmm4,xmm9 - vmovaps xmm5,xmm9 - xor ebx,ebx # no more rows remaining - test r8b,2 # (CountK & 2) != 0? - jz .LCopyPackB.InterleaveRowDataN16\@ - vmovdqu xmm3,XMMWORD PTR [rdx+r10] - test r8b,1 # (CountK & 1) != 0? - jz .LCopyPackB.InterleaveRowDataN16\@ - vmovdqu xmm4,XMMWORD PTR [rdx+r10*2] - jmp .LCopyPackB.InterleaveRowDataN16\@ - -.LCopyPackB.StoreColumnSumBufferN16\@: - vmovdqu YMMWORD PTR [r9],ymm0 - vmovdqu YMMWORD PTR [r9+32],ymm1 - add r9,16*4 # advance column sum buffer by 16 dwords - sub rcx,16 # subtract columns remaining - jae .LCopyPackB.ProcessNextColumnN16\@ - -.LCopyPackB.ProcessRemainingColumns\@: - add rcx,16 # correct for over-subtract above - jnz .LCopyPackB.ProcessColumnNUnaligned\@ - -// -// Restore non-volatile registers and return. -// - -.LCopyPackB.ExitRoutine\@: - vzeroupper - - pop rbx - pop rbp - ret - -// -// Process the remaining columns of matrix B. -// - -.LCopyPackB.ProcessColumnNUnaligned\@: - vpxor xmm0,xmm0,xmm0 # clear column accumulators - vpxor xmm1,xmm1,xmm1 - vmovdqu YMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp],ymm9 - vmovdqu YMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp+32],ymm9 - sub r8,4 - jb .LCopyPackB.ProcessRemainingRowsNUnaligned\@ - -.LCopyPackB.ProcessNextRowLoopNUnaligned\@: - mov rdx,rsi - lea rbp,.LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp] - test cl,8 # (CountN & 8) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan8K4\@ - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r10] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [rdx+r10*2] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [rdx+r11] - mov QWORD PTR [rbp+48],rax - add rdx,8 # advance matrix B - add rbp,8 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan8K4\@: - test cl,4 # (CountN & 4) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan4K4\@ - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r10] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [rdx+r10*2] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [rdx+r11] - mov DWORD PTR [rbp+48],eax - add rdx,4 # advance matrix B - add rbp,4 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan4K4\@: - test cl,2 # (CountN & 2) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan2K4\@ - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r10] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [rdx+r10*2] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [rdx+r11] - mov WORD PTR [rbp+48],ax - add rdx,2 # advance matrix B - add rbp,2 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan2K4\@: - test cl,1 # (CountN & 1) != 0? - jz .LCopyPackB.ProcessPaddedMatrixBData\@ - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r10] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [rdx+r10*2] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [rdx+r11] - mov BYTE PTR [rbp+48],al - -.LCopyPackB.ProcessPaddedMatrixBData\@: - vmovdqu xmm2,XMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp] - vmovdqu xmm3,XMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp+16] - vmovdqu xmm4,XMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp+32] - vmovdqu xmm5,XMMWORD PTR .LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp+48] - vpunpcklbw xmm6,xmm2,xmm3 # interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vpunpcklbw xmm2,xmm4,xmm5 - vpunpckhbw xmm5,xmm4,xmm5 - vpunpcklwd xmm4,xmm6,xmm2 - vpunpckhwd xmm6,xmm6,xmm2 - vpunpcklwd xmm2,xmm3,xmm5 - vpunpckhwd xmm3,xmm3,xmm5 - vinserti128 ymm4,ymm4,xmm6,1 - vinserti128 ymm2,ymm2,xmm3,1 -.if \IsVnni\() == 0 - vpxor ymm4,ymm4,ymm9 # optionally adjust unsigned data - vpxor ymm2,ymm2,ymm9 -.endif - vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows - vmovdqu YMMWORD PTR [rdi+32],ymm2 -.if \IsVnni\() == 1 - .if \BSigned\() == 1 - VpdpbssdYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbssdYmmYmmYmm ymm1,ymm2,ymm8 - .else - VpdpbuudYmmYmmYmm ymm0,ymm4,ymm8 - VpdpbuudYmmYmmYmm ymm1,ymm2,ymm8 - .endif -.else - vpmaddubsw ymm4,ymm8,ymm4 # horizontal byte+byte=word per row - vpmaddwd ymm4,ymm4,ymm7 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddubsw ymm2,ymm8,ymm2 - vpmaddwd ymm2,ymm2,ymm7 - vpaddd ymm1,ymm1,ymm2 -.endif - lea rsi,[rsi+r10*4] # advance next matrix B by 4 rows - add rdi,64 # advance matrix D by 64 bytes - sub r8,4 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopNUnaligned\@ - -.LCopyPackB.ProcessRemainingRowsNUnaligned\@: - add r8,4 - jz .LCopyPackB.StoreColumnSumBufferNUnaligned\@ - -// -// Process the less than 4 remaining rows where the row has less than 16 columns. -// - - lea rbp,.LGemmInt8CopyPackBFrame_PaddedMatrixBData[rsp] - vmovdqu YMMWORD PTR [rbp],ymm9 - vmovdqu YMMWORD PTR [rbp+32],ymm9 - -.LCopyPackB.CopyUnalignedRowLoop\@: - lea r11,[rbp+16] # advance next padded buffer by 16 bytes - mov rdx,rsi - test cl,8 # (CountN & 8) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan8KSmall\@ - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 # advance matrix B - add rbp,8 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan8KSmall\@: - test cl,4 # (CountN & 4) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan4KSmall\@ - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 # advance matrix B - add rbp,4 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan4KSmall\@: - test cl,2 # (CountN & 2) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan2KSmall\@ - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 # advance matrix B - add rbp,2 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan2KSmall\@: - test cl,1 # (CountN & 1) != 0? - jz .LCopyPackB.DoneCopyRemainingCountNKSmall\@ - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -.LCopyPackB.DoneCopyRemainingCountNKSmall\@: - dec r8 - jz .LCopyPackB.ProcessPaddedMatrixBData\@ - add rsi,r10 # advance next matrix B by 1 row - mov rbp,r11 - jmp .LCopyPackB.CopyUnalignedRowLoop\@ - -.LCopyPackB.StoreColumnSumBufferNUnaligned\@: - vmovdqu YMMWORD PTR [r9],ymm0 - vmovdqu YMMWORD PTR [r9+32],ymm1 - jmp .LCopyPackB.ExitRoutine\@ - -.endm - - FUNCTION_ENTRY MlasGemmU8S8CopyPackBAvx2 - MlasGemmCopyPackBAvx2 0, 0 # sign variable not checked if IsVnni = 0 - - FUNCTION_ENTRY MlasGemmU8CopyPackBAvx2Vnni - MlasGemmCopyPackBAvx2 1, 0 - - FUNCTION_ENTRY MlasGemmS8CopyPackBAvx2Vnni - MlasGemmCopyPackBAvx2 1, 1 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S deleted file mode 100644 index 2bdef12aebf22..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S +++ /dev/null @@ -1,599 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8U8KernelAvx2.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX2 instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the U8U8 CopyPackA routine. -// - - .equ .LGemmU8U8CopyPackAFrame_PaddedMatrixAData, -72 - .equ .LGemmU8U8CopyPackAFrame_Padding, -8 - .equ .LGemmU8U8CopyPackAFrame_SavedR13, 0 - .equ .LGemmU8U8CopyPackAFrame_SavedR12, 8 - .equ .LGemmU8U8CopyPackAFrame_SavedRbx, 16 - .equ .LGemmU8U8CopyPackAFrame_SavedRbp, 24 - .equ .LGemmU8U8CopyPackAFrame_ReturnAddress, 32 - -// -// Stack frame layout for the U8U8 CopyPackB routine. -// - - .equ .LGemmU8U8CopyPackBFrame_PaddedMatrixBData, -40 - .equ .LGemmU8U8CopyPackBFrame_Padding, -8 - .equ .LGemmU8U8CopyPackBFrame_SavedRbx, 0 - .equ .LGemmU8U8CopyPackBFrame_SavedRbp, 8 - .equ .LGemmU8U8CopyPackBFrame_ReturnAddress, 16 - - .text - -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - - The kernel expects that elements from matrix A have been zero extended to - 16-bits and padded to a multiple of 32-bits (two pairs of 16-bit values). - The kernel can then efficiently broadcast 32-bits from the packed buffer - and avoid expensive shuffling inside the kernel. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - A (rsi) - Supplies the address of the source matrix. - - lda (rdx) - Supplies the number of elements per row of the source matrix. - - CountM (rcx) - Supplies the number of rows of the source matrix to copy. - - CountK (r8) - Supplies the number of columns of the source matrix to copy. - - RowSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the rows. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasGemmU8U8CopyPackAAvx2 - - push rbp - push rbx - push r12 - push r13 - - mov r10,rdx - mov r11,rcx - lea r12,[r8+1] - and r12,NOT 1 # align CountK up to pair count - vpcmpeqw ymm8,ymm8,ymm8 # generate word vector [0xFFFF] - vpsrlw ymm8,ymm8,15 # generate word vector [0x0001] - -// -// Compute the conditional load/store mask for an unaligned CountK. -// - - mov eax,r8d - and eax,15 # isolate unaligned count - inc eax - shr eax,1 # align unaligned count to pair count - neg rax - lea rbx,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovdqu ymm9,YMMWORD PTR [rbx+rax*4] - -// -// Zero initialize the padded stack buffers. -// - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp],ymm0 - vmovdqu YMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32],ymm0 - -// -// Process 4 rows of matrix A in a loop. -// -// Zero extend the source bytes to 16-bits and write to the packed buffer. -// -// The packed buffer has the same data ordering as the source bytes, but CountK -// is aligned up to a multiple of 2 to maintain 32-bit alignment. All padding -// bytes are zero filled. -// - - sub r11,4 - jb .LCopyPackA.ProcessRemainingRows - -.LCopyPackA.ProcessNextRowM4: - vpxor xmm0,xmm0,xmm0 # clear row accumulators - vpxor xmm1,xmm1,xmm1 - vpxor xmm2,xmm2,xmm2 - vpxor xmm3,xmm3,xmm3 - mov rdx,rsi - mov rcx,rdi - lea rsi,[rsi+r10*4] # advance next matrix A by 4 rows - lea rdi,[rdi+r12*8] # advance next matrix D by 4 rows - mov rbx,r8 # reload columns remaining - sub rbx,16 - jb .LCopyPackA.ProcessRemainingColumnsM4 - -.LCopyPackA.ProcessNextColumnLoopM4: - lea rax,[rdx+r10*2] # compute matrix A plus 2 rows - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vpmovzxbw ymm5,XMMWORD PTR [rdx+r10] - vpmovzxbw ymm6,XMMWORD PTR [rax] - vpmovzxbw ymm7,XMMWORD PTR [rax+r10] - lea rax,[rcx+r12*4] # compute matrix D plus 2 rows - vmovdqu YMMWORD PTR [rcx],ymm4 - vmovdqu YMMWORD PTR [rcx+r12*2],ymm5 - vmovdqu YMMWORD PTR [rax],ymm6 - vmovdqu YMMWORD PTR [rax+r12*2],ymm7 - vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns - vpaddw ymm1,ymm1,ymm5 - vpaddw ymm2,ymm2,ymm6 - vpaddw ymm3,ymm3,ymm7 - add rdx,16 # advance matrix A by 16 bytes - add rcx,16*2 # advance matrix D by 16 words - sub rbx,16 # subtract columns remaining - jae .LCopyPackA.ProcessNextColumnLoopM4 - -.LCopyPackA.ProcessRemainingColumnsM4: - add rbx,16 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM4 - -// -// Copy the unaligned CountK columns to a zero padded stack buffer. -// - - lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] - test bl,8 # (CountK & 8) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan8M4 - lea r13,[rdx+r10*2] # compute matrix A plus 2 rows - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r10] - mov QWORD PTR [rbp+16],rax - mov rax,QWORD PTR [r13] - mov QWORD PTR [rbp+32],rax - mov rax,QWORD PTR [r13+r10] - mov QWORD PTR [rbp+48],rax - add rdx,8 - add rbp,8 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan8M4: - test bl,4 # (CountK & 4) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan4M4 - lea r13,[rdx+r10*2] # compute matrix A plus 2 rows - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r10] - mov DWORD PTR [rbp+16],eax - mov eax,DWORD PTR [r13] - mov DWORD PTR [rbp+32],eax - mov eax,DWORD PTR [r13+r10] - mov DWORD PTR [rbp+48],eax - add rdx,4 - add rbp,4 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan4M4: - test bl,2 # (CountK & 2) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan2M4 - lea r13,[rdx+r10*2] # compute matrix A plus 2 rows - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r10] - mov WORD PTR [rbp+16],ax - movzx eax,WORD PTR [r13] - mov WORD PTR [rbp+32],ax - movzx eax,WORD PTR [r13+r10] - mov WORD PTR [rbp+48],ax - add rdx,2 - add rbp,2 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan2M4: - test bl,1 # (CountK & 1) != 0? - jz .LCopyPackA.ProcessPaddedMatrixADataM4 - lea r13,[rdx+r10*2] # compute matrix A plus 2 rows - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r10] - mov BYTE PTR [rbp+16],al - movzx eax,BYTE PTR [r13] - mov BYTE PTR [rbp+32],al - movzx eax,BYTE PTR [r13+r10] - mov BYTE PTR [rbp+48],al - -// -// Process the remaining CountK columns using the zero padded stack buffer. -// - -.LCopyPackA.ProcessPaddedMatrixADataM4: - vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] - vpmovzxbw ymm5,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+16] - vpmovzxbw ymm6,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+32] - vpmovzxbw ymm7,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp+48] - lea rax,[rcx+r12*4] # compute matrix D plus 2 rows - vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 - vpmaskmovd YMMWORD PTR [rcx+r12*2],ymm9,ymm5 - vpmaskmovd YMMWORD PTR [rax],ymm9,ymm6 - vpmaskmovd YMMWORD PTR [rax+r12*2],ymm9,ymm7 - vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns - vpaddw ymm1,ymm1,ymm5 - vpaddw ymm2,ymm2,ymm6 - vpaddw ymm3,ymm3,ymm7 - -// -// Reduce the sums for the four rows of output. -// - -.LCopyPackA.ReduceRowSumBufferM4: - vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row - vpmaddwd ymm1,ymm1,ymm8 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 - vpmaddwd ymm2,ymm2,ymm8 - vpmaddwd ymm3,ymm3,ymm8 - vphaddd ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 - vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 - vextracti128 xmm1,ymm0,1 # extract high dwords - vpaddd xmm0,xmm0,xmm1 # reduce low/high dwords - vmovdqu XMMWORD PTR [r9],xmm0 - add r9,4*4 # advance row sum buffer by 4 dwords - sub r11,4 # subtract rows remaining - jae .LCopyPackA.ProcessNextRowM4 - -.LCopyPackA.ProcessRemainingRows: - add r11,4 # correct for over-subtract above - jz .LCopyPackA.ExitRoutine - -// -// Process a single row of matrix A in a loop. -// - -.LCopyPackA.ProcessNextRowM1: - vpxor xmm0,xmm0,xmm0 # clear row accumulator - mov rdx,rsi - mov rcx,rdi - add rsi,r10 - lea rdi,[rdi+r12*2] - mov rbx,r8 # reload columns remaining - sub rbx,16 - jb .LCopyPackA.ProcessRemainingColumnsM1 - -.LCopyPackA.ProcessNextColumnLoopM1: - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rcx],ymm4 - vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns - add rdx,16 # advance matrix A by 16 bytes - add rcx,16*2 # advance matrix D by 16 words - sub rbx,16 # subtract columns remaining - jae .LCopyPackA.ProcessNextColumnLoopM1 - -.LCopyPackA.ProcessRemainingColumnsM1: - add rbx,16 # correct for over-subtract above - jz .LCopyPackA.ReduceRowSumBufferM1 - -// -// Copy the unaligned CountK columns to a zero padded stack buffer. -// - - lea rbp,.LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] - test bl,8 # (CountK & 8) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan8M1 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 - add rbp,8 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan8M1: - test bl,4 # (CountK & 4) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan4M1 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 - add rbp,4 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan4M1: - test bl,2 # (CountK & 2) != 0? - jz .LCopyPackA.CopyRemainingCountKLessThan2M1 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 - add rbp,2 # advance padded buffer destination - -.LCopyPackA.CopyRemainingCountKLessThan2M1: - test bl,1 # (CountK & 1) != 0? - jz .LCopyPackA.ProcessPaddedMatrixADataM1 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -// -// Process the remaining CountK columns using the zero padded stack buffer. -// - -.LCopyPackA.ProcessPaddedMatrixADataM1: - vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackAFrame_PaddedMatrixAData[rsp] - vpmaskmovd YMMWORD PTR [rcx],ymm9,ymm4 - vpaddw ymm0,ymm0,ymm4 # accumulate per row along columns - -// -// Reduce the sum for the single row of output. -// - -.LCopyPackA.ReduceRowSumBufferM1: - vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row - vextracti128 xmm1,ymm0,1 # extract high dwords - vpaddd xmm0,xmm0,xmm1 # reduction - vphaddd xmm0,xmm0,xmm0 - vphaddd xmm0,xmm0,xmm0 - vmovd DWORD PTR [r9],xmm0 - add r9,4 # advance row sum buffer by 1 dword - dec r11 # decrement rows remaining - jnz .LCopyPackA.ProcessNextRowM1 - -// -// Restore non-volatile registers and return. -// - -.LCopyPackA.ExitRoutine: - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - -/*++ - -Routine Description: - - This routine copies elements from the source matrix to the destination - packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - B (rsi) - Supplies the address of the source matrix. - - ldb (rdx) - Supplies the number of elements per row of the source matrix. - - CountN (rcx) - Supplies the number of columns of the source matrix to copy. - - CountK (r8) - Supplies the number of rows of the source matrix to copy. - - ColumnSumBuffer (r9) - Supplies the address of the buffer to receive the sums - of the elements along each of the columns. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasGemmU8U8CopyPackBAvx2 - - push rbp - push rbx - - mov r10,rdx - vpcmpeqw ymm5,ymm5,ymm5 # generate word vector [0xFFFF] - vpsrlw ymm5,ymm5,15 # generate word vector [0x0001] - -// -// Zero initialize the padded stack buffers. -// - - vpxor xmm0,xmm0,xmm0 - vmovdqu YMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp],ymm0 - -// -// Process 16 columns of matrix B in a loop. -// - - sub rcx,16 - jb .LCopyPackB.ProcessRemainingColumns - -.LCopyPackB.ProcessNextColumnN16: - vpxor xmm0,xmm0,xmm0 # clear column accumulators - vpxor xmm1,xmm1,xmm1 - mov rdx,rsi - add rsi,16 # advance next matrix B by 16 columns - mov rbx,r8 # reload rows remaining - sub rbx,2 - jb .LCopyPackB.ProcessRemainingRowsN16 - -.LCopyPackB.ProcessNextRowLoopN16: - vmovdqu xmm2,XMMWORD PTR [rdx] # load 2 rows - vmovdqu xmm3,XMMWORD PTR [rdx+r10] - lea rdx,[rdx+r10*2] # advance matrix B by 2 rows - vpunpcklbw xmm4,xmm2,xmm3 # interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows - vmovdqu XMMWORD PTR [rdi+16],xmm3 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - add rdi,32 # advance matrix D by 32 bytes - vpmaddwd ymm4,ymm4,ymm5 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - sub rbx,2 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopN16 - -.LCopyPackB.ProcessRemainingRowsN16: - add rbx,2 # correct for over-subtract above - jz .LCopyPackB.StoreColumnSumBufferN16 - vpmovzxbw ymm4,XMMWORD PTR [rdx] - vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows - vextracti128 xmm3,ymm4,1 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - add rdi,32 # advance matrix D by 32 bytes - -.LCopyPackB.StoreColumnSumBufferN16: - vmovdqu YMMWORD PTR [r9],ymm0 - vmovdqu YMMWORD PTR [r9+32],ymm1 - add r9,64 # advance column sum buffer by 16 dwords - sub rcx,16 # subtract columns remaining - jae .LCopyPackB.ProcessNextColumnN16 - -.LCopyPackB.ProcessRemainingColumns: - add rcx,16 # correct for over-subtract above - jnz .LCopyPackB.ProcessColumnNUnaligned - -// -// Restore non-volatile registers and return. -// - -.LCopyPackB.ExitRoutine: - vzeroupper - - pop rbx - pop rbp - ret - -// -// Process the remaining columns of matrix B. -// - -.LCopyPackB.ProcessColumnNUnaligned: - vpxor xmm0,xmm0,xmm0 # clear column accumulators - vpxor xmm1,xmm1,xmm1 - sub r8,2 - jb .LCopyPackB.ProcessRemainingRowsNUnaligned - -.LCopyPackB.ProcessNextRowLoopNUnaligned: - mov rdx,rsi - lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - test cl,8 # (CountN & 8) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan8K2 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - mov rax,QWORD PTR [rdx+r10] - mov QWORD PTR [rbp+16],rax - add rdx,8 # advance matrix B - add rbp,8 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan8K2: - test cl,4 # (CountN & 4) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan4K2 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - mov eax,DWORD PTR [rdx+r10] - mov DWORD PTR [rbp+16],eax - add rdx,4 # advance matrix B - add rbp,4 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan4K2: - test cl,2 # (CountN & 2) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan2K2 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - movzx eax,WORD PTR [rdx+r10] - mov WORD PTR [rbp+16],ax - add rdx,2 # advance matrix B - add rbp,2 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan2K2: - test cl,1 # (CountN & 1) != 0? - jz .LCopyPackB.ProcessPaddedMatrixBDataK2 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - movzx eax,BYTE PTR [rdx+r10] - mov BYTE PTR [rbp+16],al - -.LCopyPackB.ProcessPaddedMatrixBDataK2: - vmovdqu xmm2,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - vmovdqu xmm3,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp+16] - vpunpcklbw xmm4,xmm2,xmm3 # interleave row data - vpunpckhbw xmm3,xmm2,xmm3 - vmovdqu XMMWORD PTR [rdi],xmm4 # store interleaved rows - vmovdqu XMMWORD PTR [rdi+16],xmm3 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - lea rsi,[rsi+r10*2] # advance next matrix B by 2 rows - add rdi,32 # advance matrix D by 32 bytes - sub r8,2 # subtract rows remaining - jae .LCopyPackB.ProcessNextRowLoopNUnaligned - -.LCopyPackB.ProcessRemainingRowsNUnaligned: - add r8,2 - jz .LCopyPackB.StoreColumnSumBufferNUnaligned - mov rdx,rsi - lea rbp,.LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - test cl,8 # (CountN & 8) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan8K1 - mov rax,QWORD PTR [rdx] - mov QWORD PTR [rbp],rax - add rdx,8 # advance matrix B - add rbp,8 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan8K1: - test cl,4 # (CountN & 4) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan4K1 - mov eax,DWORD PTR [rdx] - mov DWORD PTR [rbp],eax - add rdx,4 # advance matrix B - add rbp,4 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan4K1: - test cl,2 # (CountN & 2) != 0? - jz .LCopyPackB.CopyRemainingCountNLessThan2K1 - movzx eax,WORD PTR [rdx] - mov WORD PTR [rbp],ax - add rdx,2 # advance matrix B - add rbp,2 # advance padded buffer destination - -.LCopyPackB.CopyRemainingCountNLessThan2K1: - test cl,1 # (CountN & 1) != 0? - jz .LCopyPackB.ProcessPaddedMatrixBDataK1 - movzx eax,BYTE PTR [rdx] - mov BYTE PTR [rbp],al - -.LCopyPackB.ProcessPaddedMatrixBDataK1: - vpmovzxbw ymm4,XMMWORD PTR .LGemmU8U8CopyPackBFrame_PaddedMatrixBData[rsp] - vmovdqu YMMWORD PTR [rdi],ymm4 # store interleaved rows - vextracti128 xmm3,ymm4,1 - vpmovzxbw ymm4,xmm4 - vpmovzxbw ymm3,xmm3 - vpmaddwd ymm4,ymm4,ymm5 # horizontal word+word=dword per row - vpaddd ymm0,ymm0,ymm4 # accumulate per column - vpmaddwd ymm3,ymm3,ymm5 - vpaddd ymm1,ymm1,ymm3 - -.LCopyPackB.StoreColumnSumBufferNUnaligned: - vmovdqu YMMWORD PTR [r9],ymm0 - vmovdqu YMMWORD PTR [r9+32],ymm1 - jmp .LCopyPackB.ExitRoutine - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S deleted file mode 100644 index af2a475ea0c59..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S +++ /dev/null @@ -1,934 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelAvx2.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX2 and AVX VNNI instructions. - AVX-VNNI-INT8 support also included. - ---*/ - -#include "asmmacro.h" -#include "AssembleAvxVnni.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the Int8 kernel. -// - - .equ .LGemmInt8KernelFrame_type, -8 - .equ .LGemmInt8KernelFrame_SavedR13, 0 - .equ .LGemmInt8KernelFrame_SavedR12, 8 - .equ .LGemmInt8KernelFrame_SavedRbx, 16 - .equ .LGemmInt8KernelFrame_SavedRbp, 24 - .equ .LGemmInt8KernelFrame_ReturnAddress, 32 - .equ .LGemmInt8KernelFrame_ldc, 40 - .equ .LGemmInt8KernelFrame_RowSumBuffer, 48 - .equ .LGemmInt8KernelFrame_ColumnSumBuffer, 56 - .equ .LGemmInt8KernelFrame_ZeroPointB, 64 - .equ .LGemmInt8KernelFrame_ZeroMode, 72 - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateRowU8S8Avx2 ColumnCount, Vec1Reg, Vec2Reg - - vpmaddubsw ymm3,ymm2,ymm0 - vpmaddwd ymm3,ymm3,ymm12 -.if \ColumnCount\() == 16 - vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 - vpmaddubsw ymm2,ymm2,ymm1 - vpmaddwd ymm2,ymm2,ymm12 - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 -.else - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm11 - Supplies the block accumulators. - - ymm12 - Supplies a 256-bit with the broadcasted word value 0x0001. - ---*/ - - .macro ComputeBlockAvx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset, ASigned, BSigned - -.if \RowCount\() == 1 - vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()] - vpmaddubsw ymm3,ymm2,YMMWORD PTR [rsi+\VectorOffset\()] - vpmaddwd ymm3,ymm3,ymm12 -.if \ColumnCount\() == 16 - vpaddd ymm4,ymm4,ymm3 - vpmaddubsw ymm2,ymm2,YMMWORD PTR [rsi+\VectorOffset\()+32] - vpmaddwd ymm2,ymm2,ymm12 - vpaddd ymm5,ymm5,ymm2 -.else - vpaddd ymm5,ymm5,ymm3 -.endif -.else - vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm8, ymm9" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8S8Avx2 \ColumnCount\(), ymm10, ymm11" -.endif - - .endm - -/*++ -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - ---*/ - - .macro MultiplyAccumulateRowAvxVnni ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned - -.if \ASigned\() == 1 - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbssdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbsudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.else - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbusdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbuudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockAvxVnni ColumnCount, RowCount, VectorOffset, BroadcastOffset, ASigned, BSigned - - vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm4, ymm5, \ASigned\(), \BSigned\()" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm6, ymm7, \ASigned\(), \BSigned\()" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm8, ymm9, \ASigned\(), \BSigned\()" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm10, ymm11, \ASigned\(), \BSigned\()" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm12, ymm13, \ASigned\(), \BSigned\()" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm14, ymm15, \ASigned\(), \BSigned\()" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple times - and advancing the matrix A and matrix B data pointers. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm11 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoop Isa, ColumnCount, RowCount, ASigned, BSigned - - mov rbp,rcx # reload row length remaining - -.if (\ColumnCount\() == 16) && (\RowCount\() == 1) - sub rbp,4*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - ComputeBlock\Isa\() \ColumnCount\(), \RowCount\(), 0*64, 0, \ASigned\(), \BSigned\() - ComputeBlock\Isa\() \ColumnCount\(), \RowCount\(), 1*64, 4, \ASigned\(), \BSigned\() - ComputeBlock\Isa\() \ColumnCount\(), \RowCount\(), 2*64, 8, \ASigned\(), \BSigned\() - ComputeBlock\Isa\() \ColumnCount\(), \RowCount\(), 3*64, 12, \ASigned\(), \BSigned\() - add rdi,4*4 # advance matrix A by 4 quads - add rsi,4*64 # advance matrix B - sub rbp,4*4 - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,4*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ -.endif - -.LComputeBlockBy1Loop\@: - ComputeBlock\Isa\() \ColumnCount\(), \RowCount\(), 0, 0, \ASigned\(), \BSigned\() - add rdi,4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add r8,4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,64 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ - -.LComputeBlockLoopExit\@: - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single row of the - output block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - Vec1Reg - Supplies the high block accumulator register (when ColumnCount - is 16). - - Vec2Reg - Supplies the low block accumulator register. - -Implicit Arguments: - - ymm0 - Supplies the first vector loaded from matrix B. - - ymm1 - Supplies the second vector loaded from matrix B (when ColumnCount - is 16). - - ymm2 - Supplies the broadcast value loaded from matrix A. - ---*/ - - .macro MultiplyAccumulateRowU8U8Avx2 ColumnCount, Vec1Reg, Vec2Reg - - vpmaddwd ymm3,ymm2,ymm0 -.if \ColumnCount\() == 16 - vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 - vpmaddwd ymm2,ymm2,ymm1 - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 -.else - vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm3 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockU8U8Avx2 ColumnCount, RowCount, VectorOffset, BroadcastOffset - - vpmovzxbw ymm0,XMMWORD PTR [rsi+\VectorOffset\()] - EmitIfCountGE \ColumnCount\(), 16, "vpmovzxbw ymm1,XMMWORD PTR [rsi+\VectorOffset\()+16]" - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm4, ymm5" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm6, ymm7" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm8, ymm9" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm10, ymm11" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm12, ymm13" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowU8U8Avx2 \ColumnCount\(), ymm14, ymm15" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple times - and advancing the matrix A and matrix B data pointers. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.if (\ColumnCount\() == 16) && ((\RowCount\() & 1) == 0) - sub rbp,2*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy2Loop\@: - ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 - ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 32, 4 - add rdi,2*4 # advance matrix A by 2 pairs -.if \RowCount\() > 3 - add r8,2*4 # advance matrix A plus 3 rows by 2 pairs -.endif - add rsi,2*32 # advance matrix B - sub rbp,2*4 - jae .LComputeBlockBy2Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,2*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ - ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 - add rsi,32 # advance matrix B -.else -.LComputeBlockBy1Loop\@: - ComputeBlockU8U8\Isa\() \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 pair -.if \RowCount\() > 3 - add r8,4 # advance matrix A plus 3 rows by 1 pair -.endif - add rsi,32 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ -.endif - -.LComputeBlockLoopExit\@: - - .endm - -/*++ - -Macro Description: - - This macro generates code to produce an output block for a set of columns - and rows. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r11 - Supplies the address of the row sum buffer. - - r12 - Supplies the address of the column sum buffer. - - ymm4-ymm15 - Supplies the block accumulators. - ---*/ - - .macro ProduceOutputBlock ColumnCount, RowCount, ASigned, BSigned - -// -// Initialize the accumulators with the row and column sums. -// - - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm5,DWORD PTR [r11]" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm7,DWORD PTR [r11+4]" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm9,DWORD PTR [r11+8]" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]" -.if \ColumnCount\() == 16 - vmovdqu ymm0,YMMWORD PTR [r12] - vmovdqu ymm1,YMMWORD PTR [r12+32] - add r12,16*4 # advance ColumnSumBuffer by 16 columns -.else - vmovdqu ymm1,YMMWORD PTR [r12] -.endif - test r13,r13 # per column zero points? - jz .LSkipScaleByZeroPointB\@ -.if \ColumnCount\() == 16 - vmovdqu ymm2,YMMWORD PTR [r13] - vmovdqu ymm3,YMMWORD PTR [r13+32] - add r13,16*4 # advance ZeroPointB by 16 columns -.else - vmovdqu ymm3,YMMWORD PTR [r13] -.endif - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2" - EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2" - EmitIfCountGE \RowCount\(), 2, "vpmulld ymm7,ymm7,ymm3" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm0,ymm6" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm1,ymm7" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld ymm8,ymm9,ymm2" - EmitIfCountGE \RowCount\(), 3, "vpmulld ymm9,ymm9,ymm3" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm0,ymm8" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm1,ymm9" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld ymm10,ymm11,ymm2" - EmitIfCountGE \RowCount\(), 4, "vpmulld ymm11,ymm11,ymm3" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm0,ymm10" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm1,ymm11" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld ymm12,ymm13,ymm2" - EmitIfCountGE \RowCount\(), 5, "vpmulld ymm13,ymm13,ymm3" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm0,ymm12" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm1,ymm13" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld ymm14,ymm15,ymm2" - EmitIfCountGE \RowCount\(), 6, "vpmulld ymm15,ymm15,ymm3" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm0,ymm14" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm1,ymm15" - jmp .LAccumulatorsInitialized\@ - -.LSkipScaleByZeroPointB\@: - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd ymm8,ymm9,ymm0" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd ymm10,ymm11,ymm0" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd ymm12,ymm13,ymm0" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd ymm14,ymm15,ymm0" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm1" - -.LAccumulatorsInitialized\@: - -// -// Iterate over the length of a matrix A row to produce the output accumulators. -// - -.if \RowCount\() > 3 - lea r8,[rcx*2+rcx] - add r8,rdi # compute matrix A plus 3 rows -.endif - cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],0 - jg .LProduceWithU8U8Avx2\@ -.if \RowCount\() <= 4 - jl .LProduceWithInt8AvxVnni\@ - ComputeBlockLoop Avx2, \ColumnCount\(), \RowCount\(), \ASigned\(), \BSigned\() - jmp .LExitProduceOutputBlock\@ -.endif - -.LProduceWithInt8AvxVnni\@: - ComputeBlockLoop AvxVnni, \ColumnCount\(), \RowCount\(), \ASigned\(), \BSigned\() - jmp .LExitProduceOutputBlock\@ - -.LProduceWithU8U8Avx2\@: - ComputeBlockLoopU8U8 Avx2, \ColumnCount\(), \RowCount\() - -.LExitProduceOutputBlock\@: -.if \RowCount\() > 3 - lea r8,[rax*2+rax] - add r8,rdx # compute matrix C plus 3 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - rbx - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the length in bytes of a row from matrix A. - - r10b - Supplies the zero mode flag. - - r11 - Supplies the address of the row sum buffer. - - r12 - Supplies the address of the column sum buffer. - ---*/ - - .macro ProcessCountM RowCount, ASigned, BSigned - - cmp r9,8 - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop16xN\@: - ProduceOutputBlock 16, \RowCount\(), \ASigned\(), \BSigned\() - sub r9,16 - jb .LOutputMasked16xNBlock\@ - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax+32]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2+32]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8+32]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax+32]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2+32]" - -.LSkipAccumulateOutput16xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx+32],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax+32],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2+32],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8+32],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax+32],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2+32],ymm15" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,rbx # reload matrix A - cmp r9,8 - ja .LProcessNextColumnLoop16xN\@ - test r9,r9 - jnz .LProcessRemainingCountN\@ - -.LExitProcessCountM\@: - mov eax,\RowCount\() - jmp .LExitKernel - -.LProcessRemainingCountN\@: - ProduceOutputBlock 8, \RowCount\(), \ASigned\(), \BSigned\() - cmp r9,8 - jb .LOutputMasked8xNBlock\@ - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput8xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,YMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,YMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,YMMWORD PTR [r8+rax*2]" - -.LSkipAccumulateOutput8xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm5" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm7" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm9" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm11" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm13" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm15" - jmp .LExitProcessCountM\@ - -.LOutputMasked16xNBlock\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutputMasked16xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm4,ymm4,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm6,ymm6,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm8,ymm8,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm10,ymm10,YMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm12,ymm12,YMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm14,ymm14,YMMWORD PTR [r8+rax*2]" - -.LSkipAccumulateOutputMasked16xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu YMMWORD PTR [rdx],ymm4" - EmitIfCountGE \RowCount\(), 2, "vmovdqu YMMWORD PTR [rdx+rax],ymm6" - EmitIfCountGE \RowCount\(), 3, "vmovdqu YMMWORD PTR [rdx+rax*2],ymm8" - EmitIfCountGE \RowCount\(), 4, "vmovdqu YMMWORD PTR [r8],ymm10" - EmitIfCountGE \RowCount\(), 5, "vmovdqu YMMWORD PTR [r8+rax],ymm12" - EmitIfCountGE \RowCount\(), 6, "vmovdqu YMMWORD PTR [r8+rax*2],ymm14" - add rdx,8*4 # advance matrix C by 8 columns -.if \RowCount\() > 3 - add r8,8*4 # advance matrix C plus 3 rows by 8 columns -.endif - add r9,8 # correct for over-subtract above - -.LOutputMasked8xNBlock\@: - neg r9 - lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutputMasked8xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd ymm6,ymm0,YMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd ymm10,ymm0,YMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd ymm12,ymm0,YMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd ymm14,ymm0,YMMWORD PTR [r8+rax*2]" - EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm4" - EmitIfCountGE \RowCount\(), 2, "vpaddd ymm7,ymm7,ymm6" - EmitIfCountGE \RowCount\(), 3, "vpaddd ymm9,ymm9,ymm8" - EmitIfCountGE \RowCount\(), 4, "vpaddd ymm11,ymm11,ymm10" - EmitIfCountGE \RowCount\(), 5, "vpaddd ymm13,ymm13,ymm12" - EmitIfCountGE \RowCount\(), 6, "vpaddd ymm15,ymm15,ymm14" - -.LSkipAccumulateOutputMasked8xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5" - EmitIfCountGE \RowCount\(), 2, "vpmaskmovd YMMWORD PTR [rdx+rax],ymm0,ymm7" - EmitIfCountGE \RowCount\(), 3, "vpmaskmovd YMMWORD PTR [rdx+rax*2],ymm0,ymm9" - EmitIfCountGE \RowCount\(), 4, "vpmaskmovd YMMWORD PTR [r8],ymm0,ymm11" - EmitIfCountGE \RowCount\(), 5, "vpmaskmovd YMMWORD PTR [r8+rax],ymm0,ymm13" - EmitIfCountGE \RowCount\(), 6, "vpmaskmovd YMMWORD PTR [r8+rax*2],ymm0,ymm15" - jmp .LExitProcessCountM\@ - - .endm - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmCopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmCopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A - and the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A. These values have - been pre-scaled by the zero point offset of matrix B if the offset is - per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - -.macro MlasGemmInt8KernelAvx2 ASigned, BSigned - - push rbp - push rbx - push r12 - push r13 - - mov DWORD PTR .LGemmInt8KernelFrame_type[rsp],eax - mov rbx,rdi - mov rax,.LGemmInt8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmInt8KernelFrame_ZeroMode[rsp] - mov r11,.LGemmInt8KernelFrame_RowSumBuffer[rsp] - mov r12,.LGemmInt8KernelFrame_ColumnSumBuffer[rsp] - mov r13,.LGemmInt8KernelFrame_ZeroPointB[rsp] - vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] - vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] - cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],0 - je .LCheckCountM4OrMore\@ # U8S8 AVX2 kernel requires extra registers - -// -// Process CountM rows of the matrices. -// - -.LCheckCountM6OrMore\@: - cmp r8,5 - ja .LProcessCountM6\@ - je .LProcessCountM5\@ - -.LCheckCountM4OrMore\@: - cmp r8,3 - ja .LProcessCountM4\@ - je .LProcessCountM3\@ - cmp r8,1 - je .LProcessCountM1\@ - -.LProcessCountM2\@: - ProcessCountM 2, \ASigned\(), \BSigned\() - -.LProcessCountM4\@: - ProcessCountM 4, \ASigned\(), \BSigned\() - -.LProcessCountM6\@: - ProcessCountM 6, \ASigned\(), \BSigned\() - -.LProcessCountM1\@: - ProcessCountM 1, \ASigned\(), \BSigned\() - -.LProcessCountM3\@: - ProcessCountM 3, \ASigned\(), \BSigned\() - -.LProcessCountM5\@: - ProcessCountM 5, \ASigned\(), \BSigned\() - -.endm - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - vzeroupper - - pop r13 - pop r12 - pop rbx - pop rbp - ret - -// -// Reduce code size for the various types of kernels by sharing the outer logic -// and switching on the selector codes (using sign bit to discriminate). -// - - FUNCTION_ENTRY MlasGemmU8S8KernelAvxVnni - - mov eax,-1 - MlasGemmInt8KernelAvx2 0, 1 - - FUNCTION_ENTRY MlasGemmU8U8KernelAvx2Vnni - - mov eax,-1 - MlasGemmInt8KernelAvx2 0, 0 - - FUNCTION_ENTRY MlasGemmU8U8KernelAvx2 - - mov eax,1 - MlasGemmInt8KernelAvx2 0, 0 - - FUNCTION_ENTRY MlasGemmU8S8KernelAvx2 - - xor eax,eax - MlasGemmInt8KernelAvx2 0, 1 - - FUNCTION_ENTRY MlasGemmS8S8KernelAvx2Vnni - - mov eax,-1 - MlasGemmInt8KernelAvx2 1, 1 - - FUNCTION_ENTRY MlasGemmS8U8KernelAvx2Vnni - - mov eax,-1 - MlasGemmInt8KernelAvx2 1, 0 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S deleted file mode 100644 index 279f406b88940..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx512Core.S +++ /dev/null @@ -1,709 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemmU8X8KernelAvx512Core.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/matrix - multiply operation (QGEMM). - - This implementation uses AVX512 core (BW/DQ/VL) and AVX512 VNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "AssembleAvx512Vnni.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the U8X8 kernel. -// - - .equ .LGemmU8X8KernelFrame_type, -8 - .equ .LGemmU8X8KernelFrame_SavedR14, 0 - .equ .LGemmU8X8KernelFrame_SavedR13, 8 - .equ .LGemmU8X8KernelFrame_SavedR12, 16 - .equ .LGemmU8X8KernelFrame_SavedRbx, 24 - .equ .LGemmU8X8KernelFrame_SavedRbp, 32 - .equ .LGemmU8X8KernelFrame_ReturnAddress, 40 - .equ .LGemmU8X8KernelFrame_ldc, 48 - .equ .LGemmU8X8KernelFrame_RowSumBuffer, 56 - .equ .LGemmU8X8KernelFrame_ColumnSumBuffer, 64 - .equ .LGemmU8X8KernelFrame_ZeroPointB, 72 - .equ .LGemmU8X8KernelFrame_ZeroMode, 80 - - .text - -/*++ - -Macro Description: - - This macro generates code to load packed data from matrix B. - -Arguments: - - VecReg - Supplies the register to load the data into. - - AddressOperand - Supplies the address operand. - ---*/ - - .macro LoadPackedMatrixBU8S8 VecReg, AddressOperand - - vmovdqu32 \VecReg\(),ZMMWORD PTR \AddressOperand\() - - .endm - - .macro LoadPackedMatrixBU8U8 VecReg, AddressOperand - - vpmovzxbw \VecReg\(),YMMWORD PTR \AddressOperand\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulator a single cell of the - output block. - -Arguments: - - AccumReg - Supplies the register to accumulate into. - - Mult1Reg - Supplies the first multiplication operand register. - - Mult2Reg - Supplies the second multiplication operand register. - -Implicit Arguments: - - zmm4 - Supplies a scratch register for intermediate results. - - zmm13 - Supplies a 512-bit with the broadcasted word value 0x0001. - ---*/ - - .macro MultiplyAccumulateCellU8S8Avx512Core AccumReg, Mult1Reg, Mult2Reg - - vpmaddubsw zmm4,\Mult1Reg\(),\Mult2Reg\() - vpmaddwd zmm4,zmm4,zmm13 - vpaddd \AccumReg\(),\AccumReg\(),zmm4 - - .endm - - .macro MultiplyAccumulateCellU8S8Avx512Vnni AccumReg, Mult1Reg, Mult2Reg - - VpdpbusdsZmmZmmZmm \AccumReg\(),\Mult1Reg\(),\Mult2Reg\() - - .endm - - .macro MultiplyAccumulateCellU8U8Avx512Core AccumReg, Mult1Reg, Mult2Reg - - vpmaddwd zmm4,\Mult1Reg\(),\Mult2Reg\() - vpaddd \AccumReg\(),\AccumReg\(),zmm4 - - .endm - -/*++ - -Macro Description: - - This macro generates code to multiply and accumulate each row of the output - block. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - BroadcastOffset - Supplies the byte offset from matrix A to fetch elements. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock Type, Isa, ColumnCount, RowCount, VectorOffset, BroadcastOffset - -.if \ColumnCount\() >= 48 - LoadPackedMatrixB\Type\() zmm0,"[rsi+\VectorOffset\()]" - LoadPackedMatrixB\Type\() zmm1,"[rsi+r14+\VectorOffset\()]" - LoadPackedMatrixB\Type\() zmm2,"[rsi+r14*2+\VectorOffset\()]" -.elseif \ColumnCount\() >= 32 - LoadPackedMatrixB\Type\() zmm1,"[rsi+\VectorOffset\()]" - LoadPackedMatrixB\Type\() zmm2,"[rsi+r14+\VectorOffset\()]" -.else - LoadPackedMatrixB\Type\() zmm2,"[rsi+\VectorOffset\()]" -.endif - EmitIfCountGE \RowCount\(), 1, "vpbroadcastd zmm3,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm26,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm20,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm14,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 2, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm27,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm21,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm15,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 3, "vpbroadcastd zmm3,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm28,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm22,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm16,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 4, "vpbroadcastd zmm3,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm29,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm23,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm17,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 5, "vpbroadcastd zmm3,DWORD PTR [r8+rcx+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm30,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm24,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm18,zmm3,zmm2" - EmitIfCountGE \RowCount\(), 6, "vpbroadcastd zmm3,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "MultiplyAccumulateCell\Type\()\Isa\() zmm31,zmm3,zmm0" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "MultiplyAccumulateCell\Type\()\Isa\() zmm25,zmm3,zmm1" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "MultiplyAccumulateCell\Type\()\Isa\() zmm19,zmm3,zmm2" - - .endm - -/*++ - -Macro Description: - - This macro generates code to execute the block compute macro multiple times - and advancing the matrix A and matrix B data pointers. - -Arguments: - - Isa - Supplies the instruction set architecture string. - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rdi - Supplies the address into the matrix A data. - - r8 - Supplies the address into the matrix A data plus 3 rows. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - - zmm14-zmm31 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockLoopU8S8 Isa, ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.if (\RowCount\() == 1) || ((\RowCount\() & 1) == 0) - sub rbp,4*4 - jb .LProcessRemainingBlocks\@ - -.LComputeBlockBy4Loop\@: - ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0*64, 0 - ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 1*64, 4 - ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 2*64, 8 - ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 3*64, 12 - add rdi,4*4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add r8,4*4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,4*64 # advance matrix B - sub rbp,4*4 # decrement quads remaining - jae .LComputeBlockBy4Loop\@ - -.LProcessRemainingBlocks\@: - add rbp,4*4 # correct for over-subtract above - jz .LComputeBlockLoopExit\@ -.endif - -.LComputeBlockBy1Loop\@: - ComputeBlock U8S8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 quad -.if \RowCount\() > 3 - add r8,4 # advance matrix A plus 3 rows by 1 quad -.endif - add rsi,64 # advance matrix B - sub rbp,4 # decrement quads remaining - jnz .LComputeBlockBy1Loop\@ - -.LComputeBlockLoopExit\@: - - .endm - - .macro ComputeBlockLoopU8U8 Isa, ColumnCount, RowCount - - mov rbp,rcx # reload row length remaining - -.LComputeBlockBy1Loop\@: - ComputeBlock U8U8, \Isa\(), \ColumnCount\(), \RowCount\(), 0, 0 - add rdi,4 # advance matrix A by 1 pair -.if \RowCount\() > 3 - add r8,4 # advance matrix A plus 3 rows by 1 pair -.endif - add rsi,32 # advance matrix B - sub rbp,4 - jnz .LComputeBlockBy1Loop\@ - - .endm - -/*++ - -Macro Description: - - This macro generates code to produce an output block for a set of columns - and rows. - -Arguments: - - ColumnCount - Supplies the number of columns to produce. - - RowCount - Supplies the number of rows to produce. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address into the matrix A data. - - rsi - Supplies the address into the matrix B data. - - rcx - Supplies the length in bytes of a row from matrix A. - - r11 - Supplies the address of the row sum buffer. - - r12 - Supplies the address of the column sum buffer. - ---*/ - - .macro ProduceOutputBlock ColumnCount, RowCount - -// -// Initialize the accumulators with the row and column sums. -// - -.if \ColumnCount\() >= 32 -.if \ColumnCount\() >= 48 - vmovdqu32 zmm2,ZMMWORD PTR [r12] - vmovdqu32 zmm1,ZMMWORD PTR [r12+64] - vmovdqu32 zmm0,ZMMWORD PTR [r12+128] -.else - vmovdqu32 zmm1,ZMMWORD PTR [r12] - vmovdqu32 zmm0,ZMMWORD PTR [r12+64] -.endif - add_immed r12,\ColumnCount\()*4 # advance ColumnSumBuffer by N columns -.else - vmovdqu32 zmm0,ZMMWORD PTR [r12] -.endif - test r13,r13 # per column zero points? - jz .LSkipScaleByZeroPointB\@ -.if \ColumnCount\() >= 32 -.if \ColumnCount\() >= 48 - vmovdqu32 zmm5,ZMMWORD PTR [r13] - vmovdqu32 zmm4,ZMMWORD PTR [r13+64] - vmovdqu32 zmm3,ZMMWORD PTR [r13+128] -.else - vmovdqu32 zmm4,ZMMWORD PTR [r13] - vmovdqu32 zmm3,ZMMWORD PTR [r13+64] -.endif - add_immed r13,\ColumnCount\()*4 # advance ZeroPointB by N columns -.else - vmovdqu32 zmm3,ZMMWORD PTR [r13] -.endif - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld zmm14,zmm3,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld zmm20,zmm4,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpmulld zmm26,zmm5,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,zmm14" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,zmm20" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,zmm26" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld zmm15,zmm3,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpmulld zmm21,zmm4,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpmulld zmm27,zmm5,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,zmm15" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,zmm21" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,zmm27" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpmulld zmm16,zmm3,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpmulld zmm22,zmm4,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpmulld zmm28,zmm5,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,zmm16" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,zmm22" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,zmm28" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpmulld zmm17,zmm3,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpmulld zmm23,zmm4,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpmulld zmm29,zmm5,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,zmm17" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,zmm23" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,zmm29" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpmulld zmm18,zmm3,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpmulld zmm24,zmm4,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpmulld zmm30,zmm5,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,zmm18" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,zmm24" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,zmm30" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpmulld zmm19,zmm3,DWORD PTR [r11+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpmulld zmm25,zmm4,DWORD PTR [r11+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpmulld zmm31,zmm5,DWORD PTR [r11+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,zmm19" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,zmm25" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,zmm31" - jmp .LAccumulatorsInitialized\@ - -.LSkipScaleByZeroPointB\@: - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd zmm14,zmm0,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd zmm20,zmm1,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 48, "vpaddd zmm26,zmm2,DWORD PTR [r11]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd zmm15,zmm0,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 32, "vpaddd zmm21,zmm1,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 48, "vpaddd zmm27,zmm2,DWORD PTR [r11+4]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 16, "vpaddd zmm16,zmm0,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 32, "vpaddd zmm22,zmm1,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 3, \ColumnCount\(), 48, "vpaddd zmm28,zmm2,DWORD PTR [r11+8]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 16, "vpaddd zmm17,zmm0,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 32, "vpaddd zmm23,zmm1,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 4, \ColumnCount\(), 48, "vpaddd zmm29,zmm2,DWORD PTR [r11+12]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 16, "vpaddd zmm18,zmm0,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 32, "vpaddd zmm24,zmm1,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 5, \ColumnCount\(), 48, "vpaddd zmm30,zmm2,DWORD PTR [r11+16]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 16, "vpaddd zmm19,zmm0,DWORD PTR [r11+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 32, "vpaddd zmm25,zmm1,DWORD PTR [r11+20]{1to16}" - EmitIfCount2GE \RowCount\(), 6, \ColumnCount\(), 48, "vpaddd zmm31,zmm2,DWORD PTR [r11+20]{1to16}" - -.LAccumulatorsInitialized\@: - -// -// Iterate over the length of a matrix A row to produce the output accumulators. -// - -.if \RowCount\() > 3 - lea r8,[rcx*2+rcx] - add r8,rdi # compute matrix A plus 3 rows -.endif - cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 - je .LProduceWithU8S8Avx512Core\@ - jg .LProduceWithU8U8Avx512Core\@ - ComputeBlockLoopU8S8 Avx512Vnni, \ColumnCount\(), \RowCount\() - jmp .LExitProduceOutputBlock\@ - -.LProduceWithU8U8Avx512Core\@: - ComputeBlockLoopU8U8 Avx512Core, \ColumnCount\(), \RowCount\() - jmp .LExitProduceOutputBlock\@ - -.LProduceWithU8S8Avx512Core\@: - ComputeBlockLoopU8S8 Avx512Core, \ColumnCount\(), \RowCount\() - -.LExitProduceOutputBlock\@: -.if \RowCount\() > 3 - lea r8,[rax*2+rax] - add r8,rdx # compute matrix C plus 3 rows -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - -Implicit Arguments: - - rax - Supplies the length in bytes of a row from matrix C. - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - rdx - Supplies the address of matrix C. - - rbx - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rcx - Supplies the length in bytes of a row from matrix A. - - r10b - Supplies the zero mode flag. - - r11 - Supplies the address of the row sum buffer. - - r12 - Supplies the address of the column sum buffer. - - r14 - Supplies the stride in bytes of between packed blocks of matrix B. - ---*/ - - .macro ProcessCountM RowCount - - cmp r9,32 - ja .LProcessNextColumnLoop48xN\@ - cmp r9,16 - jbe .LProcessRemainingCountN\@ - -.LProcessNextColumnLoop32xN\@: - ProduceOutputBlock 32, \RowCount\() - add rsi,r14 # advance matrix B by packed block stride - -.LOutput32xNBlock\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput32xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm20,zmm20,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm21,zmm21,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm22,zmm22,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm23,zmm23,ZMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm24,zmm24,ZMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm25,zmm25,ZMMWORD PTR [r8+rax*2]" - -.LSkipAccumulateOutput32xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm20" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm21" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm22" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm23" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm24" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm25" - add rdx,16*4 # advance matrix C by 16 columns -.if \RowCount\() > 3 - add r8,16*4 # advance matrix C plus 3 rows by 16 columns -.endif - sub r9,16 - -.LOutput16xNBlock\@: - sub r9,16 - jae .LOutput16xNBlockWithMask\@ - lea rcx,[r9+16] # correct for over-subtract above - mov ebp,1 - shl ebp,cl - dec ebp - kmovw k1,ebp # update mask for remaining columns - xor r9,r9 # no more columns remaining - -.LOutput16xNBlockWithMask\@: - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput16xNBlockWithMask\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm14{k1},zmm14,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm15{k1},zmm15,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm16{k1},zmm16,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm17{k1},zmm17,ZMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm18{k1},zmm18,ZMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm19{k1},zmm19,ZMMWORD PTR [r8+rax*2]" - -.LSkipAccumulateOutput16xNBlockWithMask\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx]{k1},zmm14" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax]{k1},zmm15" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2]{k1},zmm16" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8]{k1},zmm17" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax]{k1},zmm18" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2]{k1},zmm19" - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,rbx # reload matrix A - cmp r9,32 - ja .LProcessNextColumnLoop48xN\@ - cmp r9,16 - ja .LProcessNextColumnLoop32xN\@ - test r9,r9 - jnz .LProcessRemainingCountN\@ - mov eax,\RowCount\() - jmp .LExitKernel - -.LProcessRemainingCountN\@: - ProduceOutputBlock 16, \RowCount\() - jmp .LOutput16xNBlock\@ - -.LProcessNextColumnLoop48xN\@: - ProduceOutputBlock 48, \RowCount\() - lea rsi,[rsi+r14*2] # advance matrix B by packed block stride - test r10b,r10b # ZeroMode? - jnz .LSkipAccumulateOutput48xNBlock\@ - EmitIfCountGE \RowCount\(), 1, "vpaddd zmm26,zmm26,ZMMWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "vpaddd zmm27,zmm27,ZMMWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 3, "vpaddd zmm28,zmm28,ZMMWORD PTR [rdx+rax*2]" - EmitIfCountGE \RowCount\(), 4, "vpaddd zmm29,zmm29,ZMMWORD PTR [r8]" - EmitIfCountGE \RowCount\(), 5, "vpaddd zmm30,zmm30,ZMMWORD PTR [r8+rax]" - EmitIfCountGE \RowCount\(), 6, "vpaddd zmm31,zmm31,ZMMWORD PTR [r8+rax*2]" - -.LSkipAccumulateOutput48xNBlock\@: - EmitIfCountGE \RowCount\(), 1, "vmovdqu32 ZMMWORD PTR [rdx],zmm26" - EmitIfCountGE \RowCount\(), 2, "vmovdqu32 ZMMWORD PTR [rdx+rax],zmm27" - EmitIfCountGE \RowCount\(), 3, "vmovdqu32 ZMMWORD PTR [rdx+rax*2],zmm28" - EmitIfCountGE \RowCount\(), 4, "vmovdqu32 ZMMWORD PTR [r8],zmm29" - EmitIfCountGE \RowCount\(), 5, "vmovdqu32 ZMMWORD PTR [r8+rax],zmm30" - EmitIfCountGE \RowCount\(), 6, "vmovdqu32 ZMMWORD PTR [r8+rax*2],zmm31" - add rdx,16*4 # advance matrix C by 16 columns -.if \RowCount\() > 3 - add r8,16*4 # advance matrix C plus 3 rows by 16 columns -.endif - sub r9,16 - jmp .LOutput32xNBlock\@ - - .endm - -// -// Reduce code size for the various types of kernels by sharing the outer logic -// and switching on the selector codes (using sign bit to discriminate). -// - - FUNCTION_ENTRY MlasGemmU8U8KernelAvx512Core - - mov eax,1 - jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) - - FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Core - - xor eax,eax - jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) - - FUNCTION_ENTRY MlasGemmU8S8KernelAvx512Vnni - - mov eax,-1 - jmp C_UNDERSCORE(MlasGemmU8X8KernelAvx512Core) - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. - -Arguments: - - A (rdi) - Supplies the address of matrix A. The matrix data has been packed - using MlasGemmU8X8CopyPackAAvx2. - - B (rsi) - Supplies the address of matrix B. The matrix data has been packed - using MlasGemmU8X8CopyPackBAvx2. - - C (rdx) - Supplies the address of matrix C. - - PackedCountK (rcx) - Supplies the number of packed columns from matrix A and - the number of packed rows from matrix B to iterate over. - - CountM (r8) - Supplies the maximum number of rows that can be processed for - matrix A and matrix C. The actual number of rows handled for this - invocation depends on the kernel implementation. - - CountN (r9) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldc - Supplies the first dimension of matrix C. - - RowSumBuffer - Supplies the sum of each row from matrix A. These values have - been pre-scaled by the zero point offset of matrix B if the offset is - per-tensor (ZeroPointB is nullptr). Otherwise, these values must be - scaled by the per-column zero point offsets of matrix B. These values are - accumulated into every row of matrix C. - - ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied - by the zero point offset of matrix A. These values are accumulated into - every column of matrix C. - - ZeroPointB - Optionally supplies the per-column zero point offsets of matrix - B, else nullptr if the matrix B is using per-tensor quantization. - - ZeroMode - Supplies true if the output matrix must be zero initialized, - else false if the output matrix is accumulated into. - -Return Value: - - Returns the number of rows handled. - ---*/ - - FUNCTION_ENTRY MlasGemmU8X8KernelAvx512Core - - push rbp - push rbx - push r12 - push r13 - push r14 - - mov DWORD PTR .LGemmU8X8KernelFrame_type[rsp],eax - mov rbx,rdi - mov rax,.LGemmU8X8KernelFrame_ldc[rsp] - shl rax,2 # convert ldc to bytes - shl rcx,2 # convert to row length - movzx r10,BYTE PTR .LGemmU8X8KernelFrame_ZeroMode[rsp] - mov r11,.LGemmU8X8KernelFrame_RowSumBuffer[rsp] - mov r12,.LGemmU8X8KernelFrame_ColumnSumBuffer[rsp] - mov r13,.LGemmU8X8KernelFrame_ZeroPointB[rsp] - mov ebp,-1 - kmovw k1,ebp # update mask to write all columns - neg ebp - vpbroadcastw zmm13,ebp # generate 512-bit word vector [0x0001] - lea rbp,[rcx*8] - lea r14,[rbp*2] - cmp DWORD PTR .LGemmU8X8KernelFrame_type[rsp],0 - cmovg r14,rbp # select matrix B packed stride - -// -// Process CountM rows of the matrices. -// - - cmp r8,5 - ja .LProcessCountM6 - je .LProcessCountM5 - cmp r8,3 - ja .LProcessCountM4 - je .LProcessCountM3 - cmp r8,1 - je .LProcessCountM1 - -.LProcessCountM2: - ProcessCountM 2 - -.LProcessCountM4: - ProcessCountM 4 - -.LProcessCountM6: - ProcessCountM 6 - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - vzeroupper - - pop r14 - pop r13 - pop r12 - pop rbx - pop rbp - ret - -.LProcessCountM1: - ProcessCountM 1 - -.LProcessCountM3: - ProcessCountM 3 - -.LProcessCountM5: - ProcessCountM 5 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S deleted file mode 100644 index ef6b0afedda38..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx2.S +++ /dev/null @@ -1,345 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemvU8S8KernelAvx2.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/vector - multiply operation (QGEMV). - - This implementation uses AVX2 instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the U8S8 kernel. -// - - .equ .LGemvU8S8KernelFrame_mask, -8 - .equ .LGemvU8S8KernelFrame_SavedRbx, 0 - .equ .LGemvU8S8KernelFrame_SavedRbp, 8 - .equ .LGemvU8S8KernelFrame_ReturnAddress, 16 - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix/vector multiplication. - -Arguments: - - A (rdi) - Supplies the address of vector A. - - B (rsi) - Supplies the address of matrix B. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from vector A and the number - of rows from matrix B to iterate over. - - CountN (r8) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb (r9) - Supplies the first dimension of matrix B. - -Return Value: - - None. - ---*/ - - .globl C_UNDERSCORE(MlasGemvU8S8KernelAvx2) -C_UNDERSCORE(MlasGemvU8S8KernelAvx2): - - push rbp - push rbx - - mov r10,rdx - mov r11,rsp # set ZeroMode to any non-zero value - vpcmpeqw ymm6,ymm6,ymm6 # generate word vector [0xFFFF] - vpsrlw ymm6,ymm6,15 # generate word vector [0x0001] - -// -// Process 4 rows of matrix B in a loop. -// - - sub rcx,4 - jb .LProcessRemainingRows - -.LProcessRowLoop4: - mov rdx,rsi # reload matrix B - lea rsi,[rsi+r9*4] # advance matrix B by 4 rows - mov rbx,r10 # reload matrix C - mov rbp,r8 # reload CountN - vpbroadcastd ymm0,DWORD PTR [rdi] - add rdi,4 # advance matrix A by 4 bytes - -// -// Process sets of 32 columns from the 4 rows in a loop. -// -// Some permute operations are deferred until the final store of the 4x32 block -// as these permutes are expensive. -// - -.LProcessColumnLoop4By32: - cmp rbp,32 - jb .LProcessColumnLoop4By8 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovdqu ymm2,YMMWORD PTR [rdx] - vmovdqu ymm3,YMMWORD PTR [rdx+r9] - vmovdqu ymm4,YMMWORD PTR [rax] - vmovdqu ymm5,YMMWORD PTR [rax+r9] - vpunpcklbw ymm1,ymm2,ymm3 # interleave row data bytes - vpunpckhbw ymm2,ymm2,ymm3 - vpunpcklbw ymm3,ymm4,ymm5 - vpunpckhbw ymm4,ymm4,ymm5 - vpunpcklwd ymm5,ymm1,ymm3 # interleave row data words - vpunpckhwd ymm1,ymm1,ymm3 - vpunpcklwd ymm3,ymm2,ymm4 - vpunpckhwd ymm2,ymm2,ymm4 - vpmaddubsw ymm5,ymm0,ymm5 # multiply and reduce - vpmaddwd ymm5,ymm5,ymm6 - vpmaddubsw ymm1,ymm0,ymm1 - vpmaddwd ymm1,ymm1,ymm6 - vpmaddubsw ymm3,ymm0,ymm3 - vpmaddwd ymm3,ymm3,ymm6 - vpmaddubsw ymm2,ymm0,ymm2 - vpmaddwd ymm2,ymm2,ymm6 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By32 - vpaddd ymm5,ymm5,YMMWORD PTR [rbx] - vpaddd ymm1,ymm1,YMMWORD PTR [rbx+32] - vpaddd ymm3,ymm3,YMMWORD PTR [rbx+64] - vpaddd ymm2,ymm2,YMMWORD PTR [rbx+96] - -.LSkipAccumulateOutput4By32: - cmp rcx,4 # final 4x32 block? - jae .LStoreOutput4By32 - vperm2i128 ymm4,ymm5,ymm1,0x31 # interleave vector results - vperm2i128 ymm5,ymm5,ymm1,0x20 - vperm2i128 ymm1,ymm3,ymm2,0x20 - vperm2i128 ymm2,ymm3,ymm2,0x31 - vmovaps ymm3,ymm4 - -.LStoreOutput4By32: - vmovdqu YMMWORD PTR [rbx],ymm5 - vmovdqu YMMWORD PTR [rbx+32],ymm1 - vmovdqu YMMWORD PTR [rbx+64],ymm3 - vmovdqu YMMWORD PTR [rbx+96],ymm2 - add rdx,32 # advance matrix B by 32 bytes - add rbx,32*4 # advance matrix C by 32 columns - sub rbp,32 # decrement CountN - jnz .LProcessColumnLoop4By32 - -.LAdvanceRowLoop4: - xor r11,r11 # clear ZeroMode - sub rcx,4 # decrement CountK - jae .LProcessRowLoop4 - -.LProcessRemainingRows: - add rcx,4 # correct for over-subtract above - jnz .LProcessRemainingSmallK - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - vzeroupper - - pop rbx - pop rbp - ret - -// -// Process sets of 8 columns from the 4 rows in a loop. -// - -.LProcessColumnLoop4By8: - cmp ebp,8 - jb .LProcessColumn4By4 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovq xmm2,QWORD PTR [rdx] - vmovq xmm3,QWORD PTR [rdx+r9] - vmovq xmm4,QWORD PTR [rax] - vmovq xmm5,QWORD PTR [rax+r9] - vpunpcklbw xmm2,xmm2,xmm3 # interleave row data bytes - vpunpcklbw xmm4,xmm4,xmm5 - vpunpcklwd xmm1,xmm2,xmm4 # interleave row data words - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm1,ymm1,xmm2,1 # concatenate vector - vpmaddubsw ymm1,ymm0,ymm1 # multiply and reduce - vpmaddwd ymm1,ymm1,ymm6 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By8 - vpaddd ymm1,ymm1,YMMWORD PTR [rbx] - -.LSkipAccumulateOutput4By8: - vmovdqu YMMWORD PTR [rbx],ymm1 - add rdx,8 # advance matrix B by 8 bytes - add rbx,8*4 # advance matrix C by 8 columns - sub ebp,8 # decrement CountN - jnz .LProcessColumnLoop4By8 - jmp .LAdvanceRowLoop4 - -// -// Process a set of 4 columns from the 4 rows. -// - -.LProcessColumn4By4: - test ebp,4 # (CountN & 4) != 0? - jz .LProcessColumn4BySmallN - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovd xmm1,DWORD PTR [rdx] - vpinsrd xmm1,xmm1,DWORD PTR [rdx+r9],1 - vpinsrd xmm1,xmm1,DWORD PTR [rax],2 - vpinsrd xmm1,xmm1,DWORD PTR [rax+r9],3 - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpmaddubsw xmm1,xmm0,xmm1 # multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By4 - vpaddd xmm1,xmm1,XMMWORD PTR [rbx] - -.LSkipAccumulateOutput4By4: - vmovdqu XMMWORD PTR [rbx],xmm1 - and ebp,3 # (CountN & 3) != 0? - jz .LAdvanceRowLoop4 - add rdx,4 # advance matrix B by 4 bytes - add rbx,4*4 # advance matrix C by 4 columns - -// -// Process the remaining 1 to 3 columns from the 4 rows. -// - -.LProcessColumn4BySmallN: - mov DWORD PTR .LGemvU8S8KernelFrame_mask[rsp],ebp - vbroadcastss xmm2,DWORD PTR .LGemvU8S8KernelFrame_mask[rsp] - vpcmpgtd xmm2,xmm2,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - vpxor xmm1,xmm1,xmm1 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - cmp ebp,2 # (CountN & 2) != 0? - jb .LProcessColumn4By1 - vpinsrw xmm1,xmm1,WORD PTR [rdx],0 - vpinsrw xmm1,xmm1,WORD PTR [rdx+r9],2 - vpinsrw xmm1,xmm1,WORD PTR [rax],4 - vpinsrw xmm1,xmm1,WORD PTR [rax+r9],6 - je .LComputeOutput4BySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rdx+2],2 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+r9+2],6 - vpinsrb xmm1,xmm1,BYTE PTR [rax+2],10 - vpinsrb xmm1,xmm1,BYTE PTR [rax+r9+2],14 - jmp .LComputeOutput4BySmallN - -.LProcessColumn4By1: - vpinsrb xmm1,xmm1,BYTE PTR [rdx],0 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+r9],4 - vpinsrb xmm1,xmm1,BYTE PTR [rax],8 - vpinsrb xmm1,xmm1,BYTE PTR [rax+r9],12 - -.LComputeOutput4BySmallN: - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpmaddubsw xmm1,xmm0,xmm1 # multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 # ZeroMode? - jnz .LStoreOutput4BySmallN - vpmaskmovd xmm3,xmm2,XMMWORD PTR [rbx] - vpaddd xmm1,xmm1,xmm3 - -.LStoreOutput4BySmallN: - vpmaskmovd XMMWORD PTR [rbx],xmm2,xmm1 - jmp .LAdvanceRowLoop4 - -// -// Broadcast the remaining 1 to 3 values from vector A. -// - -.LProcessRemainingSmallK: - vpxor xmm5,xmm5,xmm5 # keep zero vector for vpinsrb/vpinsrw - cmp ecx,2 - jb .LLoadVectorASingleRemainingByte - vpinsrw xmm0,xmm5,WORD PTR [rdi],0 - je .LBroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rdi+2],2 - jmp .LBroadcastVectorARemainingBytes - -.LLoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm5,BYTE PTR [rdi],0 - -.LBroadcastVectorARemainingBytes: - vpshufd xmm0,xmm0,0 # broadcast values - -// -// Process a set of 4 columns from the remaining rows. -// - -.LProcessColumnLoopSmallKBy4: - cmp r8d,4 - jb .LProcessColumnLoopSmallKBySmallN - vmovd xmm1,DWORD PTR [rsi] - cmp ecx,2 - jb .LComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+r9],1 - je .LComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+r9*2],2 - -.LComputeOutputSmallKBy4: - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpmaddubsw xmm1,xmm0,xmm1 # multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBy4 - vpaddd xmm1,xmm1,XMMWORD PTR [r10] - -.LSkipAccumulateOutputSmallKBy4: - vmovdqu XMMWORD PTR [r10],xmm1 - add rsi,4 # advance matrix B by 4 bytes - add r10,4*4 # advance matrix C by 4 columns - sub r8d,4 # decrement CountN - jnz .LProcessColumnLoopSmallKBy4 - jmp .LExitKernel - -// -// Process the remaining 1 to 3 columns from the remaining rows. -// -// Single step through each of the columns to keep code size small for the -// uncommon path (typically the row count is a multiple of 4). -// - -.LProcessColumnLoopSmallKBySmallN: - vpinsrb xmm1,xmm5,BYTE PTR [rsi],0 - cmp ecx,2 - jb .LComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+r9],1 - je .LComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+r9*2],2 - -.LComputeOutputSmallKBySmallN: - vpmaddubsw xmm1,xmm0,xmm1 # multiply and reduce - vpmaddwd xmm1,xmm1,xmm6 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBySmallN - vmovd xmm3,DWORD PTR [r10] - vpaddd xmm1,xmm1,xmm3 - -.LSkipAccumulateOutputSmallKBySmallN: - vmovd DWORD PTR [r10],xmm1 - inc rsi # advance matrix B by 1 byte - add r10,4 # advance matrix C by 1 column - dec r8 - jnz .LProcessColumnLoopSmallKBySmallN - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Common.h b/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Common.h deleted file mode 100644 index c5a45c6cfef5e..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Common.h +++ /dev/null @@ -1,356 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemvU8S8KernelAvx512Common.h - -Abstract: - - This module contains common kernel macros and structures for the quantized - integer matrix/vector multiply operation (QGEMV) for the AVX512 core and - AVX512VNNI kernels. - ---*/ - -// -// Stack frame layout for the U8S8 kernel. -// - - .equ .LGemvU8S8KernelFrame_SavedRbx, 0 - .equ .LGemvU8S8KernelFrame_SavedRbp, 8 - .equ .LGemvU8S8KernelFrame_ReturnAddress, 16 - -/*++ - -Macro Description: - - This macro generates the common AVX512 code for the inner kernel to compute - matrix/vector multiplication. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro GemvU8S8KernelAvx512Function Isa - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix/vector multiplication. - -Arguments: - - A (rdi) - Supplies the address of vector A. - - B (rsi) - Supplies the address of matrix B. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from vector A and the number - of rows from matrix B to iterate over. - - CountN (r8) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb (r9) - Supplies the first dimension of matrix B. - -Return Value: - - None. - ---*/ - - .globl C_UNDERSCORE(MlasGemvU8S8Kernel\Isa\()) -C_UNDERSCORE(MlasGemvU8S8Kernel\Isa\()): - - push rbp - push rbx - - mov rbx,rcx - mov ecx,r8d - and ecx,15 # isolate unaligned count - mov eax,1 - shl eax,cl - dec eax - kmovw k1,eax # compute vector load/store mask - mov rcx,rbx - mov r10,rdx - mov r11,rsp # set ZeroMode to any non-zero value -.ifeqs "\Isa\()", "Avx512Core" - mov eax,1 - vpbroadcastw zmm29,eax -.endif - -// -// Process 4 rows of matrix B in a loop. -// - - sub rcx,4 - jb .LProcessRemainingRows - -.LProcessRowLoop4: - mov rdx,rsi # reload matrix B - lea rsi,[rsi+r9*4] # advance matrix B by 4 rows - mov rbx,r10 # reload matrix C - mov rbp,r8 # reload CountN - vpbroadcastd zmm28,DWORD PTR [rdi] - add rdi,4 # advance matrix A by 4 bytes - -// -// Process sets of 64 columns from the 4 rows in a loop. -// -// Some permute operations are deferred until the final store of the 4x64 block -// as these permutes are expensive. -// - -.LProcessColumnLoop4By64: - cmp rbp,64 - jb .LProcessColumnLoop4By16 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovdqu32 zmm16,ZMMWORD PTR [rdx] - vmovdqu32 zmm17,ZMMWORD PTR [rdx+r9] - vmovdqu32 zmm18,ZMMWORD PTR [rax] - vmovdqu32 zmm19,ZMMWORD PTR [rax+r9] - vpunpcklbw zmm20,zmm16,zmm17 # interleave row data bytes - vpunpckhbw zmm21,zmm16,zmm17 - vpunpcklbw zmm22,zmm18,zmm19 - vpunpckhbw zmm23,zmm18,zmm19 - vpunpcklwd zmm16,zmm20,zmm22 # interleave row data words - vpunpckhwd zmm17,zmm20,zmm22 - vpunpcklwd zmm18,zmm21,zmm23 - vpunpckhwd zmm19,zmm21,zmm23 -.ifeqs "\Isa\()", "Avx512Core" - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 - vpmaddubsw zmm17,zmm28,zmm17 - vpmaddwd zmm21,zmm17,zmm29 - vpmaddubsw zmm18,zmm28,zmm18 - vpmaddwd zmm22,zmm18,zmm29 - vpmaddubsw zmm19,zmm28,zmm19 - vpmaddwd zmm23,zmm19,zmm29 -.else - vpxord zmm20,zmm20,zmm20 - vpxord zmm21,zmm21,zmm21 - vpxord zmm22,zmm22,zmm22 - vpxord zmm23,zmm23,zmm23 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 - VpdpbusdsZmmZmmZmm zmm21,zmm28,zmm17 - VpdpbusdsZmmZmmZmm zmm22,zmm28,zmm18 - VpdpbusdsZmmZmmZmm zmm23,zmm28,zmm19 -.endif - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By64 - vpaddd zmm20,zmm20,ZMMWORD PTR [rbx] - vpaddd zmm21,zmm21,ZMMWORD PTR [rbx+16*4] - vpaddd zmm22,zmm22,ZMMWORD PTR [rbx+32*4] - vpaddd zmm23,zmm23,ZMMWORD PTR [rbx+48*4] - -.LSkipAccumulateOutput4By64: - cmp rcx,4 # final 4x64 block? - jae .LStoreOutput4By64 - vextracti32x4 XMMWORD PTR [rbx],zmm20,0 - vextracti32x4 XMMWORD PTR [rbx+4*4],zmm21,0 - vextracti32x4 XMMWORD PTR [rbx+8*4],zmm22,0 - vextracti32x4 XMMWORD PTR [rbx+12*4],zmm23,0 - vextracti32x4 XMMWORD PTR [rbx+16*4],zmm20,1 - vextracti32x4 XMMWORD PTR [rbx+20*4],zmm21,1 - vextracti32x4 XMMWORD PTR [rbx+24*4],zmm22,1 - vextracti32x4 XMMWORD PTR [rbx+28*4],zmm23,1 - vextracti32x4 XMMWORD PTR [rbx+32*4],zmm20,2 - vextracti32x4 XMMWORD PTR [rbx+36*4],zmm21,2 - vextracti32x4 XMMWORD PTR [rbx+40*4],zmm22,2 - vextracti32x4 XMMWORD PTR [rbx+44*4],zmm23,2 - vextracti32x4 XMMWORD PTR [rbx+48*4],zmm20,3 - vextracti32x4 XMMWORD PTR [rbx+52*4],zmm21,3 - vextracti32x4 XMMWORD PTR [rbx+56*4],zmm22,3 - vextracti32x4 XMMWORD PTR [rbx+60*4],zmm23,3 - jmp .LAdvanceColumnLoop64 - -.LStoreOutput4By64: - vmovdqu32 ZMMWORD PTR [rbx],zmm20 - vmovdqu32 ZMMWORD PTR [rbx+16*4],zmm21 - vmovdqu32 ZMMWORD PTR [rbx+32*4],zmm22 - vmovdqu32 ZMMWORD PTR [rbx+48*4],zmm23 - -.LAdvanceColumnLoop64: - add rdx,64 # advance matrix B by 64 bytes - add rbx,64*4 # advance matrix C by 64 columns - sub rbp,64 # decrement CountN - jnz .LProcessColumnLoop4By64 - -.LAdvanceRowLoop4: - xor r11,r11 # clear ZeroMode - sub rcx,4 # decrement CountK - jae .LProcessRowLoop4 - -.LProcessRemainingRows: - add rcx,4 # correct for over-subtract above - jnz .LProcessRemainingSmallK - -.LExitKernel: - vzeroupper - - pop rbx - pop rbp - ret - -// -// Process sets of 16 columns from the 4 rows in a loop or process the remaining -// 1 to 15 columns. -// - -.LProcessColumnLoop4By16: - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - cmp ebp,16 - jb .LLoadPartialVector4BySmallN - vmovdqu xmm2,XMMWORD PTR [rdx] - vmovdqu xmm3,XMMWORD PTR [rdx+r9] - vmovdqu xmm4,XMMWORD PTR [rax] - vmovdqu xmm5,XMMWORD PTR [rax+r9] - jmp .LComputeOutput4By16 - -.LLoadPartialVector4BySmallN: - vmovdqu8 zmm2{k1}{z},ZMMWORD PTR [rdx] - vmovdqu8 zmm3{k1}{z},ZMMWORD PTR [rdx+r9] - vmovdqu8 zmm4{k1}{z},ZMMWORD PTR [rax] - vmovdqu8 zmm5{k1}{z},ZMMWORD PTR [rax+r9] - -.LComputeOutput4By16: - vpunpcklbw xmm1,xmm2,xmm3 # interleave row data bytes - vpunpckhbw xmm2,xmm2,xmm3 - vpunpcklbw xmm3,xmm4,xmm5 - vpunpckhbw xmm4,xmm4,xmm5 - vpunpcklwd xmm5,xmm1,xmm3 # interleave row data words - vpunpckhwd xmm1,xmm1,xmm3 - vpunpcklwd xmm3,xmm2,xmm4 - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm5,ymm5,xmm1,1 # concatenate 256-bit vector - vinserti128 ymm3,ymm3,xmm2,1 - vshufi32x4 zmm16,zmm5,zmm3,0x44 # concatenate 512-bit vector -.ifeqs "\Isa\()", "Avx512Core" - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 -.else - vpxord zmm20,zmm20,zmm20 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 -.endif - cmp ebp,16 - jb .LStorePartialVector4BySmallN - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By16 - vpaddd zmm20,zmm20,ZMMWORD PTR [rbx] - -.LSkipAccumulateOutput4By16: - vmovdqu32 ZMMWORD PTR [rbx],zmm20 - add rdx,16 # advance matrix B by 16 bytes - add rbx,16*4 # advance matrix C by 16 columns - sub ebp,16 # decrement CountN - jnz .LProcessColumnLoop4By16 - jmp .LAdvanceRowLoop4 - -.LStorePartialVector4BySmallN: - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4BySmallN - vpaddd zmm20{k1}{z},zmm20,ZMMWORD PTR [rbx] - -.LSkipAccumulateOutput4BySmallN: - vmovdqu32 ZMMWORD PTR [rbx]{k1},zmm20 - jmp .LAdvanceRowLoop4 - -// -// Broadcast the remaining 1 to 3 values from vector A. -// - -.LProcessRemainingSmallK: - vpxor xmm0,xmm0,xmm0 - cmp ecx,2 - jb .LLoadVectorASingleRemainingByte - vpinsrw xmm0,xmm0,WORD PTR [rdi],0 - je .LBroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rdi+2],2 - jmp .LBroadcastVectorARemainingBytes - -.LLoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm0,BYTE PTR [rdi],0 - -.LBroadcastVectorARemainingBytes: - vpbroadcastd zmm28,xmm0 # broadcast values - -// -// Process sets of 16 columns from the remaining rows in a loop or process the -// remaining 1 to 15 columns. -// - -.LProcessColumnLoopSmallKBy16: - vpxor xmm3,xmm3,xmm3 # clear optional row vectors - vpxor xmm4,xmm4,xmm4 - vpxor xmm5,xmm5,xmm5 - cmp r8d,16 - jb .LLoadPartialVectorSmallKBySmallN - vmovdqu xmm2,XMMWORD PTR [rsi] - cmp ecx,2 - jb .LComputeOutputSmallKBy16 - vmovdqu xmm3,XMMWORD PTR [rsi+r9] - je .LComputeOutputSmallKBy16 - vmovdqu xmm4,XMMWORD PTR [rsi+r9*2] - jmp .LComputeOutputSmallKBy16 - -.LLoadPartialVectorSmallKBySmallN: - vmovdqu8 zmm2{k1}{z},ZMMWORD PTR [rsi] - cmp ecx,2 - jb .LComputeOutputSmallKBy16 - vmovdqu8 zmm3{k1}{z},ZMMWORD PTR [rsi+r9] - je .LComputeOutputSmallKBy16 - vmovdqu8 zmm4{k1}{z},ZMMWORD PTR [rsi+r9*2] - jmp .LComputeOutputSmallKBy16 - -.LComputeOutputSmallKBy16: - vpunpcklbw xmm1,xmm2,xmm3 # interleave row data bytes - vpunpckhbw xmm2,xmm2,xmm3 - vpunpcklbw xmm3,xmm4,xmm5 - vpunpckhbw xmm4,xmm4,xmm5 - vpunpcklwd xmm5,xmm1,xmm3 # interleave row data words - vpunpckhwd xmm1,xmm1,xmm3 - vpunpcklwd xmm3,xmm2,xmm4 - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm5,ymm5,xmm1,1 # concatenate 256-bit vector - vinserti128 ymm3,ymm3,xmm2,1 - vshufi32x4 zmm16,zmm5,zmm3,0x44 # concatenate 512-bit vector -.ifeqs "\Isa\()", "Avx512Core" - vpmaddubsw zmm16,zmm28,zmm16 - vpmaddwd zmm20,zmm16,zmm29 -.else - vpxord zmm20,zmm20,zmm20 - VpdpbusdsZmmZmmZmm zmm20,zmm28,zmm16 -.endif - cmp r8d,16 - jb .LStorePartialVectorSmallKBySmallN - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBy16 - vpaddd zmm20,zmm20,ZMMWORD PTR [r10] - -.LSkipAccumulateOutputSmallKBy16: - vmovdqu32 ZMMWORD PTR [r10],zmm20 - add rsi,16 # advance matrix B by 16 bytes - add r10,16*4 # advance matrix C by 16 columns - sub r8d,16 # decrement CountN - jnz .LProcessColumnLoopSmallKBy16 - jmp .LExitKernel - -.LStorePartialVectorSmallKBySmallN: - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBySmallN - vpaddd zmm20{k1}{z},zmm20,ZMMWORD PTR [r10] - -.LSkipAccumulateOutputSmallKBySmallN: - vmovdqu32 ZMMWORD PTR [r10]{k1},zmm20 - jmp .LExitKernel - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S b/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S deleted file mode 100644 index 841eed0d53caa..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Core.S +++ /dev/null @@ -1,33 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemvU8S8KernelAvx512Core.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/vector - multiply operation (QGEMV). - - This implementation uses AVX512 core instructions (BW/DQ/VL). - ---*/ - -#include "asmmacro.h" -#include "QgemvU8S8KernelAvx512Common.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMV kernel. -// - -GemvU8S8KernelAvx512Function Avx512Core - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S deleted file mode 100644 index 6e0281b632cdd..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvx512Vnni.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemvU8S8KernelAvx512Vnni.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/vector - multiply operation (QGEMV). - - This implementation uses AVX512VNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "QgemvU8S8KernelAvx512Common.h" -#include "AssembleAvx512Vnni.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMV kernel. -// - -GemvU8S8KernelAvx512Function Avx512Vnni - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S b/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S deleted file mode 100644 index 2d8f7d656e687..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/QgemvU8S8KernelAvxVnni.S +++ /dev/null @@ -1,344 +0,0 @@ -/*++ - -Copyright (c) 2020 Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - QgemvU8S8KernelAvxVnni.s - -Abstract: - - This module implements the kernels for the quantized integer matrix/vector - multiply operation (QGEMV). - - This implementation uses AVXVNNI instructions. - ---*/ - -#include "asmmacro.h" -#include "AssembleAvxVnni.h" - - .intel_syntax noprefix - -// -// Stack frame layout for the U8S8 kernel. -// - - .equ .LGemvU8S8KernelFrame_mask, -8 - .equ .LGemvU8S8KernelFrame_SavedRbx, 0 - .equ .LGemvU8S8KernelFrame_SavedRbp, 8 - .equ .LGemvU8S8KernelFrame_ReturnAddress, 16 - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix/vector multiplication. - -Arguments: - - A (rdi) - Supplies the address of vector A. - - B (rsi) - Supplies the address of matrix B. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from vector A and the number - of rows from matrix B to iterate over. - - CountN (r8) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb (r9) - Supplies the first dimension of matrix B. - -Return Value: - - None. - ---*/ - - .globl C_UNDERSCORE(MlasGemvU8S8KernelAvxVnni) -C_UNDERSCORE(MlasGemvU8S8KernelAvxVnni): - - push rbp - push rbx - - mov r10,rdx - mov r11,rsp # set ZeroMode to any non-zero value - -// -// Process 4 rows of matrix B in a loop. -// - - sub rcx,4 - jb .LProcessRemainingRows - -.LProcessRowLoop4: - mov rdx,rsi # reload matrix B - lea rsi,[rsi+r9*4] # advance matrix B by 4 rows - mov rbx,r10 # reload matrix C - mov rbp,r8 # reload CountN - vpbroadcastd ymm0,DWORD PTR [rdi] - add rdi,4 # advance matrix A by 4 bytes - -// -// Process sets of 32 columns from the 4 rows in a loop. -// -// Some permute operations are deferred until the final store of the 4x32 block -// as these permutes are expensive. -// - -.LProcessColumnLoop4By32: - cmp rbp,32 - jb .LProcessColumnLoop4By8 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovdqu ymm2,YMMWORD PTR [rdx] - vmovdqu ymm3,YMMWORD PTR [rdx+r9] - vmovdqu ymm4,YMMWORD PTR [rax] - vmovdqu ymm5,YMMWORD PTR [rax+r9] - vpunpcklbw ymm1,ymm2,ymm3 # interleave row data bytes - vpunpckhbw ymm2,ymm2,ymm3 - vpxor ymm7,ymm7,ymm7 - vpunpcklbw ymm3,ymm4,ymm5 - vpunpckhbw ymm4,ymm4,ymm5 - vpxor ymm8,ymm8,ymm8 - vpunpcklwd ymm5,ymm1,ymm3 # interleave row data words - vpunpckhwd ymm1,ymm1,ymm3 - vpxor ymm9,ymm9,ymm9 - vpunpcklwd ymm3,ymm2,ymm4 - vpunpckhwd ymm2,ymm2,ymm4 - vpxor ymm10,ymm10,ymm10 - VpdpbusdsYmmYmmYmm ymm7,ymm0,ymm5 - VpdpbusdsYmmYmmYmm ymm8,ymm0,ymm1 - VpdpbusdsYmmYmmYmm ymm9,ymm0,ymm3 - VpdpbusdsYmmYmmYmm ymm10,ymm0,ymm2 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By32 - vpaddd ymm7,ymm7,YMMWORD PTR [rbx] - vpaddd ymm8,ymm8,YMMWORD PTR [rbx+32] - vpaddd ymm9,ymm9,YMMWORD PTR [rbx+64] - vpaddd ymm10,ymm10,YMMWORD PTR [rbx+96] - -.LSkipAccumulateOutput4By32: - cmp rcx,4 # final 4x32 block? - jae .LStoreOutput4By32 - vperm2i128 ymm4,ymm7,ymm8,0x31 # interleave vector results - vperm2i128 ymm7,ymm7,ymm8,0x20 - vperm2i128 ymm8,ymm9,ymm10,0x20 - vperm2i128 ymm10,ymm9,ymm10,0x31 - vmovaps ymm9,ymm4 - -.LStoreOutput4By32: - vmovdqu YMMWORD PTR [rbx],ymm7 - vmovdqu YMMWORD PTR [rbx+32],ymm8 - vmovdqu YMMWORD PTR [rbx+64],ymm9 - vmovdqu YMMWORD PTR [rbx+96],ymm10 - add rdx,32 # advance matrix B by 32 bytes - add rbx,32*4 # advance matrix C by 32 columns - sub rbp,32 # decrement CountN - jnz .LProcessColumnLoop4By32 - -.LAdvanceRowLoop4: - xor r11,r11 # clear ZeroMode - sub rcx,4 # decrement CountK - jae .LProcessRowLoop4 - -.LProcessRemainingRows: - add rcx,4 # correct for over-subtract above - jnz .LProcessRemainingSmallK - -// -// Restore non-volatile registers and return. -// - -.LExitKernel: - vzeroupper - - pop rbx - pop rbp - ret - -// -// Process sets of 8 columns from the 4 rows in a loop. -// - -.LProcessColumnLoop4By8: - cmp ebp,8 - jb .LProcessColumn4By4 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovq xmm2,QWORD PTR [rdx] - vmovq xmm3,QWORD PTR [rdx+r9] - vmovq xmm4,QWORD PTR [rax] - vmovq xmm5,QWORD PTR [rax+r9] - vpunpcklbw xmm2,xmm2,xmm3 # interleave row data bytes - vpunpcklbw xmm4,xmm4,xmm5 - vpunpcklwd xmm1,xmm2,xmm4 # interleave row data words - vpunpckhwd xmm2,xmm2,xmm4 - vinserti128 ymm1,ymm1,xmm2,1 # concatenate vector - vpxor ymm8,ymm8,ymm8 - VpdpbusdsYmmYmmYmm ymm8,ymm0,ymm1 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By8 - vpaddd ymm8,ymm8,YMMWORD PTR [rbx] - -.LSkipAccumulateOutput4By8: - vmovdqu YMMWORD PTR [rbx],ymm8 - add rdx,8 # advance matrix B by 8 bytes - add rbx,8*4 # advance matrix C by 8 columns - sub ebp,8 # decrement CountN - jnz .LProcessColumnLoop4By8 - jmp .LAdvanceRowLoop4 - -// -// Process a set of 4 columns from the 4 rows. -// - -.LProcessColumn4By4: - test ebp,4 # (CountN & 4) != 0? - jz .LProcessColumn4BySmallN - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - vmovd xmm1,DWORD PTR [rdx] - vpinsrd xmm1,xmm1,DWORD PTR [rdx+r9],1 - vpinsrd xmm1,xmm1,DWORD PTR [rax],2 - vpinsrd xmm1,xmm1,DWORD PTR [rax+r9],3 - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutput4By4 - vpaddd xmm8,xmm8,XMMWORD PTR [rbx] - -.LSkipAccumulateOutput4By4: - vmovdqu XMMWORD PTR [rbx],xmm8 - and ebp,3 # (CountN & 3) != 0? - jz .LAdvanceRowLoop4 - add rdx,4 # advance matrix B by 4 bytes - add rbx,4*4 # advance matrix C by 4 columns - -// -// Process the remaining 1 to 3 columns from the 4 rows. -// - -.LProcessColumn4BySmallN: - mov DWORD PTR .LGemvU8S8KernelFrame_mask[rsp],ebp - vbroadcastss xmm2,DWORD PTR .LGemvU8S8KernelFrame_mask[rsp] - vpcmpgtd xmm2,xmm2,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - vpxor xmm1,xmm1,xmm1 - lea rax,[rdx+r9*2] # compute matrix B plus 2 rows - cmp ebp,2 # (CountN & 2) != 0? - jb .LProcessColumn4By1 - vpinsrw xmm1,xmm1,WORD PTR [rdx],0 - vpinsrw xmm1,xmm1,WORD PTR [rdx+r9],2 - vpinsrw xmm1,xmm1,WORD PTR [rax],4 - vpinsrw xmm1,xmm1,WORD PTR [rax+r9],6 - je .LComputeOutput4BySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rdx+2],2 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+r9+2],6 - vpinsrb xmm1,xmm1,BYTE PTR [rax+2],10 - vpinsrb xmm1,xmm1,BYTE PTR [rax+r9+2],14 - jmp .LComputeOutput4BySmallN - -.LProcessColumn4By1: - vpinsrb xmm1,xmm1,BYTE PTR [rdx],0 - vpinsrb xmm1,xmm1,BYTE PTR [rdx+r9],4 - vpinsrb xmm1,xmm1,BYTE PTR [rax],8 - vpinsrb xmm1,xmm1,BYTE PTR [rax+r9],12 - -.LComputeOutput4BySmallN: - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 # ZeroMode? - jnz .LStoreOutput4BySmallN - vpmaskmovd xmm3,xmm2,XMMWORD PTR [rbx] - vpaddd xmm8,xmm8,xmm3 - -.LStoreOutput4BySmallN: - vpmaskmovd XMMWORD PTR [rbx],xmm2,xmm8 - jmp .LAdvanceRowLoop4 - -// -// Broadcast the remaining 1 to 3 values from vector A. -// - -.LProcessRemainingSmallK: - vpxor xmm5,xmm5,xmm5 # keep zero vector for vpinsrb/vpinsrw - cmp ecx,2 - jb .LLoadVectorASingleRemainingByte - vpinsrw xmm0,xmm5,WORD PTR [rdi],0 - je .LBroadcastVectorARemainingBytes - vpinsrb xmm0,xmm0,BYTE PTR [rdi+2],2 - jmp .LBroadcastVectorARemainingBytes - -.LLoadVectorASingleRemainingByte: - vpinsrb xmm0,xmm5,BYTE PTR [rdi],0 - -.LBroadcastVectorARemainingBytes: - vpshufd xmm0,xmm0,0 # broadcast values - -// -// Process a set of 4 columns from the remaining rows. -// - -.LProcessColumnLoopSmallKBy4: - cmp r8d,4 - jb .LProcessColumnLoopSmallKBySmallN - vmovd xmm1,DWORD PTR [rsi] - cmp ecx,2 - jb .LComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+r9],1 - je .LComputeOutputSmallKBy4 - vpinsrd xmm1,xmm1,DWORD PTR [rsi+r9*2],2 - -.LComputeOutputSmallKBy4: - vpshufb xmm1,xmm1,XMMWORD PTR C_UNDERSCORE(MlasTranspose4x4BytesAvx)[rip] - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBy4 - vpaddd xmm8,xmm8,XMMWORD PTR [r10] - -.LSkipAccumulateOutputSmallKBy4: - vmovdqu XMMWORD PTR [r10],xmm8 - add rsi,4 # advance matrix B by 4 bytes - add r10,4*4 # advance matrix C by 4 columns - sub r8d,4 # decrement CountN - jnz .LProcessColumnLoopSmallKBy4 - jmp .LExitKernel - -// -// Process the remaining 1 to 3 columns from the remaining rows. -// -// Single step through each of the columns to keep code size small for the -// uncommon path (typically the row count is a multiple of 4). -// - -.LProcessColumnLoopSmallKBySmallN: - vpinsrb xmm1,xmm5,BYTE PTR [rsi],0 - cmp ecx,2 - jb .LComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+r9],1 - je .LComputeOutputSmallKBySmallN - vpinsrb xmm1,xmm1,BYTE PTR [rsi+r9*2],2 - -.LComputeOutputSmallKBySmallN: - vpxor xmm8,xmm8,xmm8 - VpdpbusdsXmmXmmXmm xmm8,xmm0,xmm1 - test r11,r11 # ZeroMode? - jnz .LSkipAccumulateOutputSmallKBySmallN - vmovd xmm3,DWORD PTR [r10] - vpaddd xmm8,xmm8,xmm3 - -.LSkipAccumulateOutputSmallKBySmallN: - vmovd DWORD PTR [r10],xmm8 - inc rsi # advance matrix B by 1 byte - add r10,4 # advance matrix C by 1 column - dec r8 - jnz .LProcessColumnLoopSmallKBySmallN - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S deleted file mode 100644 index 2163708dcb352..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx.S +++ /dev/null @@ -1,378 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelAvx.S - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbx - Supplies the address of the filter buffer plus 2 * FilterStride. - - r9 - Supplies the StrideWidth parameter (see function description). - - ymm0-ymm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - vmovups ymm12,YMMWORD PTR [rdx] - EmitIfCountGE \OutputCount\(), 1, "vmulps ymm8,ymm12,YMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 1, "vaddps ymm0,ymm0,ymm8" - EmitIfCountGE \OutputCount\(), 2, "vmulps ymm9,ymm12,YMMWORD PTR [rcx+r9]" - EmitIfCountGE \OutputCount\(), 2, "vaddps ymm4,ymm4,ymm9" -.else - EmitIfCountGE \OutputCount\(), 1, "vbroadcastss ymm13,DWORD PTR [rcx+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 2, "vbroadcastss ymm14,DWORD PTR [rcx+r9+\BroadcastOffset\()]" -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "vmulps ymm8,ymm13,YMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 1, "vaddps ymm0,ymm0,ymm8" - EmitIfCountGE \FilterCount\(), 2, "vmulps ymm9,ymm13,YMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 2, "vaddps ymm1,ymm1,ymm9" - EmitIfCountGE \FilterCount\(), 3, "vmulps ymm10,ymm13,YMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 3, "vaddps ymm2,ymm2,ymm10" - EmitIfCountGE \FilterCount\(), 4, "vmulps ymm11,ymm13,YMMWORD PTR [rbx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 4, "vaddps ymm3,ymm3,ymm11" -.else - EmitIfCountGE \FilterCount\(), 1, "vmovups ymm12,YMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vmulps ymm8,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vaddps ymm0,ymm0,ymm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vmulps ymm9,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vaddps ymm4,ymm4,ymm9" - EmitIfCountGE \FilterCount\(), 2, "vmovups ymm12,YMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vmulps ymm10,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vaddps ymm1,ymm1,ymm10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vmulps ymm11,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vaddps ymm5,ymm5,ymm11" - EmitIfCountGE \FilterCount\(), 3, "vmovups ymm12,YMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vmulps ymm8,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vaddps ymm2,ymm2,ymm8" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vmulps ymm9,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vaddps ymm6,ymm6,ymm9" - EmitIfCountGE \FilterCount\(), 4, "vmovups ymm12,YMMWORD PTR [rbx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vmulps ymm10,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vaddps ymm3,ymm3,ymm10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vmulps ymm11,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vaddps ymm7,ymm7,ymm11" -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - rbp - Supplies the DilationWidth parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r15 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - -// -// Process the output blocks that include left padding. -// - - mov r10,\KernelFrame\()_OutputCountLeftPad[rsp] - test r10,r10 - jz .L\KernelType\().\FilterCount\().ProcessOutputCount - call MlasConv\KernelType\()FloatSingleAvxFilter\FilterCount\() - -// -// Process the output blocks that do not include any padding. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCount: - mov r10,\KernelFrame\()_OutputCount[rsp] - sub r10,2 - jb .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount - -.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2: - ProcessOutputCountN Avx, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 - lea rdi,[rdi+r9*2] # advance input by 2 elements - sub r10,2 - jae .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy2 - -.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: - add r10,2 # correct for over-subtract above - -// -// Process the output blocks that include right padding plus any remaining output -// blocks from above. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: - add r10,\KernelFrame\()_OutputCountRightPad[rsp] - jz .L\KernelType\().ExitKernel - call MlasConv\KernelType\()FloatSingleAvxFilter\FilterCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r10 - Supplies the OutputCount parameter (see function description). - - r12 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount - - sub r10,2 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCount - -.LPointwise.\FilterCount\().ProcessNextOutputCountBy2: - ProcessPointwiseOutputCountN Avx, 8, \FilterCount\(), 2 - lea rdi,[rdi+r9*2] # advance input by 2 elements - sub r10,2 - jae .LPointwise.\FilterCount\().ProcessNextOutputCountBy2 - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount: - add r10,2 # correct for over-subtract above - jz .LPointwise.ExitKernel - ProcessPointwiseOutputCountN Avx, 8, \FilterCount\(), 1 - - .endm - -// -// Generate the convolution kernels. -// - - SconvKernelFunction Nchw, 8, Avx - SconvKernelFunction Nchwc, 8, Avx, BiasFilter - SconvKernelDepthwiseFunction 8, Avx - SconvKernelPointwiseFunction Avx, BiasFilter - -/*++ - -Macro Description: - - This macro generates code to process an output block after the inner - convolution kernel has executed and then stores the output block to the - output buffer. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - ---*/ - - .macro PostProcessBlock FilterCount, OutputCount - - .globl MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\() -#if !defined(__APPLE__) - .hidden MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\() -#endif -MlasConvPostProcessFloatAvxFilter\FilterCount\()Output\OutputCount\(): - - .globl MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() -#if !defined(__APPLE__) - .hidden MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\() -#endif -MlasConvPostProcessFloatFma3Filter\FilterCount\()Output\OutputCount\(): - -.if \FilterCount\() > 2 - lea rbx,[r8+rax*2] # compute output plus 2 rows -.endif - -// -// Test if the existing contents of the output buffer should be accumulated -// with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vaddps ymm0,ymm0,YMMWORD PTR [r8]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vaddps ymm4,ymm4,YMMWORD PTR [r8+32]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vaddps ymm8,ymm8,YMMWORD PTR [r8+64]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vaddps ymm1,ymm1,YMMWORD PTR [r8+rax]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vaddps ymm5,ymm5,YMMWORD PTR [r8+rax+32]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vaddps ymm9,ymm9,YMMWORD PTR [r8+rax+64]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vaddps ymm2,ymm2,YMMWORD PTR [rbx]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vaddps ymm6,ymm6,YMMWORD PTR [rbx+32]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vaddps ymm10,ymm10,YMMWORD PTR [rbx+64]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vaddps ymm3,ymm3,YMMWORD PTR [rbx+rax]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vaddps ymm7,ymm7,YMMWORD PTR [rbx+rax+32]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vaddps ymm11,ymm11,YMMWORD PTR [rbx+rax+64]" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: - -// -// Test if the bias buffer should be accumulated with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "vaddps ymm0,ymm0,YMMWORD PTR [rcx]" - EmitIfCountGE \FilterCount\(), 2, "vaddps ymm1,ymm1,YMMWORD PTR [rcx+32]" - EmitIfCountGE \FilterCount\(), 3, "vaddps ymm2,ymm2,YMMWORD PTR [rcx+64]" - EmitIfCountGE \FilterCount\(), 4, "vaddps ymm3,ymm3,YMMWORD PTR [rcx+96]" -.else - EmitIfCountGE \FilterCount\(), 1, "vmovups ymm12,YMMWORD PTR [rcx]" - EmitIfCountGE \FilterCount\(), 2, "vmovups ymm13,YMMWORD PTR [rcx+32]" - EmitIfCountGE \FilterCount\(), 3, "vmovups ymm14,YMMWORD PTR [rcx+64]" - EmitIfCountGE \FilterCount\(), 4, "vmovups ymm15,YMMWORD PTR [rcx+96]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vaddps ymm0,ymm0,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vaddps ymm4,ymm4,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vaddps ymm8,ymm8,ymm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vaddps ymm1,ymm1,ymm13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vaddps ymm5,ymm5,ymm13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vaddps ymm9,ymm9,ymm13" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vaddps ymm2,ymm2,ymm14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vaddps ymm6,ymm6,ymm14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vaddps ymm10,ymm10,ymm14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vaddps ymm3,ymm3,ymm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vaddps ymm7,ymm7,ymm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vaddps ymm11,ymm11,ymm15" -.endif - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: - -// -// Test for fused ReLU activation. -// - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation - vxorps xmm15,xmm15,xmm15 - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vmaxps ymm0,ymm15,ymm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vmaxps ymm4,ymm15,ymm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vmaxps ymm8,ymm15,ymm8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vmaxps ymm1,ymm15,ymm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vmaxps ymm5,ymm15,ymm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vmaxps ymm9,ymm15,ymm9" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vmaxps ymm2,ymm15,ymm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vmaxps ymm6,ymm15,ymm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vmaxps ymm10,ymm15,ymm10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vmaxps ymm3,ymm15,ymm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vmaxps ymm7,ymm15,ymm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vmaxps ymm11,ymm15,ymm11" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: - -// -// Store the output block in the output buffer. -// - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vmovups YMMWORD PTR [r8],ymm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vmovups YMMWORD PTR [r8+32],ymm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vmovups YMMWORD PTR [r8+64],ymm8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vmovups YMMWORD PTR [r8+rax],ymm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vmovups YMMWORD PTR [r8+rax+32],ymm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vmovups YMMWORD PTR [r8+rax+64],ymm9" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vmovups YMMWORD PTR [rbx],ymm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vmovups YMMWORD PTR [rbx+32],ymm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vmovups YMMWORD PTR [rbx+64],ymm10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vmovups YMMWORD PTR [rbx+rax],ymm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vmovups YMMWORD PTR [rbx+rax+32],ymm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vmovups YMMWORD PTR [rbx+rax+64],ymm11" - add_immed r8,\OutputCount\()*8*4 # advance output by N nchw8c blocks - ret - - .endm - - .irp FilterCount, 1, 2, 3, 4 - .irp OutputCount, 1, 2, 3 - PostProcessBlock \FilterCount\(), \OutputCount\() - .endr - .endr - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S deleted file mode 100644 index 55d2aa613f212..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvx512F.S +++ /dev/null @@ -1,524 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelAvx512F.s - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to clear the block accumulators. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - zmm0-zmm23 - Supplies the block accumulators. - ---*/ - - .macro ClearBlock FilterCount, OutputCount - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vpxord zmm0,zmm0,zmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vpxord zmm4,zmm4,zmm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vpxord zmm8,zmm8,zmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vpxord zmm12,zmm12,zmm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vpxord zmm16,zmm16,zmm16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vpxord zmm20,zmm20,zmm20" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vpxord zmm1,zmm1,zmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vpxord zmm5,zmm5,zmm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vpxord zmm9,zmm9,zmm9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vpxord zmm13,zmm13,zmm13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vpxord zmm17,zmm17,zmm17" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vpxord zmm21,zmm21,zmm21" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vpxord zmm2,zmm2,zmm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vpxord zmm6,zmm6,zmm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vpxord zmm10,zmm10,zmm10" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 4, "vpxord zmm14,zmm14,zmm14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 5, "vpxord zmm18,zmm18,zmm18" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 6, "vpxord zmm22,zmm22,zmm22" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vpxord zmm3,zmm3,zmm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vpxord zmm7,zmm7,zmm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vpxord zmm11,zmm11,zmm11" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 4, "vpxord zmm15,zmm15,zmm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 5, "vpxord zmm19,zmm19,zmm19" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 6, "vpxord zmm23,zmm23,zmm23" - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbx - Supplies the address of the filter buffer plus 2 * FilterStride. - - r9 - Supplies the StrideWidth parameter (see function description). - - r14 - Supplies the address of the input buffer plus 3 * StrideWidth. - - zmm0-zmm23 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - vmovups zmm24,ZMMWORD PTR [rdx+\VectorOffset\()] - EmitIfCountGE \OutputCount\(), 1, "vfmadd231ps zmm0,zmm24,ZMMWORD PTR [rcx+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 2, "vfmadd231ps zmm4,zmm24,ZMMWORD PTR [rcx+r9+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 3, "vfmadd231ps zmm8,zmm24,ZMMWORD PTR [rcx+r9*2+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 4, "vfmadd231ps zmm12,zmm24,ZMMWORD PTR [r14+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 5, "vfmadd231ps zmm16,zmm24,ZMMWORD PTR [r14+r9+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 6, "vfmadd231ps zmm20,zmm24,ZMMWORD PTR [r14+r9*2+\BroadcastOffset\()]" -.else -.if \FilterCount\() == 1 - vmovups zmm24,ZMMWORD PTR [rdx+\VectorOffset\()] - EmitIfCountGE \OutputCount\(), 1, "vfmadd231ps zmm0,zmm24,DWORD PTR [rcx+\BroadcastOffset\()]{1to16}" - EmitIfCountGE \OutputCount\(), 2, "vfmadd231ps zmm4,zmm24,DWORD PTR [rcx+r9+\BroadcastOffset\()]{1to16}" - EmitIfCountGE \OutputCount\(), 3, "vfmadd231ps zmm8,zmm24,DWORD PTR [rcx+r9*2+\BroadcastOffset\()]{1to16}" - EmitIfCountGE \OutputCount\(), 4, "vfmadd231ps zmm12,zmm24,DWORD PTR [r14+\BroadcastOffset\()]{1to16}" - EmitIfCountGE \OutputCount\(), 5, "vfmadd231ps zmm16,zmm24,DWORD PTR [r14+r9+\BroadcastOffset\()]{1to16}" - EmitIfCountGE \OutputCount\(), 6, "vfmadd231ps zmm20,zmm24,DWORD PTR [r14+r9*2+\BroadcastOffset\()]{1to16}" -.else - EmitIfCountGE \OutputCount\(), 1, "vbroadcastss zmm26,DWORD PTR [rcx+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 2, "vbroadcastss zmm27,DWORD PTR [rcx+r9+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 3, "vbroadcastss zmm28,DWORD PTR [rcx+r9*2+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 4, "vbroadcastss zmm29,DWORD PTR [r14+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 5, "vbroadcastss zmm30,DWORD PTR [r14+r9+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 6, "vbroadcastss zmm31,DWORD PTR [r14+r9*2+\BroadcastOffset\()]" -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "vfmadd231ps zmm0,zmm26,ZMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 2, "vfmadd231ps zmm1,zmm26,ZMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 3, "vfmadd231ps zmm2,zmm26,ZMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 4, "vfmadd231ps zmm3,zmm26,ZMMWORD PTR [rbx+rsi+\VectorOffset\()]" -.else - EmitIfCountGE \FilterCount\(), 1, "vmovups zmm24,ZMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd231ps zmm0,zmm26,zmm24" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vfmadd231ps zmm4,zmm27,zmm24" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vfmadd231ps zmm8,zmm28,zmm24" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vfmadd231ps zmm12,zmm29,zmm24" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vfmadd231ps zmm16,zmm30,zmm24" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vfmadd231ps zmm20,zmm31,zmm24" - EmitIfCountGE \FilterCount\(), 2, "vmovups zmm24,ZMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd231ps zmm1,zmm26,zmm24" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vfmadd231ps zmm5,zmm27,zmm24" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vfmadd231ps zmm9,zmm28,zmm24" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vfmadd231ps zmm13,zmm29,zmm24" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vfmadd231ps zmm17,zmm30,zmm24" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vfmadd231ps zmm21,zmm31,zmm24" - EmitIfCountGE \FilterCount\(), 3, "vmovups zmm24,ZMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd231ps zmm2,zmm26,zmm24" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vfmadd231ps zmm6,zmm27,zmm24" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vfmadd231ps zmm10,zmm28,zmm24" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 4, "vfmadd231ps zmm14,zmm29,zmm24" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 5, "vfmadd231ps zmm18,zmm30,zmm24" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 6, "vfmadd231ps zmm22,zmm31,zmm24" - EmitIfCountGE \FilterCount\(), 4, "vmovups zmm24,ZMMWORD PTR [rbx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd231ps zmm3,zmm26,zmm24" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vfmadd231ps zmm7,zmm27,zmm24" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vfmadd231ps zmm11,zmm28,zmm24" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 4, "vfmadd231ps zmm15,zmm29,zmm24" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 5, "vfmadd231ps zmm19,zmm30,zmm24" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 6, "vfmadd231ps zmm23,zmm31,zmm24" -.endif -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - rbp - Supplies the DilationWidth parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r15 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - -// -// Process the output blocks that include left padding. -// - - mov r10,\KernelFrame\()_OutputCountLeftPad[rsp] - test r10,r10 - jz .L\KernelType\().\FilterCount\().ProcessOutputCount - call MlasConv\KernelType\()FloatSingleAvx512FFilter\FilterCount\() - -// -// Process the output blocks that do not include any padding. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCount: - mov r10,\KernelFrame\()_OutputCount[rsp] - sub r10,6 - jb .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount - -.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy6: - ProcessOutputCountN Avx512F, \KernelFrame\(), \KernelType\(), 16, \FilterCount\(), 6 - lea rax,[r9*2+r9] - lea rdi,[rdi+rax*2] # advance input by 6 elements - sub r10,6 - jae .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy6 - -.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: - add r10,6 # correct for over-subtract above - jz .L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining - cmp r10,3 - jb .L\KernelType\().\FilterCount\().ProcessRemainingOutputCountLessThan3 - ProcessOutputCountN Avx512F, \KernelFrame\(), \KernelType\(), 16, \FilterCount\(), 3 - lea rax,[r9*2+r9] - add rdi,rax # advance input by 3 elements - sub r10,3 - jz .L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining - -.L\KernelType\().\FilterCount\().ProcessRemainingOutputCountLessThan3: - cmp r10,1 - je .L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining - ProcessOutputCountN Avx512F, \KernelFrame\(), \KernelType\(), 16, \FilterCount\(), 2 - lea rdi,[rdi+r9*2] # advance input by 2 elements - sub r10,2 - -// -// Process the output blocks that include right padding plus any remaining output -// blocks from above. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: - add r10,\KernelFrame\()_OutputCountRightPad[rsp] - jz .L\KernelType\().ExitKernel - call MlasConv\KernelType\()FloatSingleAvx512FFilter\FilterCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r10 - Supplies the OutputCount parameter (see function description). - - r12 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount - - sub r10,6 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCount - -.LPointwise.\FilterCount\().ProcessNextOutputCountBy6: - ProcessPointwiseOutputCountN Avx512F, 16, \FilterCount\(), 6 - lea rax,[r9*2+r9] - lea rdi,[rdi+rax*2] # advance input by 6 elements - sub r10,6 - jae .LPointwise.\FilterCount\().ProcessNextOutputCountBy6 - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount: - add r10,6 # correct for over-subtract above - jz .LPointwise.ExitKernel - cmp r10,3 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCountLessThan3 - ProcessPointwiseOutputCountN Avx512F, 16, \FilterCount\(), 3 - lea rax,[r9*2+r9] - add rdi,rax # advance input by 3 elements - sub r10,3 - jz .LPointwise.ExitKernel - -.LPointwise.\FilterCount\().ProcessRemainingOutputCountLessThan3: - cmp r10,2 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCount1 - ProcessPointwiseOutputCountN Avx512F, 16, \FilterCount\(), 2 - jmp .LPointwise.ExitKernel - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount1: - ProcessPointwiseOutputCountN Avx512F, 16, \FilterCount\(), 1 - - .endm - -// -// Generate the convolution kernels. -// -// N.B. BiasFilter is not used here as the AVX-512 EVEX instruction encoding -// efficiently compresses aligned relative byte offsets. -// - - SconvKernelFunction Nchw, 16, Avx512F - SconvKernelFunction Nchwc, 16, Avx512F - SconvKernelDepthwiseFunction 16, Avx512F - SconvKernelPointwiseFunction Avx512F - -/*++ - -Macro Description: - - This macro generates code to process an output block after the inner - convolution kernel has executed and then stores the output block to the - output buffer. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - ---*/ - - .macro PostProcessBlock FilterCount, OutputCount - - .globl MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\() -#if !defined(__APPLE__) - .hidden MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\() -#endif -MlasConvPostProcessFloatAvx512FFilter\FilterCount\()Output\OutputCount\(): - -.if \FilterCount\() > 2 - lea rbx,[r8+rax*2] # compute output plus 2 rows -.endif - -// -// Test if the existing contents of the output buffer should be accumulated -// with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vaddps zmm0,zmm0,ZMMWORD PTR [r8]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vaddps zmm4,zmm4,ZMMWORD PTR [r8+16*4]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vaddps zmm8,zmm8,ZMMWORD PTR [r8+32*4]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vaddps zmm12,zmm12,ZMMWORD PTR [r8+48*4]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vaddps zmm16,zmm16,ZMMWORD PTR [r8+64*4]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vaddps zmm20,zmm20,ZMMWORD PTR [r8+80*4]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vaddps zmm1,zmm1,ZMMWORD PTR [r8+rax]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vaddps zmm5,zmm5,ZMMWORD PTR [r8+rax+16*4]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vaddps zmm9,zmm9,ZMMWORD PTR [r8+rax+32*4]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vaddps zmm13,zmm13,ZMMWORD PTR [r8+rax+48*4]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vaddps zmm17,zmm17,ZMMWORD PTR [r8+rax+64*4]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vaddps zmm21,zmm21,ZMMWORD PTR [r8+rax+80*4]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vaddps zmm2,zmm2,ZMMWORD PTR [rbx]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vaddps zmm6,zmm6,ZMMWORD PTR [rbx+16*4]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vaddps zmm10,zmm10,ZMMWORD PTR [rbx+32*4]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 4, "vaddps zmm14,zmm14,ZMMWORD PTR [rbx+48*4]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 5, "vaddps zmm18,zmm18,ZMMWORD PTR [rbx+64*4]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 6, "vaddps zmm22,zmm22,ZMMWORD PTR [rbx+80*4]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vaddps zmm3,zmm3,ZMMWORD PTR [rbx+rax]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vaddps zmm7,zmm7,ZMMWORD PTR [rbx+rax+16*4]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vaddps zmm11,zmm11,ZMMWORD PTR [rbx+rax+32*4]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 4, "vaddps zmm15,zmm15,ZMMWORD PTR [rbx+rax+48*4]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 5, "vaddps zmm19,zmm19,ZMMWORD PTR [rbx+rax+64*4]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 6, "vaddps zmm23,zmm23,ZMMWORD PTR [rbx+rax+80*4]" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: - -// -// Test if the bias buffer should be accumulated with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "vaddps zmm0,zmm0,ZMMWORD PTR [rcx]" - EmitIfCountGE \FilterCount\(), 2, "vaddps zmm1,zmm1,ZMMWORD PTR [rcx+16*4]" - EmitIfCountGE \FilterCount\(), 3, "vaddps zmm2,zmm2,ZMMWORD PTR [rcx+32*4]" - EmitIfCountGE \FilterCount\(), 4, "vaddps zmm3,zmm3,ZMMWORD PTR [rcx+48*4]" -.else - EmitIfCountGE \FilterCount\(), 1, "vmovups zmm28,ZMMWORD PTR [rcx]" - EmitIfCountGE \FilterCount\(), 2, "vmovups zmm29,ZMMWORD PTR [rcx+16*4]" - EmitIfCountGE \FilterCount\(), 3, "vmovups zmm30,ZMMWORD PTR [rcx+32*4]" - EmitIfCountGE \FilterCount\(), 4, "vmovups zmm31,ZMMWORD PTR [rcx+48*4]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vaddps zmm0,zmm0,zmm28" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vaddps zmm4,zmm4,zmm28" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vaddps zmm8,zmm8,zmm28" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vaddps zmm12,zmm12,zmm28" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vaddps zmm16,zmm16,zmm28" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vaddps zmm20,zmm20,zmm28" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vaddps zmm1,zmm1,zmm29" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vaddps zmm5,zmm5,zmm29" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vaddps zmm9,zmm9,zmm29" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vaddps zmm13,zmm13,zmm29" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vaddps zmm17,zmm17,zmm29" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vaddps zmm21,zmm21,zmm29" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vaddps zmm2,zmm2,zmm30" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vaddps zmm6,zmm6,zmm30" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vaddps zmm10,zmm10,zmm30" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 4, "vaddps zmm14,zmm14,zmm30" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 5, "vaddps zmm18,zmm18,zmm30" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 6, "vaddps zmm22,zmm22,zmm30" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vaddps zmm3,zmm3,zmm31" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vaddps zmm7,zmm7,zmm31" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vaddps zmm11,zmm11,zmm31" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 4, "vaddps zmm15,zmm15,zmm31" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 5, "vaddps zmm19,zmm19,zmm31" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 6, "vaddps zmm23,zmm23,zmm31" -.endif - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: - -// -// Test for fused ReLU activation. -// - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation - vpxord zmm24,zmm24,zmm24 - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vmaxps zmm0,zmm24,zmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vmaxps zmm4,zmm24,zmm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vmaxps zmm8,zmm24,zmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vmaxps zmm12,zmm24,zmm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vmaxps zmm16,zmm24,zmm16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vmaxps zmm20,zmm24,zmm20" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vmaxps zmm1,zmm24,zmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vmaxps zmm5,zmm24,zmm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vmaxps zmm9,zmm24,zmm9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vmaxps zmm13,zmm24,zmm13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vmaxps zmm17,zmm24,zmm17" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vmaxps zmm21,zmm24,zmm21" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vmaxps zmm2,zmm24,zmm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vmaxps zmm6,zmm24,zmm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vmaxps zmm10,zmm24,zmm10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vmaxps zmm14,zmm24,zmm14" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vmaxps zmm18,zmm24,zmm18" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vmaxps zmm22,zmm24,zmm22" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vmaxps zmm3,zmm24,zmm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vmaxps zmm7,zmm24,zmm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vmaxps zmm11,zmm24,zmm11" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vmaxps zmm15,zmm24,zmm15" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vmaxps zmm19,zmm24,zmm19" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vmaxps zmm23,zmm24,zmm23" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: - -// -// Store the output block in the output buffer. -// - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vmovups ZMMWORD PTR [r8],zmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vmovups ZMMWORD PTR [r8+16*4],zmm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vmovups ZMMWORD PTR [r8+32*4],zmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 4, "vmovups ZMMWORD PTR [r8+48*4],zmm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 5, "vmovups ZMMWORD PTR [r8+64*4],zmm16" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 6, "vmovups ZMMWORD PTR [r8+80*4],zmm20" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vmovups ZMMWORD PTR [r8+rax],zmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vmovups ZMMWORD PTR [r8+rax+16*4],zmm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vmovups ZMMWORD PTR [r8+rax+32*4],zmm9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 4, "vmovups ZMMWORD PTR [r8+rax+48*4],zmm13" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 5, "vmovups ZMMWORD PTR [r8+rax+64*4],zmm17" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 6, "vmovups ZMMWORD PTR [r8+rax+80*4],zmm21" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vmovups ZMMWORD PTR [rbx],zmm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vmovups ZMMWORD PTR [rbx+16*4],zmm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vmovups ZMMWORD PTR [rbx+32*4],zmm10" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 4, "vmovups ZMMWORD PTR [rbx+48*4],zmm14" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 5, "vmovups ZMMWORD PTR [rbx+64*4],zmm18" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 6, "vmovups ZMMWORD PTR [rbx+80*4],zmm22" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vmovups ZMMWORD PTR [rbx+rax],zmm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vmovups ZMMWORD PTR [rbx+rax+16*4],zmm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vmovups ZMMWORD PTR [rbx+rax+32*4],zmm11" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 4, "vmovups ZMMWORD PTR [rbx+rax+48*4],zmm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 5, "vmovups ZMMWORD PTR [rbx+rax+64*4],zmm19" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 6, "vmovups ZMMWORD PTR [rbx+rax+80*4],zmm23" - add_immed r8,\OutputCount\()*16*4 # advance output by N nchw16c blocks - ret - - .endm - - .irp FilterCount, 1, 2, 3, 4 - .irp OutputCount, 1, 2, 3, 6 - PostProcessBlock \FilterCount\(), \OutputCount\() - .endr - .endr - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvxCommon.h b/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvxCommon.h deleted file mode 100644 index 7562d688cd2d7..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelAvxCommon.h +++ /dev/null @@ -1,53 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelAvxCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision convolution operation for the AVX and FMA3 kernels. - ---*/ - -#include "SconvKernelCommon.h" - -/*++ - -Macro Description: - - This macro generates code to clear the block accumulators. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - ymm0-ymm11 - Supplies the block accumulators. - ---*/ - - .macro ClearBlock FilterCount, OutputCount - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vxorps xmm0,xmm0,xmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vxorps xmm4,xmm4,xmm4" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vxorps xmm8,xmm8,xmm8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vxorps xmm1,xmm1,xmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vxorps xmm5,xmm5,xmm5" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vxorps xmm9,xmm9,xmm9" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vxorps xmm2,xmm2,xmm2" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vxorps xmm6,xmm6,xmm6" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vxorps xmm10,xmm10,xmm10" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vxorps xmm3,xmm3,xmm3" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vxorps xmm7,xmm7,xmm7" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vxorps xmm11,xmm11,xmm11" - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/SconvKernelCommon.h deleted file mode 100644 index a2a858d15de20..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelCommon.h +++ /dev/null @@ -1,765 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision convolution operation. - ---*/ - -// -// Define the convolution kernel flags. -// - -#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 -#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 -#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 -#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 - -// -// Stack frame layout for the convolution kernels. -// - - .equ .LSconvKernelFrame_Filter, 0 - .equ .LSconvKernelFrame_SavedR12, 8 - .equ .LSconvKernelFrame_SavedR13, 16 - .equ .LSconvKernelFrame_SavedR14, 24 - .equ .LSconvKernelFrame_SavedR15, 32 - .equ .LSconvKernelFrame_SavedRbx, 40 - .equ .LSconvKernelFrame_SavedRbp, 48 - .equ .LSconvKernelFrame_ReturnAddress, 56 - .equ .LSconvKernelFrame_InputStride, 64 - .equ .LSconvKernelFrame_FilterStride, 72 - .equ .LSconvKernelFrame_OutputStride, 80 - .equ .LSconvKernelFrame_KernelHeight, 88 - .equ .LSconvKernelFrame_KernelWidth, 96 - .equ .LSconvKernelFrame_InputBase, 104 - .equ .LSconvKernelFrame_InputWidth, 112 - .equ .LSconvKernelFrame_DilatedInputWidth, 120 - .equ .LSconvKernelFrame_OutputCountLeftPad, 128 - .equ .LSconvKernelFrame_OutputCount, 136 - .equ .LSconvKernelFrame_OutputCountRightPad, 144 - .equ .LSconvKernelFrame_Bias, 152 - .equ .LSconvKernelFrame_Flags, 160 - - .equ .LSconvKernelSingleFrame_ReturnAddress, 0 - .equ .LSconvKernelSingleFrame_Filter, 8 - .equ .LSconvKernelSingleFrame_SavedR12, 16 - .equ .LSconvKernelSingleFrame_SavedR13, 24 - .equ .LSconvKernelSingleFrame_SavedR14, 32 - .equ .LSconvKernelSingleFrame_SavedR15, 40 - .equ .LSconvKernelSingleFrame_SavedRbx, 48 - .equ .LSconvKernelSingleFrame_SavedRbp, 56 - .equ .LSconvKernelSingleFrame_ParentReturnAddress, 64 - .equ .LSconvKernelSingleFrame_InputStride, 72 - .equ .LSconvKernelSingleFrame_FilterStride, 80 - .equ .LSconvKernelSingleFrame_OutputStride, 88 - .equ .LSconvKernelSingleFrame_KernelHeight, 96 - .equ .LSconvKernelSingleFrame_KernelWidth, 104 - .equ .LSconvKernelSingleFrame_InputBase, 112 - .equ .LSconvKernelSingleFrame_InputWidth, 120 - .equ .LSconvKernelSingleFrame_DilatedInputWidth, 128 - .equ .LSconvKernelSingleFrame_OutputCountLeftPad, 136 - .equ .LSconvKernelSingleFrame_OutputCount, 144 - .equ .LSconvKernelSingleFrame_OutputCountRightPad, 152 - .equ .LSconvKernelSingleFrame_Bias, 160 - .equ .LSconvKernelSingleFrame_Flags, 168 - - .equ .LSconvKernelDepthwiseFrame_SavedR12, 0 - .equ .LSconvKernelDepthwiseFrame_SavedR13, 8 - .equ .LSconvKernelDepthwiseFrame_SavedR14, 16 - .equ .LSconvKernelDepthwiseFrame_SavedR15, 24 - .equ .LSconvKernelDepthwiseFrame_SavedRbx, 32 - .equ .LSconvKernelDepthwiseFrame_SavedRbp, 40 - .equ .LSconvKernelDepthwiseFrame_ReturnAddress, 48 - .equ .LSconvKernelDepthwiseFrame_KernelHeight, 56 - .equ .LSconvKernelDepthwiseFrame_KernelWidth, 64 - .equ .LSconvKernelDepthwiseFrame_InputBase, 72 - .equ .LSconvKernelDepthwiseFrame_InputWidth, 80 - .equ .LSconvKernelDepthwiseFrame_DilatedInputWidth, 88 - .equ .LSconvKernelDepthwiseFrame_OutputCountLeftPad, 96 - .equ .LSconvKernelDepthwiseFrame_OutputCount, 104 - .equ .LSconvKernelDepthwiseFrame_OutputCountRightPad, 112 - .equ .LSconvKernelDepthwiseFrame_Bias, 120 - .equ .LSconvKernelDepthwiseFrame_Flags, 128 - - .equ .LSconvKernelDepthwiseSingleFrame_ReturnAddress, 0 - .equ .LSconvKernelDepthwiseSingleFrame_SavedR12, 8 - .equ .LSconvKernelDepthwiseSingleFrame_SavedR13, 16 - .equ .LSconvKernelDepthwiseSingleFrame_SavedR14, 24 - .equ .LSconvKernelDepthwiseSingleFrame_SavedR15, 32 - .equ .LSconvKernelDepthwiseSingleFrame_SavedRbx, 40 - .equ .LSconvKernelDepthwiseSingleFrame_SavedRbp, 48 - .equ .LSconvKernelDepthwiseSingleFrame_ParentReturnAddress, 56 - .equ .LSconvKernelDepthwiseSingleFrame_KernelHeight, 64 - .equ .LSconvKernelDepthwiseSingleFrame_KernelWidth, 72 - .equ .LSconvKernelDepthwiseSingleFrame_InputBase, 80 - .equ .LSconvKernelDepthwiseSingleFrame_InputWidth, 88 - .equ .LSconvKernelDepthwiseSingleFrame_DilatedInputWidth, 96 - .equ .LSconvKernelDepthwiseSingleFrame_OutputCountLeftPad, 104 - .equ .LSconvKernelDepthwiseSingleFrame_OutputCount, 112 - .equ .LSconvKernelDepthwiseSingleFrame_OutputCountRightPad, 120 - .equ .LSconvKernelDepthwiseSingleFrame_Bias, 128 - .equ .LSconvKernelDepthwiseSingleFrame_Flags, 136 - - .equ .LSconvKernelPointwiseFrame_InputChannels, 0 - .equ .LSconvKernelPointwiseFrame_SavedR12, 8 - .equ .LSconvKernelPointwiseFrame_SavedR14, 16 - .equ .LSconvKernelPointwiseFrame_SavedRbx, 24 - .equ .LSconvKernelPointwiseFrame_SavedRbp, 32 - .equ .LSconvKernelPointwiseFrame_ReturnAddress, 40 - .equ .LSconvKernelPointwiseFrame_InputStride, 48 - .equ .LSconvKernelPointwiseFrame_FilterStride, 56 - .equ .LSconvKernelPointwiseFrame_OutputStride, 64 - .equ .LSconvKernelPointwiseFrame_OutputCount, 72 - .equ .LSconvKernelPointwiseFrame_Bias, 80 - .equ .LSconvKernelPointwiseFrame_Flags, 88 - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - rbp - Supplies the DilationWidth parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r15 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessOutputCountN Isa, KernelFrame, KernelType, BlockSize, FilterCount, OutputCount - - mov rcx,rdi -.ifeqs "\KernelType\()","Depthwise" - mov rdx,rsi -.else - mov rdx,\KernelFrame\()_Filter[rsp] -.endif - mov r11,\KernelFrame\()_KernelHeight[rsp] - mov r12,\KernelFrame\()_KernelWidth[rsp] -.if \OutputCount\() == 1 - mov r13,\KernelFrame\()_InputBase[rsp] - mov r14,\KernelFrame\()_InputWidth[rsp] - neg r13 # keep negative for lea usage below -.endif - ClearBlock \FilterCount\(), \OutputCount\() - test r11,r11 # zero sized kernel? - jz .L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing - -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow: - mov rax,r12 # reload kernel width remaining - -.L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - lea rbx,[rcx+r13] # compute (Input - InputBase) - cmp rbx,r14 # (Input - InputBase) >= InputWidth? - jae .L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding -.endif -.if \OutputCount\() > 3 - lea r14,[r9+r9*2] - add r14,rcx # compute input plus 3 blocks -.endif -.if \FilterCount\() > 2 - lea rbx,[rdx+rsi*2] # compute filter plus 2 rows -.endif -.ifeqs "\KernelType\()","Nchwc" -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif -.else - ComputeBlock \KernelType\(), \FilterCount\(), \OutputCount\(), 0, 0 -.endif - -.L\KernelType\().\FilterCount\().\OutputCount\().SkipOverPadding: - add rcx,rbp # advance input by dilation width -.ifeqs "\KernelType\()","Nchwc" - add rdx,\BlockSize\()*\BlockSize\()*4 - # advance filter by 8i8o/16i16o block -.else - add rdx,\BlockSize\()*4 # advance filter by 8o/16o block -.endif - dec rax # decrement columns remaining - jnz .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextColumn - add rcx,r15 # advance input to next row -.if \OutputCount\() == 1 - sub r13,\KernelFrame\()_DilatedInputWidth[rsp] - # advance input base to next row -.endif - dec r11 # decrement rows remaining - jnz .L\KernelType\().\FilterCount\().\OutputCount\().ProcessNextRow - -// -// Handle post processing of the output block. -// - -.L\KernelType\().\FilterCount\().\OutputCount\().HandlePostProcessing: - mov edx,DWORD PTR \KernelFrame\()_Flags[rsp] -.if \FilterCount\() > 1 - mov rax,\KernelFrame\()_OutputStride[rsp] -.endif - mov rcx,\KernelFrame\()_Bias[rsp] - call MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - - BiasFilter - Supplies a non-blank value if the address of the filter buffer - should be biased to point to the middle of a OIhw8i8o block in order to - reduce the code size from relative byte offsets. - ---*/ - - .macro SconvKernelFunction KernelType, BlockSize, Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - StrideWidth (rcx) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (r8) - Supplies the length in bytes of the blocked dilation - width. - - FilterCount (r9) - Supplies the number of filters to process in this - iteration. - - InputStride - Supplies the length in bytes to advance the input buffer to - the next input row. - - FilterStride - Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride - Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - KernelHeight - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth - Supplies the width of the kernel to apply. - - InputBase - Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth - Supplies the length in bytes of the blocked input width. - - DilatedInputWidth - Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad - Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount - Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad - Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias - Supplies the address of the bias buffer. - - Flags - Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConv\KernelType\()FloatKernel\Isa\() - - push rbp - push rbx - push r15 - push r14 - push r13 - push r12 -.ifeqs "\BiasFilter\()","BiasFilter" - add_immed rsi,4*8*4 -.endif - push rsi - mov rsi,.LSconvKernelFrame_FilterStride[rsp] - mov r15,.LSconvKernelFrame_InputStride[rsp] - mov rbp,r8 # shuffle to Win64 register usage - mov r11,r9 - mov r8,rdx - mov r9,rcx - -// -// Process the specified number of filter rows. -// - - cmp r11,3 - je .L\KernelType\().ProcessFilterCount3 - jb .L\KernelType\().ProcessFilterCountLessThan3 - ProcessFilterCountN .LSconvKernelFrame, \KernelType\(), 4 - jmp .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount3: - ProcessFilterCountN .LSconvKernelFrame, \KernelType\(), 3 - jmp .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCountLessThan3: - cmp r11,2 - jb .L\KernelType\().ProcessFilterCount1 - ProcessFilterCountN .LSconvKernelFrame, \KernelType\(), 2 - jmp .L\KernelType\().ExitKernel - -.L\KernelType\().ProcessFilterCount1: - ProcessFilterCountN .LSconvKernelFrame, \KernelType\(), 1 - -// -// Restore non-volatile registers and return. -// - -.L\KernelType\().ExitKernel: -.ifnes "\Isa\()","Sse" - vzeroupper -.endif - pop rsi # clear Filter local - pop r12 - pop r13 - pop r14 - pop r15 - pop rbx - pop rbp - ret - -.ifnes "\Isa\()","Sse" - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - - .irp FilterCount, 1, 2, 3, 4 - -MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\(): - ProcessOutputCountN \Isa\(), .LSconvKernelSingleFrame, \KernelType\(), \BlockSize\(), \FilterCount\(), 1 - add rdi,r9 # advance input by 1 element - dec r10 # decrement output count remaining - jnz MlasConv\KernelType\()FloatSingle\Isa\()Filter\FilterCount\() - ret - - .endr - -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel for the special - case of a depthwise separable convolution. - -Arguments: - - BlockSize - Supplies the number of elements per block. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SconvKernelDepthwiseFunction BlockSize, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Depthwise separable convolutions are a form of grouped convolution where - the number of input and output channels per group are one. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - StrideWidth (rcx) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (r8) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (r9) - Supplies the length in bytes to advance the input buffer - to the next input row. - - KernelHeight - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth - Supplies the width of the kernel to apply. - - InputBase - Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth - Supplies the length in bytes of the blocked input width. - - DilatedInputWidth - Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad - Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount - Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad - Supplies the number of output elements that include - one or more padding elements from the right edge. - - Bias - Supplies the address of the bias buffer. - - Flags - Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvDepthwiseFloatKernel\Isa\() - - push rbp - push rbx - push r15 - push r14 - push r13 - push r12 - mov rbp,r8 # shuffle to Win64 register usage - mov r15,r9 - mov r8,rdx - mov r9,rcx - -// -// Process the specified number of filter rows. -// - - ProcessFilterCountN .LSconvKernelDepthwiseFrame, Depthwise, 1 - -// -// Restore non-volatile registers and return. -// - -.LDepthwise.ExitKernel: -.ifnes "\Isa\()","Sse" - vzeroupper -.endif - pop r12 - pop r13 - pop r14 - pop r15 - pop rbx - pop rbp - ret - -.ifnes "\Isa\()","Sse" - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - -MlasConvDepthwiseFloatSingle\Isa\()Filter1: - ProcessOutputCountN \Isa\(), .LSconvKernelDepthwiseSingleFrame, Depthwise, \BlockSize\(), 1, 1 - add rdi,r9 # advance input by 1 element - dec r10 # decrement output count remaining - jnz MlasConvDepthwiseFloatSingle\Isa\()Filter1 - ret - -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a vector of input - blocks and a vector of filter blocks to produce a matrix of output blocks - for a pointwise convolution. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - BlockSize - Supplies the number of elements per block. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r12 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseOutputCountN Isa, BlockSize, FilterCount, OutputCount - - mov rcx,rdi - mov rdx,r12 - mov r11,.LSconvKernelPointwiseFrame_InputChannels[rsp] - ClearBlock \FilterCount\(), \OutputCount\() - -.LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock: -.if \OutputCount\() > 3 - lea r14,[r9+r9*2] - add r14,rcx # compute input plus 3 blocks -.endif -.if \FilterCount\() > 2 - lea rbx,[rdx+rsi*2] # compute filter plus 2 rows -.endif -.if \BlockSize\() == 16 - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), \Index\()*16*4, \Index\()*4 - .endr -.else - .irp Index, 0, 1, 2, 3, 4, 5, 6, 7 - ComputeBlock Pointwise, \FilterCount\(), \OutputCount\(), (\Index\()-4)*8*4, \Index\()*4 - .endr -.endif - add rcx,rbp # advance input to next channel block - add rdx,\BlockSize\()*\BlockSize\()*4 - # advance filter by 8i8o/16i16o block - dec r11 # decrement input blocks remaining - jnz .LPointwise.\FilterCount\().\OutputCount\().ProcessNextInputBlock - -// -// Handle post processing of the output block. -// - - mov edx,DWORD PTR .LSconvKernelPointwiseFrame_Flags[rsp] -.if \FilterCount\() > 1 - mov rax,.LSconvKernelPointwiseFrame_OutputStride[rsp] -.endif - mov rcx,.LSconvKernelPointwiseFrame_Bias[rsp] - call MlasConvPostProcessFloat\Isa\()Filter\FilterCount\()Output\OutputCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner convolution kernel for the special - case where the kernel dimensions are 1. - -Arguments: - - Isa - Supplies the instruction set architecture string for function tags. - - BiasFilter - Supplies a non-blank value if the address of the filter buffer - should be biased to point to the middle of a OIhw8i8o block in order to - reduce the code size from relative byte offsets. - ---*/ - - .macro SconvKernelPointwiseFunction Isa, BiasFilter - -/*++ - -Routine Description: - - This routine is the inner kernel to compute a convolution for the elements - of an output row for a set of filter rows. - - Pointwise convolutions have a kernel size of one. To simplify this - implementation, no input padding is allowed, which matches typical usage in - models. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - Filter (rsi) - Supplies the address of the filter buffer. - - Output (rdx) - Supplies the address of the output buffer. - - StrideWidth (rcx) - Supplies the length in bytes of the blocked stride width. - - InputChannels (r8) - Supplies the number of input channels to process. - - FilterCount (r9) - Supplies the number of rows from the filter to process. - - InputStride - Supplies the length in bytes to advance the input buffer to - the next input channel of the same input row. - - FilterStride - Supplies the length in bytes to advance the filter buffer - to the next set of filters. - - OutputStride - Supplies the length in bytes to advance the output buffer - to the next output address associated with the next set of filters. - - OutputCount - Supplies the number of output elements. - - Bias - Supplies the address of the bias buffer. - - Flags - Supplies additional flags controlling the convolution operation, - especially post calculation options. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasConvPointwiseFloatKernel\Isa\() - - push rbp - push rbx - push r14 - push r12 - push r8 -.ifeqs "\BiasFilter\()","BiasFilter" - lea r12,[rsi+4*8*4] -.else - mov r12,rsi -.endif - mov r10,.LSconvKernelPointwiseFrame_OutputCount[rsp] - mov rsi,.LSconvKernelPointwiseFrame_FilterStride[rsp] - mov rbp,.LSconvKernelPointwiseFrame_InputStride[rsp] - mov r11,r9 # shuffle to Win64 register usage - mov r8,rdx - mov r9,rcx - -// -// Process the specified number of filter rows. -// - - cmp r11,3 - je .LPointwise.ProcessFilterCount3 - jb .LPointwise.ProcessFilterCountLessThan3 - ProcessPointwiseFilterCountN 4 - jmp .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount3: - ProcessPointwiseFilterCountN 3 - jmp .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCountLessThan3: - cmp r11,2 - jb .LPointwise.ProcessFilterCount1 - ProcessPointwiseFilterCountN 2 - jmp .LPointwise.ExitKernel - -.LPointwise.ProcessFilterCount1: - ProcessPointwiseFilterCountN 1 - -// -// Restore non-volatile registers and return. -// - -.LPointwise.ExitKernel: -.ifnes "\Isa\()","Sse" - vzeroupper -.endif - pop r8 # clear InputChannels local - pop r12 - pop r14 - pop rbx - pop rbp - ret - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelFma3.S deleted file mode 100644 index 0cf063626d2f5..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelFma3.S +++ /dev/null @@ -1,247 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelFma3.s - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbx - Supplies the address of the filter buffer plus 2 * FilterStride. - - r9 - Supplies the StrideWidth parameter (see function description). - - ymm0-ymm11 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - vmovups ymm12,YMMWORD PTR [rdx] - EmitIfCountGE \OutputCount\(), 1, "vfmadd231ps ymm0,ymm12,YMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 2, "vfmadd231ps ymm4,ymm12,YMMWORD PTR [rcx+r9]" - EmitIfCountGE \OutputCount\(), 3, "vfmadd231ps ymm8,ymm12,YMMWORD PTR [rcx+r9*2]" -.else - EmitIfCountGE \OutputCount\(), 1, "vbroadcastss ymm13,DWORD PTR [rcx+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 2, "vbroadcastss ymm14,DWORD PTR [rcx+r9+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 3, "vbroadcastss ymm15,DWORD PTR [rcx+r9*2+\BroadcastOffset\()]" -.if \OutputCount\() == 1 - EmitIfCountGE \FilterCount\(), 1, "vfmadd231ps ymm0,ymm13,YMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 2, "vfmadd231ps ymm1,ymm13,YMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 3, "vfmadd231ps ymm2,ymm13,YMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 4, "vfmadd231ps ymm3,ymm13,YMMWORD PTR [rbx+rsi+\VectorOffset\()]" -.else - EmitIfCountGE \FilterCount\(), 1, "vmovups ymm12,YMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "vfmadd231ps ymm0,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 2, "vfmadd231ps ymm4,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 3, "vfmadd231ps ymm8,ymm15,ymm12" - EmitIfCountGE \FilterCount\(), 2, "vmovups ymm12,YMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "vfmadd231ps ymm1,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 2, "vfmadd231ps ymm5,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 3, "vfmadd231ps ymm9,ymm15,ymm12" - EmitIfCountGE \FilterCount\(), 3, "vmovups ymm12,YMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "vfmadd231ps ymm2,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 2, "vfmadd231ps ymm6,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 3, "vfmadd231ps ymm10,ymm15,ymm12" - EmitIfCountGE \FilterCount\(), 4, "vmovups ymm12,YMMWORD PTR [rbx+rsi+\VectorOffset\()]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "vfmadd231ps ymm3,ymm13,ymm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 2, "vfmadd231ps ymm7,ymm14,ymm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 3, "vfmadd231ps ymm11,ymm15,ymm12" -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - rbp - Supplies the DilationWidth parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r15 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - -// -// Process the output blocks that include left padding. -// - - mov r10,\KernelFrame\()_OutputCountLeftPad[rsp] - test r10,r10 - jz .L\KernelType\().\FilterCount\().ProcessOutputCount - call MlasConv\KernelType\()FloatSingleFma3Filter\FilterCount\() - -// -// Process the output blocks that do not include any padding. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCount: - mov r10,\KernelFrame\()_OutputCount[rsp] - sub r10,3 - jb .L\KernelType\().\FilterCount\().ProcessRemainingOutputCount - -.L\KernelType\().\FilterCount\().ProcessNextOutputCountBy3: - ProcessOutputCountN Fma3, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 3 - lea rax,[r9*2+r9] - add rdi,rax # advance input by 3 elements - sub r10,3 - jae .L\KernelType\().\FilterCount\().ProcessNextOutputCountBy3 - -.L\KernelType\().\FilterCount\().ProcessRemainingOutputCount: - add r10,3 # correct for over-subtract above - jz .L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining - cmp r10,2 - jb .L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining - ProcessOutputCountN Fma3, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 2 - lea rdi,[rdi+r9*2] # advance input by 2 elements - sub r10,2 - -// -// Process the output blocks that include right padding plus any remaining output -// blocks from above. -// - -.L\KernelType\().\FilterCount\().ProcessOutputCountRightPadAndRemaining: - add r10,\KernelFrame\()_OutputCountRightPad[rsp] - jz .L\KernelType\().ExitKernel - call MlasConv\KernelType\()FloatSingleFma3Filter\FilterCount\() - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r10 - Supplies the OutputCount parameter (see function description). - - r12 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount - - sub r10,3 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCount - -.LPointwise.\FilterCount\().ProcessNextOutputCountBy3: - ProcessPointwiseOutputCountN Fma3, 8, \FilterCount\(), 3 - lea rax,[r9*2+r9] - add rdi,rax # advance input by 3 elements - sub r10,3 - jae .LPointwise.\FilterCount\().ProcessNextOutputCountBy3 - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount: - add r10,3 # correct for over-subtract above - jz .LPointwise.ExitKernel - cmp r10,2 - jb .LPointwise.\FilterCount\().ProcessRemainingOutputCount1 - ProcessPointwiseOutputCountN Fma3, 8, \FilterCount\(), 2 - jmp .LPointwise.ExitKernel - -.LPointwise.\FilterCount\().ProcessRemainingOutputCount1: - ProcessPointwiseOutputCountN Fma3, 8, \FilterCount\(), 1 - - .endm - -// -// Generate the convolution kernels. -// - - SconvKernelFunction Nchw, 8, Fma3 - SconvKernelFunction Nchwc, 8, Fma3, BiasFilter - SconvKernelDepthwiseFunction 8, Fma3 - SconvKernelPointwiseFunction Fma3, BiasFilter - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S b/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S deleted file mode 100644 index 4dbbf696e96f7..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SconvKernelSse2.S +++ /dev/null @@ -1,353 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SconvKernelSse2.s - -Abstract: - - This module implements the kernels for the single precision convolution - operation. - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" -#include "SconvKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to clear the block accumulators. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - xmm0-xmm7 - Supplies the block accumulators. - ---*/ - - .macro ClearBlock FilterCount, OutputCount - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xorps xmm0,xmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "xorps xmm1,xmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xorps xmm2,xmm2" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "xorps xmm3,xmm3" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xorps xmm4,xmm4" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "xorps xmm5,xmm5" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xorps xmm6,xmm6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "xorps xmm7,xmm7" - - .endm - -/*++ - -Macro Description: - - This macro multiplies and accumulates for FilterCount by OutputCount block - of the output buffer. - -Arguments: - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - - VectorOffset - Supplies the byte offset from the filter buffer to fetch - elements. - - BroadcastOffset - Supplies the byte offset from the input buffer to fetch - elements. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rdx - Supplies the address of the filter buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbx - Supplies the address of the filter buffer plus 2 * FilterStride. - - r9 - Supplies the StrideWidth parameter (see function description). - - xmm0-xmm7 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlock KernelType, FilterCount, OutputCount, VectorOffset, BroadcastOffset - -.ifeqs "\KernelType\()","Depthwise" - movups xmm8,XMMWORD PTR [rdx] - movups xmm9,XMMWORD PTR [rdx+16] - movups xmm10,XMMWORD PTR [rcx] - movups xmm11,XMMWORD PTR [rcx+16] - mulps xmm8,xmm10 - addps xmm0,xmm8 - mulps xmm9,xmm11 - addps xmm1,xmm9 -.else - EmitIfCountGE \OutputCount\(), 1, "movss xmm12,DWORD PTR [rcx+\BroadcastOffset\()]" - EmitIfCountGE \OutputCount\(), 1, "shufps xmm12,xmm12,0" - EmitIfCountGE \FilterCount\(), 1, "movups xmm8,XMMWORD PTR [rdx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 1, "movups xmm9,XMMWORD PTR [rdx+\VectorOffset\()+16]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "mulps xmm8,xmm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm0,xmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "mulps xmm9,xmm12" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm1,xmm9" - EmitIfCountGE \FilterCount\(), 2, "movups xmm8,XMMWORD PTR [rdx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 2, "movups xmm9,XMMWORD PTR [rdx+rsi+\VectorOffset\()+16]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "mulps xmm8,xmm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm2,xmm8" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "mulps xmm9,xmm12" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm3,xmm9" - EmitIfCountGE \FilterCount\(), 3, "movups xmm8,XMMWORD PTR [rbx+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 3, "movups xmm9,XMMWORD PTR [rbx+\VectorOffset\()+16]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "mulps xmm8,xmm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm4,xmm8" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "mulps xmm9,xmm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm5,xmm9" - EmitIfCountGE \FilterCount\(), 4, "movups xmm8,XMMWORD PTR [rbx+rsi+\VectorOffset\()]" - EmitIfCountGE \FilterCount\(), 4, "movups xmm9,XMMWORD PTR [rbx+rsi+\VectorOffset\()+16]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "mulps xmm8,xmm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm6,xmm8" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "mulps xmm9,xmm12" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm7,xmm9" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - KernelType - Supplies the type of kernel to be generated. - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description) when - KernelType!=Depthwise. Supplies the address of the filter buffer when - KernelType=Depthwise. - - rbp - Supplies the DilationWidth parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r15 - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessFilterCountN KernelFrame, KernelType, FilterCount - - mov r10,\KernelFrame\()_OutputCountLeftPad[rsp] - add r10,\KernelFrame\()_OutputCount[rsp] - add r10,\KernelFrame\()_OutputCountRightPad[rsp] - -.L\KernelType\().\FilterCount\().ProcessNextOutputCount: - ProcessOutputCountN Sse, \KernelFrame\(), \KernelType\(), 8, \FilterCount\(), 1 - add rdi,r9 # advance input by 1 element - dec r10 - jnz .L\KernelType\().\FilterCount\().ProcessNextOutputCount - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute the convolution for a specified number - of filter rows for a pointwise convolution. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rsi - Supplies the FilterStride parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - - r8 - Supplies the address of the output buffer. - - r9 - Supplies the StrideWidth parameter (see function description). - - r10 - Supplies the OutputCount parameter (see function description). - - r12 - Supplies the address of the filter buffer. - ---*/ - - .macro ProcessPointwiseFilterCountN FilterCount - -.LPointwise.\FilterCount\().ProcessNextOutputCount: - ProcessPointwiseOutputCountN Sse, 8, \FilterCount\(), 1 - add rdi,r9 # advance input by 1 element - dec r10 - jnz .LPointwise.\FilterCount\().ProcessNextOutputCount - - .endm - -// -// Generate the convolution kernels. -// - - SconvKernelFunction Nchw, 8, Sse - SconvKernelFunction Nchwc, 8, Sse, BiasFilter - SconvKernelDepthwiseFunction 8, Sse - SconvKernelPointwiseFunction Sse, BiasFilter - -/*++ - -Macro Description: - - This macro generates code to process an output block after the inner - convolution kernel has executed and then stores the output block to the - output buffer. - -Arguments: - - FilterCount - Supplies the number of rows from the filter to process. - - OutputCount - Supplies the number of output blocks to produce. - ---*/ - - .macro PostProcessBlock FilterCount, OutputCount - - .globl MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() -#if !defined(__APPLE__) - .hidden MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\() -#endif -MlasConvPostProcessFloatSseFilter\FilterCount\()Output\OutputCount\(): - -.if \FilterCount\() > 2 - lea rbx,[r8+rax*2] # compute output plus 2 rows -.endif - -// -// Test if the existing contents of the output buffer should be accumulated -// with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups xmm8,XMMWORD PTR [r8]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups xmm9,XMMWORD PTR [r8+16]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups xmm10,XMMWORD PTR [r8+rax]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups xmm11,XMMWORD PTR [r8+rax+16]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups xmm12,XMMWORD PTR [rbx]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups xmm13,XMMWORD PTR [rbx+16]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups xmm14,XMMWORD PTR [rbx+rax]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups xmm15,XMMWORD PTR [rbx+rax+16]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm0,xmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm1,xmm9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm2,xmm10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm3,xmm11" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm4,xmm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm5,xmm13" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm6,xmm14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm7,xmm15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipAccumulateOutput: - -// -// Test if the bias buffer should be accumulated with the output block. -// - - test dl,MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups xmm8,XMMWORD PTR [rcx]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups xmm9,XMMWORD PTR [rcx+16]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups xmm10,XMMWORD PTR [rcx+32]" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups xmm11,XMMWORD PTR [rcx+48]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups xmm12,XMMWORD PTR [rcx+64]" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups xmm13,XMMWORD PTR [rcx+80]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups xmm14,XMMWORD PTR [rcx+96]" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups xmm15,XMMWORD PTR [rcx+112]" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm0,xmm8" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "addps xmm1,xmm9" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm2,xmm10" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "addps xmm3,xmm11" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm4,xmm12" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "addps xmm5,xmm13" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm6,xmm14" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "addps xmm7,xmm15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipBiasAddition: - -// -// Test for fused ReLU activation. -// - - test dl,MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION - jz .LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation - xorps xmm15,xmm15 - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "maxps xmm0,xmm15" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "maxps xmm1,xmm15" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "maxps xmm2,xmm15" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "maxps xmm3,xmm15" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "maxps xmm4,xmm15" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "maxps xmm5,xmm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "maxps xmm6,xmm15" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "maxps xmm7,xmm15" - -.LPostProcessBlock.\FilterCount\().\OutputCount\().SkipReluActivation: - -// -// Store the output block in the output buffer. -// - - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups XMMWORD PTR [r8],xmm0" - EmitIfCount2GE \FilterCount\(), 1, \OutputCount\(), 1, "movups XMMWORD PTR [r8+16],xmm1" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups XMMWORD PTR [r8+rax],xmm2" - EmitIfCount2GE \FilterCount\(), 2, \OutputCount\(), 1, "movups XMMWORD PTR [r8+rax+16],xmm3" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups XMMWORD PTR [rbx],xmm4" - EmitIfCount2GE \FilterCount\(), 3, \OutputCount\(), 1, "movups XMMWORD PTR [rbx+16],xmm5" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups XMMWORD PTR [rbx+rax],xmm6" - EmitIfCount2GE \FilterCount\(), 4, \OutputCount\(), 1, "movups XMMWORD PTR [rbx+rax+16],xmm7" - add_immed r8,\OutputCount\()*8*4 # advance output by N nchw8c blocks - ret - - .endm - - .irp FilterCount, 1, 2, 3, 4 - .irp OutputCount, 1 - PostProcessBlock \FilterCount\(), \OutputCount\() - .endr - .endr - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S deleted file mode 100644 index a0a66f330ae71..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelAvx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" -#include "SgemmKernelCommon.h" -#include "FgemmKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelAvxFunction MlasGemmFloatKernelAvx - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx512F.S deleted file mode 100644 index c75df76030b6d..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelAvx512F.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelAvx512F.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" -#include "SgemmKernelCommon.h" -#include "FgemmKernelAvx512FCommon.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelAvx512FFunction MlasGemmFloatKernelAvx512F - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelCommon.h deleted file mode 100644 index 5802028788891..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelCommon.h +++ /dev/null @@ -1,50 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision matrix/matrix multiply operation (SGEMM). - ---*/ - -// -// Define the single precision parameters. -// - - .equ .LFgemmElementShift, 2 - .equ .LFgemmElementSize, 1 << .LFgemmElementShift - -#include "FgemmKernelCommon.h" - -// -// Define the typed instructions for single precision. -// - -FGEMM_TYPED_INSTRUCTION(addpf, addps) -FGEMM_TYPED_INSTRUCTION(movsf, movss) -FGEMM_TYPED_INSTRUCTION(movupf, movups) - -FGEMM_TYPED_INSTRUCTION(vaddpf, vaddps) -FGEMM_TYPED_INSTRUCTION(vbroadcastsf, vbroadcastss) -FGEMM_TYPED_INSTRUCTION(vfmadd213pf, vfmadd213ps) -FGEMM_TYPED_INSTRUCTION(vfmadd231pf, vfmadd231ps) -FGEMM_TYPED_INSTRUCTION(vmaskmovpf, vmaskmovps) -FGEMM_TYPED_INSTRUCTION(vmovapf, vmovaps) -FGEMM_TYPED_INSTRUCTION(vmovsf, vmovss) -FGEMM_TYPED_INSTRUCTION(vmovupf, vmovups) -FGEMM_TYPED_INSTRUCTION(vmulpf, vmulps) -FGEMM_TYPED_INSTRUCTION(vxorpf, vxorps) - - .macro vfmadd231pf_bcst DestReg, SrcReg, Address - - vfmadd231ps \DestReg\(), \SrcReg\(), \Address\(){1to16} - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S deleted file mode 100644 index 4725459323936..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelFma3.S +++ /dev/null @@ -1,34 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelFma3.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "SgemmKernelCommon.h" -#include "FgemmKernelFma3Common.h" - - .intel_syntax noprefix - - .text - -// -// Generate the GEMM kernel. -// - -FgemmKernelFma3Function MlasGemmFloatKernelFma3 - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S deleted file mode 100644 index 5c759847e2de2..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1Avx.S +++ /dev/null @@ -1,267 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelM1Avx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). This handles the special case of M=1. - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. This handles the special case of M=1. - - The elements in matrix B are not transposed. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of rows from matrix B to iterate over. - - CountN (r8) - Supplies the number of columns from matrix B and matrix C to - iterate over. - - ldb (r9) - Supplies the first dimension of matrix B. - - Beta (xmm0) - Supplies the scalar beta multiplier (see SGEMM definition). - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmKernelM1Avx - - push rbx - shl r9,2 # convert ldb to bytes - mov r10,rdx - mov r11,rsi - -// -// Compute the initial results mask for zeroing or accumulate mode. -// - - vxorps xmm1,xmm1,xmm1 - vcmpeqss xmm0,xmm1,xmm0 - vshufps xmm0,xmm0,xmm0,0 - vinsertf128 ymm0,ymm0,xmm0,1 - -// -// Compute the conditional load/store mask for an unaligned CountN. -// - - mov eax,r8d - and eax,7 - vmovd xmm7,eax - vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - vinsertf128 ymm7,ymm7,xmm6,1 - -// -// Process 4 rows of the matrices in a loop. -// - - sub rcx,4 - jb .LProcessRemainingCountK - -.LProcessRowLoop4: - vbroadcastss ymm2,DWORD PTR [rdi] - mov rax,r8 # reload CountN - vbroadcastss ymm3,DWORD PTR [rdi+4] - mov rsi,r11 # reload matrix B - vbroadcastss ymm4,DWORD PTR [rdi+8] - mov rdx,r10 # reload matrix C - vbroadcastss ymm5,DWORD PTR [rdi+12] - add rdi,4*4 # advance matrix A by 4 columns - lea r11,[rsi+r9*4] # advance matrix B by 4 rows - sub rax,16 - jb .LProcessRemainingCountN4 - -.LProcessColumnLoop4: - lea rbx,[rsi+r9*2] # compute matrix B plus 2 rows - vmulps ymm1,ymm2,YMMWORD PTR [rsi] - vmulps ymm6,ymm2,YMMWORD PTR [rsi+32] - vmulps ymm8,ymm3,YMMWORD PTR [rsi+r9] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm3,YMMWORD PTR [rsi+r9+32] - vaddps ymm6,ymm6,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rbx] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rbx+32] - vaddps ymm6,ymm6,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rbx+r9] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rbx+r9+32] - vaddps ymm6,ymm6,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [rdx] - vaddps ymm1,ymm1,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [rdx+32] - vaddps ymm6,ymm6,ymm8 - vmovups YMMWORD PTR [rdx],ymm1 - vmovups YMMWORD PTR [rdx+32],ymm6 - add rsi,16*4 # advance matrix B by 16 columns - add rdx,16*4 # advance matrix C by 16 columns - sub rax,16 - jae .LProcessColumnLoop4 - -.LProcessRemainingCountN4: - test al,15 # test for unaligned columns - jz .LProcessedRemainingCountN4 - test al,8 # CountN >= 8? - jz .LProcessRemainingCountNSmall4 - lea rbx,[rsi+r9*2] # compute matrix B plus 2 rows - vmulps ymm1,ymm2,YMMWORD PTR [rsi] - vmulps ymm8,ymm3,YMMWORD PTR [rsi+r9] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm4,YMMWORD PTR [rbx] - vaddps ymm1,ymm1,ymm8 - vmulps ymm8,ymm5,YMMWORD PTR [rbx+r9] - vaddps ymm1,ymm1,ymm8 - vandnps ymm8,ymm0,YMMWORD PTR [rdx] - vaddps ymm1,ymm1,ymm8 - vmovups YMMWORD PTR [rdx],ymm1 - add rsi,8*4 # advance matrix B by 8 columns - add rdx,8*4 # advance matrix C by 8 columns - test al,7 - jz .LProcessedRemainingCountN4 - -.LProcessRemainingCountNSmall4: - lea rbx,[rsi+r9*2] # compute matrix B plus 2 rows - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+r9] - vmulps ymm8,ymm3,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rbx] - vmulps ymm8,ymm4,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rbx+r9] - vmulps ymm8,ymm5,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [rdx],ymm7,ymm1 - -.LProcessedRemainingCountN4: - vxorps xmm0,xmm0,xmm0 # switch to accumulate mode - sub rcx,4 - jae .LProcessRowLoop4 - -.LProcessRemainingCountK: - test cl,2 - jnz .LProcessRowLoop2 - test cl,1 - jnz .LProcessRowLoop1 - -.LExitKernel: - vzeroupper - pop rbx - ret - -// -// Process 2 rows of the matrices. -// - -.LProcessRowLoop2: - vbroadcastss ymm2,DWORD PTR [rdi] - mov rax,r8 # reload CountN - vbroadcastss ymm3,DWORD PTR [rdi+4] - mov rsi,r11 # reload matrix B - mov rdx,r10 # reload matrix C - add rdi,2*4 # advance matrix A by 2 columns - lea r11,[rsi+r9*2] # advance matrix B by 2 rows - sub rax,8 - jb .LProcessRemainingCountN2 - -.LProcessColumnLoop2: - vmulps ymm1,ymm2,YMMWORD PTR [rsi] - vmulps ymm8,ymm3,YMMWORD PTR [rsi+r9] - vaddps ymm1,ymm1,ymm8 - vandnps ymm6,ymm0,YMMWORD PTR [rdx] - vaddps ymm1,ymm1,ymm6 - vmovups YMMWORD PTR [rdx],ymm1 - add rsi,8*4 # advance matrix B by 8 columns - add rdx,8*4 # advance matrix C by 8 columns - sub rax,8 - jae .LProcessColumnLoop2 - -.LProcessRemainingCountN2: - test al,7 # test for unaligned columns - jz .LProcessedRemainingCountN2 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+r9] - vmulps ymm8,ymm3,ymm6 - vaddps ymm1,ymm1,ymm8 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [rdx],ymm7,ymm1 - -.LProcessedRemainingCountN2: - test cl,1 - jz .LExitKernel - vxorps xmm0,xmm0,xmm0 # switch to accumulate mode - -// -// Process 1 row of the matrices. -// - -.LProcessRowLoop1: - vbroadcastss ymm2,DWORD PTR [rdi] - mov rax,r8 # reload CountN - mov rsi,r11 # reload matrix B - mov rdx,r10 # reload matrix C - sub rax,8 - jb .LProcessRemainingCountN1 - -.LProcessColumnLoop1: - vmulps ymm1,ymm2,YMMWORD PTR [rsi] - vandnps ymm6,ymm0,YMMWORD PTR [rdx] - vaddps ymm1,ymm1,ymm6 - vmovups YMMWORD PTR [rdx],ymm1 - add rsi,8*4 # advance matrix B by 8 columns - add rdx,8*4 # advance matrix C by 8 columns - sub rax,8 - jae .LProcessColumnLoop1 - -.LProcessRemainingCountN1: - test al,7 # test for unaligned columns - jz .LExitKernel - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm1,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rdx] - vandnps ymm6,ymm0,ymm6 - vaddps ymm1,ymm1,ymm6 - vmaskmovps YMMWORD PTR [rdx],ymm7,ymm1 - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S deleted file mode 100644 index b205c3d6d7def..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelM1TransposeBAvx.S +++ /dev/null @@ -1,275 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelM1TransposeBAvx.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). This handles the special case of M=1. - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine is an inner kernel to compute matrix multiplication for a - set of rows. This handles the special case of M=1. - - The elements in matrix B are transposed. - -Arguments: - - A (rdi) - Supplies the address of matrix A. - - B (rsi) - Supplies the address of matrix B. The elements are transposed. - - C (rdx) - Supplies the address of matrix C. - - CountK (rcx) - Supplies the number of columns from matrix A and the number - of columns from matrix B to iterate over. - - CountN (r8) - Supplies the number of rows from matrix B and the number of - columns from matrix C to iterate over. - - ldb (r9) - Supplies the first dimension of matrix B. - - Beta (xmm0) - Supplies the scalar beta multiplier (see SGEMM definition). - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmKernelM1TransposeBAvx - - push rbx - shl r9,2 # convert ldb to bytes - mov r10,rdi - mov r11,rsi - -// -// Compute the results mask for zeroing or accumulate mode. -// - - vxorps xmm1,xmm1,xmm1 - vcmpeqss xmm0,xmm1,xmm0 - vshufps xmm0,xmm0,xmm0,0 - -// -// Compute the conditional load/store mask for an unaligned CountK. -// - - mov eax,ecx - and eax,7 - vmovd xmm7,eax - vshufps xmm7,xmm7,xmm7,0 - vpcmpgtd xmm6,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip+16] - vpcmpgtd xmm7,xmm7,XMMWORD PTR C_UNDERSCORE(MlasMaskMoveAvx)[rip] - vinsertf128 ymm7,ymm7,xmm6,1 - -// -// Process 4 rows of the matrices in a loop. -// - - sub r8,4 - jb .LProcessRemainingCountN - -.LProcessRowLoop4: - vxorps xmm2,xmm2,xmm2 # clear row accumulators - vxorps xmm3,xmm3,xmm3 - vxorps xmm4,xmm4,xmm4 - vxorps xmm5,xmm5,xmm5 - mov rdi,r10 # reload matrix A - mov rsi,r11 # reload matrix B - mov rax,rcx # reload CountK - lea r11,[rsi+r9*4] # advance matrix B by 4 rows - sub rax,8 - jb .LProcessRemainingCountK4 - -.LProcessColumnLoop4: - lea rbx,[rsi+r9*2] # compute matrix B plus 2 rows - vmovups ymm1,YMMWORD PTR [rdi] - vmulps ymm6,ymm1,YMMWORD PTR [rsi] - vaddps ymm2,ymm2,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rsi+r9] - vaddps ymm3,ymm3,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rbx] - vaddps ymm4,ymm4,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rbx+r9] - vaddps ymm5,ymm5,ymm6 - add rdi,8*4 # advance matrix A by 8 columns - add rsi,8*4 # advance matrix B by 8 columns - sub rax,8 - jae .LProcessColumnLoop4 - -.LProcessRemainingCountK4: - test al,7 # test for unaligned columns - jz .LOutput4x1Block - lea rbx,[rsi+r9*2] # compute matrix B plus 2 rows - vmaskmovps ymm1,ymm7,YMMWORD PTR [rdi] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+r9] - vmulps ymm6,ymm1,ymm6 - vaddps ymm3,ymm3,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rbx] - vmulps ymm6,ymm1,ymm6 - vaddps ymm4,ymm4,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rbx+r9] - vmulps ymm6,ymm1,ymm6 - vaddps ymm5,ymm5,ymm6 - -// -// Reduce and output the row accumulators. -// - -.LOutput4x1Block: - vunpcklps ymm6,ymm2,ymm3 # transpose row accumulators - vunpckhps ymm1,ymm2,ymm3 - vunpcklps ymm2,ymm4,ymm5 - vunpckhps ymm3,ymm4,ymm5 - vunpcklpd ymm4,ymm6,ymm2 - vunpckhpd ymm5,ymm6,ymm2 - vaddps ymm4,ymm4,ymm5 - vunpcklpd ymm6,ymm1,ymm3 - vunpckhpd ymm2,ymm1,ymm3 - vaddps ymm4,ymm4,ymm6 - vaddps ymm4,ymm4,ymm2 - vextractf128 xmm5,ymm4,1 - vaddps xmm4,xmm4,xmm5 - vandnps xmm6,xmm0,XMMWORD PTR [rdx] - vaddps xmm4,xmm4,xmm6 - vmovups XMMWORD PTR [rdx],xmm4 - add rdx,4*4 # advance matrix C by 4 columns - sub r8,4 - jae .LProcessRowLoop4 - -.LProcessRemainingCountN: - test r8d,2 - jnz .LProcessRowLoop2 - test r8d,1 - jnz .LProcessRowLoop1 - -.LExitKernel: - vzeroupper - pop rbx - ret - -// -// Process 2 rows of the matrices. -// - -.LProcessRowLoop2: - vxorps xmm2,xmm2,xmm2 # clear row accumulators - vxorps xmm3,xmm3,xmm3 - mov rdi,r10 # reload matrix A - mov rsi,r11 # reload matrix B - mov rax,rcx # reload CountK - lea r11,[rsi+r9*2] # advance matrix B by 2 rows - sub rax,8 - jb .LProcessRemainingCountK2 - -.LProcessColumnLoop2: - vmovups ymm1,YMMWORD PTR [rdi] - vmulps ymm6,ymm1,YMMWORD PTR [rsi] - vaddps ymm2,ymm2,ymm6 - vmulps ymm6,ymm1,YMMWORD PTR [rsi+r9] - vaddps ymm3,ymm3,ymm6 - add rdi,8*4 # advance matrix A by 8 columns - add rsi,8*4 # advance matrix B by 8 columns - sub rax,8 - jae .LProcessColumnLoop2 - -.LProcessRemainingCountK2: - test al,7 # test for unaligned columns - jz .LOutput2x1Block - vmaskmovps ymm1,ymm7,YMMWORD PTR [rdi] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi+r9] - vmulps ymm6,ymm1,ymm6 - vaddps ymm3,ymm3,ymm6 - -// -// Reduce and output the row accumulators. -// - -.LOutput2x1Block: - vunpcklps ymm4,ymm2,ymm3 # reduce row accumulators - vunpckhps ymm2,ymm2,ymm3 - vaddps ymm2,ymm2,ymm4 - vextractf128 xmm4,ymm2,1 - vaddps xmm2,xmm2,xmm4 - vmovhlps xmm4,xmm2,xmm2 - vaddps xmm2,xmm2,xmm4 - vmovsd xmm3,QWORD PTR [rdx] - vandnps xmm3,xmm0,xmm3 - vaddps xmm2,xmm2,xmm3 - vmovsd QWORD PTR [rdx],xmm2 - add rdx,2*4 # advance matrix C by 2 columns - test r8d,1 - jz .LExitKernel - -// -// Process 1 row of the matrices. -// - -.LProcessRowLoop1: - vxorps xmm2,xmm2,xmm2 # clear row accumulators - mov rdi,r10 # reload matrix A - mov rsi,r11 # reload matrix B - mov rax,rcx # reload CountK - sub rax,8 - jb .LProcessRemainingCountK1 - -.LProcessColumnLoop1: - vmovups ymm1,YMMWORD PTR [rdi] - vmulps ymm6,ymm1,YMMWORD PTR [rsi] - vaddps ymm2,ymm2,ymm6 - add rdi,8*4 # advance matrix A by 8 columns - add rsi,8*4 # advance matrix B by 8 columns - sub rax,8 - jae .LProcessColumnLoop1 - -.LProcessRemainingCountK1: - test al,7 # test for unaligned columns - jz .LOutput1x1Block - vmaskmovps ymm1,ymm7,YMMWORD PTR [rdi] - vmaskmovps ymm6,ymm7,YMMWORD PTR [rsi] - vmulps ymm6,ymm1,ymm6 - vaddps ymm2,ymm2,ymm6 - -// -// Reduce and output the row accumulators. -// - -.LOutput1x1Block: - vhaddps ymm2,ymm2,ymm2 # reduce row accumulators - vhaddps ymm2,ymm2,ymm2 - vextractf128 xmm4,ymm2,1 - vaddss xmm2,xmm2,xmm4 - vmovss xmm3,DWORD PTR [rdx] - vandnps xmm3,xmm0,xmm3 - vaddss xmm2,xmm2,xmm3 - vmovss DWORD PTR [rdx],xmm2 - jmp .LExitKernel - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelSse2.S b/onnxruntime/core/mlas/lib/x86_64/SgemmKernelSse2.S deleted file mode 100644 index e605128537748..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmKernelSse2.S +++ /dev/null @@ -1,273 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmKernelSse2.s - -Abstract: - - This module implements the kernels for the single precision matrix/matrix - multiply operation (SGEMM). - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" -#include "SgemmKernelCommon.h" -#include "FgemmKernelSse2Common.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro multiplies and accumulates for a 16xN block of the output matrix. - -Arguments: - - RowCount - Supplies the number of rows to process. - - VectorOffset - Supplies the byte offset from matrix B to fetch elements. - - Shuffle - Supplies the shuffle mask to extract the element from matrix A. - -Implicit Arguments: - - rsi - Supplies the address into the matrix B data. - - xmm0-xmm1 - Supplies up to four elements loaded from matrix A and matrix A - plus one row. - - xmm8-xmm15 - Supplies the block accumulators. - ---*/ - - .macro ComputeBlockSseBy16 RowCount, VectorOffset, Shuffle - - movaps xmm4,XMMWORD PTR [rsi+\VectorOffset\()] - movaps xmm5,XMMWORD PTR [rsi+\VectorOffset\()+16] - pshufd xmm2,xmm0,\Shuffle\() -.if \RowCount\() == 2 - pshufd xmm3,xmm1,\Shuffle\() - movaps xmm6,xmm4 - movaps xmm7,xmm5 -.endif - mulps xmm4,xmm2 - mulps xmm5,xmm2 - addps xmm8,xmm4 - addps xmm9,xmm5 -.if \RowCount\() == 2 - mulps xmm6,xmm3 - mulps xmm7,xmm3 - addps xmm12,xmm6 - addps xmm13,xmm7 -.endif - movaps xmm4,XMMWORD PTR [rsi+\VectorOffset\()+32] - movaps xmm5,XMMWORD PTR [rsi+\VectorOffset\()+48] -.if \RowCount\() == 2 - movaps xmm6,xmm4 - movaps xmm7,xmm5 -.endif - mulps xmm4,xmm2 - mulps xmm5,xmm2 - addps xmm10,xmm4 - addps xmm11,xmm5 -.if \RowCount\() == 2 - mulps xmm6,xmm3 - mulps xmm7,xmm3 - addps xmm14,xmm6 - addps xmm15,xmm7 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute matrix multiplication for a fixed set - of rows. - -Arguments: - - RowCount - Supplies the number of rows to process. - - Fallthrough - Supplies a non-blank value if the macro may fall through to - the ExitKernel label. - -Implicit Arguments: - - rdi - Supplies the address of matrix A. - - rsi - Supplies the address of matrix B. - - r11 - Supplies the address of matrix A. - - r9 - Supplies the number of columns from matrix B and matrix C to iterate - over. - - rdx - Supplies the address of matrix C. - - rcx - Supplies the number of columns from matrix A and the number of rows - from matrix B to iterate over. - - r10 - Supplies the length in bytes of a row from matrix A. - - rax - Supplies the length in bytes of a row from matrix C. - - r15 - Stores the ZeroMode argument from the stack frame. - ---*/ - - .macro ProcessCountM RowCount, Fallthrough - -.LProcessNextColumnLoop16xN\@: - EmitIfCountGE \RowCount\(), 1, "xorps xmm8,xmm8" - EmitIfCountGE \RowCount\(), 1, "xorps xmm9,xmm9" - EmitIfCountGE \RowCount\(), 1, "xorps xmm10,xmm10" - EmitIfCountGE \RowCount\(), 1, "xorps xmm11,xmm11" - EmitIfCountGE \RowCount\(), 2, "xorps xmm12,xmm12" - EmitIfCountGE \RowCount\(), 2, "xorps xmm13,xmm13" - EmitIfCountGE \RowCount\(), 2, "xorps xmm14,xmm14" - EmitIfCountGE \RowCount\(), 2, "xorps xmm15,xmm15" - mov rbp,rcx # reload CountK - sub rbp,4 - jb .LProcessRemaining16xNBlocks\@ - -.LCompute16xNBlockBy4Loop\@: - EmitIfCountGE \RowCount\(), 1, "movups xmm0,XMMWORD PTR [rdi]" - EmitIfCountGE \RowCount\(), 2, "movups xmm1,XMMWORD PTR [rdi+r10]" - ComputeBlockSseBy16 2, 0, 0x00 - ComputeBlockSseBy16 2, 16*4, 0x55 - sub rsi,-32*4 # advance matrix B by 32 columns - ComputeBlockSseBy16 2, 0, 0xAA - ComputeBlockSseBy16 2, 16*4, 0xFF - sub rsi,-32*4 # advance matrix B by 32 columns - add rdi,4*4 # advance matrix A by 4 columns - sub rbp,4 - jae .LCompute16xNBlockBy4Loop\@ - -.LProcessRemaining16xNBlocks\@: - add rbp,4 # correct for over-subtract above - jz .LOutput16xNBlock\@ - -.LCompute16xNBlockBy1Loop\@: - EmitIfCountGE \RowCount\(), 1, "movss xmm0,[rdi]" - EmitIfCountGE \RowCount\(), 2, "movss xmm1,[rdi+r10]" - ComputeBlockSseBy16 2, 0, 0x00 - add rsi,16*4 # advance matrix B by 16 columns - add rdi,4 # advance matrix A by 1 column - dec rbp - jne .LCompute16xNBlockBy1Loop\@ - -.LOutput16xNBlock\@: - movss xmm2,.LFgemmKernelFrame_alpha[rsp] - shufps xmm2,xmm2,0 - EmitIfCountGE \RowCount\(), 1, "mulps xmm8,xmm2" - # multiply by alpha - EmitIfCountGE \RowCount\(), 1, "mulps xmm9,xmm2" - EmitIfCountGE \RowCount\(), 1, "mulps xmm10,xmm2" - EmitIfCountGE \RowCount\(), 1, "mulps xmm11,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulps xmm12,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulps xmm13,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulps xmm14,xmm2" - EmitIfCountGE \RowCount\(), 2, "mulps xmm15,xmm2" - sub r9,16 - jb .LOutputPartial16xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 4 - add rdx,16*4 # advance matrix C by 16 columns - mov rdi,r11 # reload matrix A - test r9,r9 - jnz .LProcessNextColumnLoop16xN\@ - jmp .LExitKernel - -// -// Output a partial 16xN block to the matrix. -// - -.LOutputPartial16xNBlock\@: - add r9,16 # correct for over-subtract above - cmp r9,4 - jb .LOutputPartialLessThan4xNBlock\@ - cmp r9,8 - jb .LOutputPartialLessThan8xNBlock\@ - cmp r9,12 - jb .LOutputPartialLessThan12xNBlock\@ - AccumulateAndStoreBlock \RowCount\(), 3 - and r9d,3 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movaps xmm8,xmm11" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movaps xmm12,xmm15" - add rdx,12*4 # advance matrix C by 12 columns - jmp .LOutputPartialLessThan4xNBlock\@ - -.LOutputPartialLessThan12xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 2 - and r9d,3 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movaps xmm8,xmm10" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movaps xmm12,xmm14" - add rdx,8*4 # advance matrix C by 8 columns - jmp .LOutputPartialLessThan4xNBlock\@ - -.LOutputPartialLessThan8xNBlock\@: - AccumulateAndStoreBlock \RowCount\(), 1 - and r9d,3 # check if remaining count is small - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movaps xmm8,xmm9" - # shift remaining elements down - EmitIfCountGE \RowCount\(), 2, "movaps xmm12,xmm13" - add rdx,4*4 # advance matrix C by 4 columns - -.LOutputPartialLessThan4xNBlock\@: - test r9d,2 - jz .LOutputPartial1xNBlock\@ - test r15b,r15b # ZeroMode? - jnz .LSkipAccumulateOutput2xN\@ - EmitIfCountGE \RowCount\(), 1, "movsd xmm0,QWORD PTR [rdx]" - EmitIfCountGE \RowCount\(), 2, "movsd xmm1,QWORD PTR [rdx+rax]" - EmitIfCountGE \RowCount\(), 1, "addps xmm8,xmm0" - EmitIfCountGE \RowCount\(), 2, "addps xmm12,xmm1" - -.LSkipAccumulateOutput2xN\@: - EmitIfCountGE \RowCount\(), 1, "movsd QWORD PTR [rdx],xmm8" - EmitIfCountGE \RowCount\(), 2, "movsd QWORD PTR [rdx+rax],xmm12" - test r9d,1 # check if remaining count is odd - jz .LExitKernel - EmitIfCountGE \RowCount\(), 1, "movhlps xmm8,xmm8" - # shift third element down - EmitIfCountGE \RowCount\(), 2, "movhlps xmm12,xmm12" - add rdx,2*4 # advance matrix C by 2 columns - -.LOutputPartial1xNBlock\@: - test r15b,r15b # ZeroMode? - jnz .LSkipAccumulateOutput1xN\@ - EmitIfCountGE \RowCount\(), 1, "addss xmm8,[rdx]" - EmitIfCountGE \RowCount\(), 2, "addss xmm12,[rdx+rax]" - -.LSkipAccumulateOutput1xN\@: - EmitIfCountGE \RowCount\(), 1, "movss [rdx],xmm8" - EmitIfCountGE \RowCount\(), 2, "movss [rdx+rax],xmm12" -.ifb \Fallthrough\() - jmp .LExitKernel -.endif - - .endm - -// -// Generate the GEMM kernel. -// - -FgemmKernelSse2Function MlasGemmFloatKernelSse - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Avx.S b/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Avx.S deleted file mode 100644 index 644077838a56e..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Avx.S +++ /dev/null @@ -1,120 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmTransposePackB16x4Avx.s - -Abstract: - - This module implements routines for packing buffers for the single precision - matrix/matrix multiply operation (SGEMM). - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - 4 columns of 8 rows from the source matrix are transposed to 8 columns of 4 - rows in the destination packed buffer. - -Arguments: - - StoreOffset - Supplies the relative byte offset into the destination packed - buffer. - -Implicit Arguments: - - rdi - Supplies the address of the destination packed buffer. - - rsi - Supplies the address of the source matrix. - - rdx - Supplies the number of elements per row of the source matrix. - ---*/ - - .macro TransposePackB8x4BlockAvx StoreOffset - -// -// Load 4 columns from 8 rows of the source matrix into the lower and upper -// halves of 4 YMM registers. -// - - lea rax,[rsi+rdx*2] - vmovups xmm0,XMMWORD PTR [rsi] - vmovups xmm1,XMMWORD PTR [rsi+rdx] - lea rsi,[rax+rdx*2] - vmovups xmm2,XMMWORD PTR [rax] - vmovups xmm3,XMMWORD PTR [rax+rdx] - lea rax,[rsi+rdx*2] - vinsertf128 ymm0,ymm0,XMMWORD PTR [rsi],1 - vinsertf128 ymm1,ymm1,XMMWORD PTR [rsi+rdx],1 - vinsertf128 ymm2,ymm2,XMMWORD PTR [rax],1 - vinsertf128 ymm3,ymm3,XMMWORD PTR [rax+rdx],1 - -// -// Transpose the lower and upper halves of the 4 YMM registers as two 4x4 -// matrices and store the output to the destination packed buffer. -// - - vunpcklps ymm4,ymm0,ymm1 - vunpckhps ymm5,ymm0,ymm1 - vunpcklps ymm0,ymm2,ymm3 - vunpckhps ymm1,ymm2,ymm3 - vunpcklpd ymm2,ymm4,ymm0 - vunpckhpd ymm3,ymm4,ymm0 - vmovaps YMMWORD PTR [rdi+16*4*0+\StoreOffset\()],ymm2 - vmovaps YMMWORD PTR [rdi+16*4*1+\StoreOffset\()],ymm3 - vunpcklpd ymm0,ymm5,ymm1 - vunpckhpd ymm4,ymm5,ymm1 - vmovaps YMMWORD PTR [rdi+16*4*2+\StoreOffset\()],ymm0 - vmovaps YMMWORD PTR [rdi+16*4*3+\StoreOffset\()],ymm4 - - .endm - -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 - rows in the destination packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - B (rsi) - Supplies the address of the source matrix. - - ldb (rdx) - Supplies the number of elements per row of the source matrix. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmTransposePackB16x4Avx - - shl rdx,2 # convert ldb to bytes - TransposePackB8x4BlockAvx 0*4 - lea rsi,[rax+rdx*2] - TransposePackB8x4BlockAvx 8*4 - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Sse2.S b/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Sse2.S deleted file mode 100644 index d3ef6f32376dd..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SgemmTransposePackB16x4Sse2.S +++ /dev/null @@ -1,83 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SgemmTransposePackB16x4Sse2.s - -Abstract: - - This module implements routines for packing buffers for the single precision - matrix/matrix multiply operation (SGEMM). - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine transposes elements from the source matrix to the destination - packed buffer. - - 4 columns of 16 rows from the source matrix are transposed to 16 columns of 4 - rows in the destination packed buffer. - -Arguments: - - D (rdi) - Supplies the address of the destination packed buffer. - - B (rsi) - Supplies the address of the source matrix. - - ldb (rdx) - Supplies the number of elements per row of the source matrix. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasSgemmTransposePackB16x4Sse - - shl rdx,2 # convert ldb to bytes - mov ecx,4 # transpose four 4x4 blocks - -.LTransposeBlockLoop: - lea rax,[rsi+rdx*2] - movups xmm0,XMMWORD PTR [rsi] - movups xmm1,XMMWORD PTR [rsi+rdx] - movups xmm2,XMMWORD PTR [rax] - movups xmm3,XMMWORD PTR [rax+rdx] - movaps xmm4,xmm0 - unpcklps xmm4,xmm1 - unpckhps xmm0,xmm1 - movaps xmm5,xmm2 - unpcklps xmm5,xmm3 - unpckhps xmm2,xmm3 - movaps xmm1,xmm4 - unpcklpd xmm1,xmm5 - unpckhpd xmm4,xmm5 - movaps xmm3,xmm0 - unpcklpd xmm3,xmm2 - unpckhpd xmm0,xmm2 - movaps XMMWORD PTR [rdi+16*4*0],xmm1 - movaps XMMWORD PTR [rdi+16*4*1],xmm4 - movaps XMMWORD PTR [rdi+16*4*2],xmm3 - movaps XMMWORD PTR [rdi+16*4*3],xmm0 - add rdi,4*4 - lea rsi,[rax+rdx*2] - dec ecx - jnz .LTransposeBlockLoop - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S deleted file mode 100644 index 76247ecf7c219..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S +++ /dev/null @@ -1,242 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SoftmaxKernelAvx.s - -Abstract: - - This module implements the kernels for the single precision softmax - operation. - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to find the maximum value of - the supplied buffer. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - N (rsi) - Supplies the number of elements to process. - -Return Value: - - Returns the maximum value of the supplied buffer. - ---*/ - - FUNCTION_ENTRY MlasReduceMaximumF32KernelAvx - - vbroadcastss ymm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] - test rsi,rsi - jz .LReduceMaximum.ExitKernel - cmp rsi,8 - jb .LReduceMaximum.ProcessRemainingCountBy1 - cmp rsi,32 - jb .LReduceMaximum.ProcessRemainingCountBy8 - vmovaps ymm1,ymm0 - vmovaps ymm2,ymm0 - vmovaps ymm3,ymm0 - -.LReduceMaximum.ProcessRemainingCountBy32: - vmaxps ymm0,ymm0,YMMWORD PTR [rdi] - vmaxps ymm1,ymm1,YMMWORD PTR [rdi+8*4] - sub rsi,32 - vmaxps ymm2,ymm2,YMMWORD PTR [rdi+16*4] - vmaxps ymm3,ymm3,YMMWORD PTR [rdi+24*4] - add rdi,32*4 # advance input by 32 elements - cmp rsi,32 - jae .LReduceMaximum.ProcessRemainingCountBy32 - vmaxps ymm0,ymm0,ymm1 # reduce to single vector - vmaxps ymm2,ymm2,ymm3 - vmaxps ymm0,ymm0,ymm2 - -.LReduceMaximum.ProcessRemainingCountBy8: - cmp rsi,8 - jb .LReduceMaximum.ProcessRemainingCountLessThan8 - vmaxps ymm0,ymm0,YMMWORD PTR [rdi] - sub rsi,8 - add rdi,8*4 # advance input by 8 elements - jmp .LReduceMaximum.ProcessRemainingCountBy8 - -.LReduceMaximum.ProcessRemainingCountLessThan8: - vextractf128 xmm1,ymm0,1 # reduce to single scalar - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0xEE - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0x55 - vmaxss xmm0,xmm0,xmm1 - test rsi,rsi - jz .LReduceMaximum.ExitKernel - -.LReduceMaximum.ProcessRemainingCountBy1: - vmaxss xmm0,xmm0,DWORD PTR [rdi] - add rdi,4 # advance input by 1 element - dec esi - jnz .LReduceMaximum.ProcessRemainingCountBy1 - -.LReduceMaximum.ExitKernel: - vzeroupper - ret - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to produce the final output for - the softmax operation. - -Arguments: - - Output (rdi) - Supplies the output buffer. - - N (rsi) - Supplies the number of elements to process. - - Parameters (rdx) - Supplies an array containing the scale value. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeSoftmaxOutputF32KernelAvx - - vbroadcastss ymm4,DWORD PTR [rdx] # broadcast scale value - cmp rsi,32 - jb .LComputeSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeSoftmaxOutput.ProcessRemainingCountBy32: - vmulps ymm0,ymm4,YMMWORD PTR [rdi] - vmulps ymm1,ymm4,YMMWORD PTR [rdi+8*4] - sub rsi,32 - vmulps ymm2,ymm4,YMMWORD PTR [rdi+16*4] - vmulps ymm3,ymm4,YMMWORD PTR [rdi+24*4] - vmovups YMMWORD PTR [rdi],ymm0 - vmovups YMMWORD PTR [rdi+8*4],ymm1 - vmovups YMMWORD PTR [rdi+16*4],ymm2 - vmovups YMMWORD PTR [rdi+24*4],ymm3 - add rdi,32*4 # advance output by 32 elements - cmp rsi,32 - jae .LComputeSoftmaxOutput.ProcessRemainingCountBy32 - -.LComputeSoftmaxOutput.ProcessRemainingCountBy8: - cmp rsi,8 - jb .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8 - vmulps ymm0,ymm4,YMMWORD PTR [rdi] - sub rsi,8 - vmovups YMMWORD PTR [rdi],ymm0 - add rdi,8*4 # advance output by 8 elements - jmp .LComputeSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: - test rsi,rsi - jz .LComputeSoftmaxOutput.ExitKernel - -.LComputeSoftmaxOutput.ProcessRemainingCountBy1: - vmulss xmm0,xmm4,DWORD PTR [rdi] - vmovss DWORD PTR [rdi],xmm0 - add rdi,4 # advance output by 1 element - dec esi - jnz .LComputeSoftmaxOutput.ProcessRemainingCountBy1 - -.LComputeSoftmaxOutput.ExitKernel: - vzeroupper - ret - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to produce the final output for - the log softmax operation. - -Arguments: - - Input (rdi) - Supplies the output buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - - Parameters (rcx) - Supplies an array containing the negative maximum and - logarithm values. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeLogSoftmaxOutputF32KernelAvx - - vbroadcastss ymm4,DWORD PTR [rcx] # broadcast negative minimum value - vbroadcastss ymm5,DWORD PTR [rcx+4] # broadcast log(SumExp) - cmp rdx,32 - jb .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy32: - vaddps ymm0,ymm4,YMMWORD PTR [rdi] - vaddps ymm1,ymm4,YMMWORD PTR [rdi+8*4] - sub rdx,32 - vaddps ymm2,ymm4,YMMWORD PTR [rdi+16*4] - vaddps ymm3,ymm4,YMMWORD PTR [rdi+24*4] - add rdi,32*4 # advance input by 32 elements - vsubps ymm0,ymm0,ymm5 # do as two steps for numeric stability - vsubps ymm1,ymm1,ymm5 - vsubps ymm2,ymm2,ymm5 - vsubps ymm3,ymm3,ymm5 - vmovups YMMWORD PTR [rsi],ymm0 - vmovups YMMWORD PTR [rsi+8*4],ymm1 - vmovups YMMWORD PTR [rsi+16*4],ymm2 - vmovups YMMWORD PTR [rsi+24*4],ymm3 - add rsi,32*4 # advance output by 32 elements - cmp rdx,32 - jae .LComputeLogSoftmaxOutput.ProcessRemainingCountBy32 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy8: - cmp rdx,8 - jb .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8 - vaddps ymm0,ymm4,YMMWORD PTR [rdi] - add rdi,8*4 # advance input by 8 elements - vsubps ymm0,ymm0,ymm5 # do as two steps for numeric stability - sub rdx,8 - vmovups YMMWORD PTR [rsi],ymm0 - add rsi,8*4 # advance output by 8 elements - jmp .LComputeLogSoftmaxOutput.ProcessRemainingCountBy8 - -.LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: - test rdx,rdx - jz .LComputeLogSoftmaxOutput.ExitKernel - -.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: - vaddss xmm0,xmm4,DWORD PTR [rdi] - add rdi,4 # advance input by 1 element - vsubss xmm0,xmm0,xmm5 - vmovss DWORD PTR [rsi],xmm0 - add rsi,4 # advance output by 1 element - dec edx - jnz .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 - -.LComputeLogSoftmaxOutput.ExitKernel: - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S deleted file mode 100644 index db97286046567..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx512F.S +++ /dev/null @@ -1,101 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SoftmaxKernelAvx512F.s - -Abstract: - - This module implements the kernels for the single precision softmax - operation. - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel to find the maximum value of - the supplied buffer. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - N (rsi) - Supplies the number of elements to process. - -Return Value: - - Returns the maximum value of the supplied buffer. - ---*/ - - FUNCTION_ENTRY MlasReduceMaximumF32KernelAvx512F - - vbroadcastss zmm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] - test rsi,rsi - jz .LReduceMaximum.ExitKernel - cmp rsi,16 - jb .LReduceMaximum.ProcessRemainingCountBy1 - cmp rsi,64 - jb .LReduceMaximum.ProcessRemainingCountBy16 - vmovaps zmm1,zmm0 - vmovaps zmm2,zmm0 - vmovaps zmm3,zmm0 - -.LReduceMaximum.ProcessRemainingCountBy64: - vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] - vmaxps zmm1,zmm1,ZMMWORD PTR [rdi+16*4] - sub rsi,64 - vmaxps zmm2,zmm2,ZMMWORD PTR [rdi+32*4] - vmaxps zmm3,zmm3,ZMMWORD PTR [rdi+48*4] - add rdi,64*4 # advance input by 64 elements - cmp rsi,64 - jae .LReduceMaximum.ProcessRemainingCountBy64 - vmaxps zmm0,zmm0,zmm1 # reduce to single vector - vmaxps zmm2,zmm2,zmm3 - vmaxps zmm0,zmm0,zmm2 - -.LReduceMaximum.ProcessRemainingCountBy16: - cmp rsi,16 - jb .LReduceMaximum.ProcessRemainingCountLessThan16 - vmaxps zmm0,zmm0,ZMMWORD PTR [rdi] - sub rsi,16 - add rdi,16*4 # advance input by 16 elements - jmp .LReduceMaximum.ProcessRemainingCountBy16 - -.LReduceMaximum.ProcessRemainingCountLessThan16: - vextractf32x8 ymm1,zmm0,1 # reduce to single scalar - vmaxps ymm0,ymm0,ymm1 - vextractf128 xmm1,ymm0,1 - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0xEE - vmaxps xmm0,xmm0,xmm1 - vshufps xmm1,xmm0,xmm0,0x55 - vmaxss xmm0,xmm0,xmm1 - test rsi,rsi - jz .LReduceMaximum.ExitKernel - -.LReduceMaximum.ProcessRemainingCountBy1: - vmaxss xmm0,xmm0,DWORD PTR [rdi] - add rdi,4 # advance input by 1 element - dec esi - jnz .LReduceMaximum.ProcessRemainingCountBy1 - -.LReduceMaximum.ExitKernel: - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx.S deleted file mode 100644 index 8749090330975..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx.S +++ /dev/null @@ -1,234 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelAvx.s - -Abstract: - - This module implements the kernels for the single precision pooling - operation. - - This implementation uses AVX instructions. - ---*/ - -#include "asmmacro.h" -#include "SpoolKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to initialize registers used across the kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - -Implicit Arguments: - - r9 - Supplies the ActualKernelSize parameter (see function description). - ---*/ - - .macro InitializeKernel PoolingType - -.ifeqs "\PoolingType\()","Maximum" - mov DWORD PTR .LSpoolKernelFrame_BroadcastValue[rsp],0xFF7FFFFF - vbroadcastss ymm5,DWORD PTR .LSpoolKernelFrame_BroadcastValue[rsp] -.else - vxorps xmm5,xmm5,xmm5 # initialize default divisor vector -.ifeqs "\PoolingType\()","AverageExcludePad" - mov rax,.LSpoolKernelFrame_KernelHeight[rsp] - imul rax,.LSpoolKernelFrame_KernelWidth[rsp] - vcvtsi2ss xmm5,xmm5,rax -.else - vcvtsi2ss xmm5,xmm5,r9 -.endif - vshufps xmm5,xmm5,xmm5,0 - vinsertf128 ymm5,ymm5,xmm5,1 # AVX lacks "vbroadcastss ymm5,xmm5" -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to clear the pooling intermediates. - - For PoolingType==Maximum, the pooling intermediates are set to the minimum - float value. Otherwise, the pooling intermediates are cleared to zero. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - ymm0-ymm2 - Supplies the pooling intermediates. - - ymm5 - Supplies a vector containing the minimum float value broadcasted, - if PoolingType==Maximum. - ---*/ - - .macro ClearBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "vmovaps ymm0,ymm5" - EmitIfCountGE \OutputCount\(), 2, "vmovaps ymm1,ymm5" - EmitIfCountGE \OutputCount\(), 3, "vmovaps ymm2,ymm5" -.else - EmitIfCountGE \OutputCount\(), 1, "vxorps xmm0,xmm0,xmm0" - EmitIfCountGE \OutputCount\(), 2, "vxorps xmm1,xmm1,xmm1" - EmitIfCountGE \OutputCount\(), 3, "vxorps xmm2,xmm2,xmm2" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - xor rsi,rsi # reset valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to sample the input buffer and update the pooling - intermediates as appropriate. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - r8 - Supplies the StrideWidth parameter (see function description). - - ymm0-ymm2 - Supplies the pooling intermediates. - ---*/ - - .macro ComputeBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "vmaxps ymm0,ymm0,YMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 2, "vmaxps ymm1,ymm1,YMMWORD PTR [rcx+r8]" - EmitIfCountGE \OutputCount\(), 3, "vmaxps ymm2,ymm2,YMMWORD PTR [rcx+r8*2]" -.else - EmitIfCountGE \OutputCount\(), 1, "vaddps ymm0,ymm0,YMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 2, "vaddps ymm1,ymm1,YMMWORD PTR [rcx+r8]" - EmitIfCountGE \OutputCount\(), 3, "vaddps ymm2,ymm2,YMMWORD PTR [rcx+r8*2]" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - inc rsi # increment valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to process and store the pooling intermediates. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdx - Supplies the address of the output buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - ymm0-ymm2 - Supplies the pooling intermediates. - - ymm5 - Supplies the kernel size computed by InitializeKernel, if - PoolingType=AverageExcludePad, else the actual kernel size, if - PoolingType=AverageIncludePad. - ---*/ - - .macro PostProcessBlock PoolingType, OutputCount - -// -// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -// blocks. OutputCount=1 generates code to count the number of blocks accessed by -// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. -// - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - vxorps xmm4,xmm4,xmm4 - vcvtsi2ss xmm4,xmm4,rsi # convert valid block counter - vshufps xmm4,xmm4,xmm4,0 - vinsertf128 ymm4,ymm4,xmm4,1 # AVX lacks "vbroadcastss ymm4,xmm4" - vdivps ymm0,ymm0,ymm4 -.else - EmitIfCountGE \OutputCount\(), 1, "vdivps ymm0,ymm0,ymm5" - EmitIfCountGE \OutputCount\(), 2, "vdivps ymm1,ymm1,ymm5" - EmitIfCountGE \OutputCount\(), 3, "vdivps ymm2,ymm2,ymm5" -.endif -.endif - -// -// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -// - -.ifeqs "\PoolingType\()","AverageIncludePad" - EmitIfCountGE \OutputCount\(), 1, "vdivps ymm0,ymm0,ymm5" - EmitIfCountGE \OutputCount\(), 2, "vdivps ymm1,ymm1,ymm5" - EmitIfCountGE \OutputCount\(), 3, "vdivps ymm2,ymm2,ymm5" -.endif - -// -// Store the output block in the output buffer. -// - - EmitIfCountGE \OutputCount\(), 1, "vmovups YMMWORD PTR [rdx],ymm0" - EmitIfCountGE \OutputCount\(), 2, "vmovups YMMWORD PTR [rdx+8*4],ymm1" - EmitIfCountGE \OutputCount\(), 3, "vmovups YMMWORD PTR [rdx+16*4],ymm2" - add_immed rdx,\OutputCount\()*8*4 # advance output by N nchw8c blocks - - .endm - -// -// Generate the pooling kernels. -// - - SpoolKernelFunction Maximum, Avx - SpoolKernelFunction AverageExcludePad, Avx - SpoolKernelFunction AverageIncludePad, Avx - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx512F.S deleted file mode 100644 index 9433ce85cb8cd..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvx512F.S +++ /dev/null @@ -1,228 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelAvx512F.s - -Abstract: - - This module implements the kernels for the single precision pooling - operation. - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" -#include "SpoolKernelAvxCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to initialize registers used across the kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - -Implicit Arguments: - - r9 - Supplies the ActualKernelSize parameter (see function description). - ---*/ - - .macro InitializeKernel PoolingType - -.ifeqs "\PoolingType\()","Maximum" - mov DWORD PTR .LSpoolKernelFrame_BroadcastValue[rsp],0xFF7FFFFF - vbroadcastss zmm5,DWORD PTR .LSpoolKernelFrame_BroadcastValue[rsp] -.else - vxorps xmm5,xmm5,xmm5 # initialize default divisor vector -.ifeqs "\PoolingType\()","AverageExcludePad" - mov rax,.LSpoolKernelFrame_KernelHeight[rsp] - imul rax,.LSpoolKernelFrame_KernelWidth[rsp] - vcvtsi2ss xmm5,xmm5,rax -.else - vcvtsi2ss xmm5,xmm5,r9 -.endif - vbroadcastss zmm5,xmm5 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to clear the pooling intermediates. - - For PoolingType==Maximum, the pooling intermediates are set to the minimum - float value. Otherwise, the pooling intermediates are cleared to zero. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - zmm0-zmm2 - Supplies the pooling intermediates. - - zmm5 - Supplies a vector containing the minimum float value broadcasted, - if PoolingType==Maximum. - ---*/ - - .macro ClearBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "vmovaps zmm0,zmm5" - EmitIfCountGE \OutputCount\(), 2, "vmovaps zmm1,zmm5" - EmitIfCountGE \OutputCount\(), 3, "vmovaps zmm2,zmm5" -.else - EmitIfCountGE \OutputCount\(), 1, "vxorps xmm0,xmm0,xmm0" - EmitIfCountGE \OutputCount\(), 2, "vxorps xmm1,xmm1,xmm1" - EmitIfCountGE \OutputCount\(), 3, "vxorps xmm2,xmm2,xmm2" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - xor rsi,rsi # reset valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to sample the input buffer and update the pooling - intermediates as appropriate. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - r8 - Supplies the StrideWidth parameter (see function description). - - zmm0-zmm2 - Supplies the pooling intermediates. - ---*/ - - .macro ComputeBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - EmitIfCountGE \OutputCount\(), 1, "vmaxps zmm0,zmm0,ZMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 2, "vmaxps zmm1,zmm1,ZMMWORD PTR [rcx+r8]" - EmitIfCountGE \OutputCount\(), 3, "vmaxps zmm2,zmm2,ZMMWORD PTR [rcx+r8*2]" -.else - EmitIfCountGE \OutputCount\(), 1, "vaddps zmm0,zmm0,ZMMWORD PTR [rcx]" - EmitIfCountGE \OutputCount\(), 2, "vaddps zmm1,zmm1,ZMMWORD PTR [rcx+r8]" - EmitIfCountGE \OutputCount\(), 3, "vaddps zmm2,zmm2,ZMMWORD PTR [rcx+r8*2]" -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - inc rsi # increment valid block counter -.endif -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to process and store the pooling intermediates. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdx - Supplies the address of the output buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - zmm0-zmm2 - Supplies the pooling intermediates. - - zmm5 - Supplies the kernel size computed by InitializeKernel, if - PoolingType=AverageExcludePad, else the actual kernel size, if - PoolingType=AverageIncludePad. - ---*/ - - .macro PostProcessBlock PoolingType, OutputCount - -// -// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -// blocks. OutputCount=1 generates code to count the number of blocks accessed by -// ComputeBlock. Other cases use the kernel size computed by InitializeKernel. -// - -.ifeqs "\PoolingType\()","AverageExcludePad" -.if \OutputCount\() == 1 - vxorps xmm4,xmm4,xmm4 - vcvtsi2ss xmm4,xmm4,rsi # convert valid block counter - vbroadcastss zmm4,xmm4 - vdivps zmm0,zmm0,zmm4 -.else - EmitIfCountGE \OutputCount\(), 1, "vdivps zmm0,zmm0,zmm5" - EmitIfCountGE \OutputCount\(), 2, "vdivps zmm1,zmm1,zmm5" - EmitIfCountGE \OutputCount\(), 3, "vdivps zmm2,zmm2,zmm5" -.endif -.endif - -// -// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -// - -.ifeqs "\PoolingType\()","AverageIncludePad" - EmitIfCountGE \OutputCount\(), 1, "vdivps zmm0,zmm0,zmm5" - EmitIfCountGE \OutputCount\(), 2, "vdivps zmm1,zmm1,zmm5" - EmitIfCountGE \OutputCount\(), 3, "vdivps zmm2,zmm2,zmm5" -.endif - - EmitIfCountGE \OutputCount\(), 1, "vmovups ZMMWORD PTR [rdx],zmm0" - EmitIfCountGE \OutputCount\(), 2, "vmovups ZMMWORD PTR [rdx+16*4],zmm1" - EmitIfCountGE \OutputCount\(), 3, "vmovups ZMMWORD PTR [rdx+32*4],zmm2" - add_immed rdx,\OutputCount\()*16*4 # advance output by N nchw16c blocks - - .endm - -// -// Generate the pooling kernels. -// - - SpoolKernelFunction Maximum, Avx512F - SpoolKernelFunction AverageExcludePad, Avx512F - SpoolKernelFunction AverageIncludePad, Avx512F - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvxCommon.h b/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvxCommon.h deleted file mode 100644 index 68de6acc7ad35..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelAvxCommon.h +++ /dev/null @@ -1,143 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelAvxCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision pooling operation for the AVX and AVX512F kernels. - ---*/ - -#include "SpoolKernelCommon.h" - -/*++ - -Macro Description: - - This macro generates code for the inner pooling kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SpoolKernelFunction PoolingType, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute pooling for the elements of an - output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Output (rsi) - Supplies the address of the output buffer. - - StrideWidth (rdx) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (rcx) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (r8) - Supplies the length in bytes to advance the input buffer to - the next input row. - - ActualKernelSize (r9) - Supplies the size of the kernel based on the original - kernel dimensions, used for PoolingType=AverageIncludePad. - - KernelHeight - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth - Supplies the width of the kernel to apply. - - InputBase - Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth - Supplies the length in bytes of the blocked input width. - - DilatedInputWidth - Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad - Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount - Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad - Supplies the number of output elements that include - one or more padding elements from the right edge. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() - - SpoolKernelEntry \PoolingType\() - -.L\PoolingType\().ProcessOutputCountLeftPad: - mov r10,.LSpoolKernelFrame_OutputCountLeftPad[rsp] - test r10,r10 - jz .L\PoolingType\().ProcessOutputCount - call MlasPool\PoolingType\()FloatSingle\Isa\() - -.L\PoolingType\().ProcessOutputCount: - mov r10,.LSpoolKernelFrame_OutputCount[rsp] - sub r10,3 - jb .L\PoolingType\().ProcessRemainingOutputCount - -.L\PoolingType\().ProcessNextOutputCountBy3: - ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 3 - lea rax,[r8*2+r8] - add rdi,rax # advance input by 3 elements - sub r10,3 - jae .L\PoolingType\().ProcessNextOutputCountBy3 - -.L\PoolingType\().ProcessRemainingOutputCount: - add r10,3 # correct for over-subtract above - -.L\PoolingType\().ProcessOutputCountRightPad: - add r10,.LSpoolKernelFrame_OutputCountRightPad[rsp] - jz .L\PoolingType\().ExitKernel - call MlasPool\PoolingType\()FloatSingle\Isa\() - -.L\PoolingType\().ExitKernel: - vzeroupper - SpoolKernelExit - -// -// Generate out-of-band helpers for handling output blocks involving padding. -// - -MlasPool\PoolingType\()FloatSingle\Isa\(): - ProcessOutputCountN .LSpoolKernelSingleFrame, \PoolingType\(), 1 - add rdi,r8 # advance input by 1 element - dec r10 # decrement output count remaining - jnz MlasPool\PoolingType\()FloatSingle\Isa\() - ret - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/SpoolKernelCommon.h deleted file mode 100644 index 18ab55c9e35c8..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelCommon.h +++ /dev/null @@ -1,176 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the single - precision pooling operation. - ---*/ - -// -// Stack frame layout for the pooling kernels. -// - - .equ .LSpoolKernelFrame_BroadcastValue, -8 - .equ .LSpoolKernelFrame_SavedR12, 0 - .equ .LSpoolKernelFrame_SavedR13, 8 - .equ .LSpoolKernelFrame_SavedR14, 16 - .equ .LSpoolKernelFrame_SavedRbx, 24 - .equ .LSpoolKernelFrame_SavedRbp, 32 - .equ .LSpoolKernelFrame_ReturnAddress, 40 - .equ .LSpoolKernelFrame_KernelHeight, 48 - .equ .LSpoolKernelFrame_KernelWidth, 56 - .equ .LSpoolKernelFrame_InputBase, 64 - .equ .LSpoolKernelFrame_InputWidth, 72 - .equ .LSpoolKernelFrame_DilatedInputWidth, 80 - .equ .LSpoolKernelFrame_OutputCountLeftPad, 88 - .equ .LSpoolKernelFrame_OutputCount, 96 - .equ .LSpoolKernelFrame_OutputCountRightPad, 104 - - .equ .LSpoolKernelSingleFrame_ReturnAddress, 0 - .equ .LSpoolKernelSingleFrame_SavedR12, 8 - .equ .LSpoolKernelSingleFrame_SavedR13, 16 - .equ .LSpoolKernelSingleFrame_SavedR14, 24 - .equ .LSpoolKernelSingleFrame_SavedRbx, 32 - .equ .LSpoolKernelSingleFrame_SavedRbp, 40 - .equ .LSpoolKernelSingleFrame_ParentReturnAddress, 48 - .equ .LSpoolKernelSingleFrame_KernelHeight, 56 - .equ .LSpoolKernelSingleFrame_KernelWidth, 64 - .equ .LSpoolKernelSingleFrame_InputBase, 72 - .equ .LSpoolKernelSingleFrame_InputWidth, 80 - .equ .LSpoolKernelSingleFrame_DilatedInputWidth, 88 - .equ .LSpoolKernelSingleFrame_OutputCountLeftPad, 96 - .equ .LSpoolKernelSingleFrame_OutputCount, 104 - .equ .LSpoolKernelSingleFrame_OutputCountRightPad, 112 - -/*++ - -Macro Description: - - This macro generates the common prologue code for the pooling kernels. - -Arguments: - - PoolingType - Supplies the pooling type string. - ---*/ - - .macro SpoolKernelEntry PoolingType - - push rbp - push rbx - push r14 - push r13 - push r12 - - InitializeKernel \PoolingType\() - mov rbp,r8 # shuffle to Win64 register usage - mov r8,rdx - mov r9,rcx - mov rdx,rsi - - .endm - -/*++ - -Macro Description: - - This macro generates the common epilogue code for the pooling kernels. - -Arguments: - - None. - ---*/ - - .macro SpoolKernelExit - - pop r12 - pop r13 - pop r14 - pop rbx - pop rbp - ret - - .endm - -/*++ - -Macro Description: - - This macro generates code to compute pooling for a vector of input blocks - to produce a matrix of output blocks. - - OutputCount=1 generates special case code to handle padding blocks. All - other output counts assume no padding. - -Arguments: - - KernelFrame - Supplies the symbol name to access the convolution kernel - stack. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdi - Supplies the address of the input buffer. - - rdx - Supplies the address of the output buffer. - - r8 - Supplies the StrideWidth parameter (see function description). - - r9 - Supplies the DilationWidth parameter (see function description). - - rbp - Supplies the InputStride parameter (see function description). - ---*/ - - .macro ProcessOutputCountN KernelFrame, PoolingType, OutputCount - - mov rcx,rdi - mov r11,\KernelFrame\()_KernelHeight[rsp] - mov r12,\KernelFrame\()_KernelWidth[rsp] -.if \OutputCount\() == 1 - mov r13,\KernelFrame\()_InputBase[rsp] - mov r14,\KernelFrame\()_InputWidth[rsp] - neg r13 # keep negative for lea usage below -.endif - ClearBlock \PoolingType\(), \OutputCount\() - test r11,r11 # zero sized kernel? - jz .L\PoolingType\().\OutputCount\().HandlePostProcessing - -.L\PoolingType\().\OutputCount\().ProcessNextRow: - mov rax,r12 - -.L\PoolingType\().\OutputCount\().ProcessNextColumn: -.if \OutputCount\() == 1 - lea rbx,[rcx+r13] # compute (Input - InputBase) - cmp rbx,r14 # (Input - InputBase) >= InputWidth? - jae .L\PoolingType\().\OutputCount\().SkipOverPadding -.endif - ComputeBlock \PoolingType\(), \OutputCount\() - -.L\PoolingType\().\OutputCount\().SkipOverPadding: - add rcx,r9 # advance input by dilation width - dec rax # decrement columns remaining - jnz .L\PoolingType\().\OutputCount\().ProcessNextColumn - add rcx,rbp # advance input to next row -.if \OutputCount\() == 1 - sub r13,\KernelFrame\()_DilatedInputWidth[rsp] - # advance input base to next row -.endif - dec r11 - jnz .L\PoolingType\().\OutputCount\().ProcessNextRow - -.L\PoolingType\().\OutputCount\().HandlePostProcessing: - PostProcessBlock \PoolingType\(), \OutputCount\() - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelSse2.S b/onnxruntime/core/mlas/lib/x86_64/SpoolKernelSse2.S deleted file mode 100644 index 285da3072c17a..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/SpoolKernelSse2.S +++ /dev/null @@ -1,306 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - SpoolKernelSse2.s - -Abstract: - - This module implements the kernels for the single precision pooling - operation. - - This implementation uses SSE2 instructions. - ---*/ - -#include "asmmacro.h" -#include "SpoolKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Macro Description: - - This macro generates code to initialize registers used across the kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - ---*/ - - .macro InitializeKernel PoolingType - -.ifeqs "\PoolingType\()","Maximum" - mov eax,0xFF7FFFFF - movd xmm5,eax - shufps xmm5,xmm5,0 -.endif - -.ifeqs "\PoolingType\()","AverageIncludePad" - cvtsi2ss xmm5,r9 - shufps xmm5,xmm5,0 -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to clear the pooling intermediates. - - For PoolingType==Maximum, the pooling intermediates are set to the minimum - float value. Otherwise, the pooling intermediates are cleared to zero. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - xmm0-xmm1 - Supplies the pooling intermediates. - - xmm5 - Supplies a vector containing the minimum float value broadcasted, - if PoolingType==Maximum. - ---*/ - - .macro ClearBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - movaps xmm0,xmm5 - movaps xmm1,xmm5 -.else - xorps xmm0,xmm0 - xorps xmm1,xmm1 -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" - xor rsi,rsi # reset valid block counter -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to sample the input buffer and update the pooling - intermediates as appropriate. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rcx - Supplies the address of the input buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - r8 - Supplies the StrideWidth parameter (see function description). - - xmm0-xmm1 - Supplies the pooling intermediates. - ---*/ - - .macro ComputeBlock PoolingType, OutputCount - -.ifeqs "\PoolingType\()","Maximum" - maxps xmm0,XMMWORD PTR [rcx] - maxps xmm1,XMMWORD PTR [rcx+16] -.else - addps xmm0,XMMWORD PTR [rcx] - addps xmm1,XMMWORD PTR [rcx+16] -.endif - -.ifeqs "\PoolingType\()","AverageExcludePad" - inc rsi # increment valid block counter -.endif - - .endm - -/*++ - -Macro Description: - - This macro generates code to process and store the pooling intermediates. - -Arguments: - - PoolingType - Supplies the pooling type string. - - OutputCount - Supplies the number of output blocks to produce. - -Implicit Arguments: - - rdx - Supplies the address of the output buffer. - - rsi - Supplies the number of blocks accessed by ComputeBlock, if - PoolingType=AverageExcludePad and OutputCount=1. - - xmm0-xmm1 - Supplies the pooling intermediates. - - xmm5 - Supplies the kernel size computed by InitializeKernel, if - PoolingType=AverageExcludePad, else the actual kernel size, if - PoolingType=AverageIncludePad. - ---*/ - - .macro PostProcessBlock PoolingType, OutputCount - -// -// If PoolingType=AverageExcludePad, divide the sum by the number of non-padding -// blocks. -// - -.ifeqs "\PoolingType\()","AverageExcludePad" - xorps xmm4,xmm4 - cvtsi2ss xmm4,rsi # convert valid block counter - shufps xmm4,xmm4,0 - divps xmm0,xmm4 - divps xmm1,xmm4 -.endif - -// -// If PoolingType=AverageIncludePad, divide the sum by the actual kernel size. -// - -.ifeqs "\PoolingType\()","AverageIncludePad" - divps xmm0,xmm5 - divps xmm1,xmm5 -.endif - -// -// Store the output block in the output buffer. -// - - movups XMMWORD PTR [rdx],xmm0 - movups XMMWORD PTR [rdx+16],xmm1 - add rdx,8*4 # advance output by 1 nchw8c block - - .endm - -/*++ - -Macro Description: - - This macro generates code for the inner pooling kernel. - -Arguments: - - PoolingType - Supplies the pooling type string. - - Isa - Supplies the instruction set architecture string for function tags. - ---*/ - - .macro SpoolKernelFunction PoolingType, Isa - -/*++ - -Routine Description: - - This routine is the inner kernel to compute pooling for the elements of an - output row for a set of filter rows. - -Arguments: - - Input (rdi) - Supplies the address of the input buffer. - - The address is biased to include padding blocks for the left width - dimension. The address is not biased to include padding rows for the - left height dimension these are accounted for in the outer kernel. - - Output (rsi) - Supplies the address of the output buffer. - - StrideWidth (rdx) - Supplies the length in bytes of the blocked stride width. - - DilationWidth (rcx) - Supplies the length in bytes of the blocked dilation - width. - - InputStride (r8) - Supplies the length in bytes to advance the input buffer to - the next input row. - - ActualKernelSize (r9) - Supplies the size of the kernel based on the original - kernel dimensions, used for PoolingType=AverageIncludePad. - - KernelHeight - Supplies the height of the kernel to apply. This height may - be less than the original kernel height after removing any padding - rows. - - KernelWidth - Supplies the width of the kernel to apply. - - InputBase - Supplies the address of the valid input buffer. - - This parameter is similar to the Input parameter, but does not include - the padding blocks for the left width dimension. This parameter is used - with the following InputWidth parameter in order to validate that the - current input buffer address in bounds and not in the left or right - width padding region. - - InputWidth - Supplies the length in bytes of the blocked input width. - - DilatedInputWidth - Supplies the length in bytes to advance the input base - buffer to the next input row including dilation. - - OutputCountLeftPad - Supplies the number of output elements that include - one or more padding elements from the left edge. - - OutputCount - Supplies the number of output elements that do not include - any padding elements. - - OutputCountRightPad - Supplies the number of output elements that include - one or more padding elements from the right edge. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasPool\PoolingType\()FloatKernel\Isa\() - - SpoolKernelEntry \PoolingType\() - - mov r10,.LSpoolKernelFrame_OutputCountLeftPad[rsp] - add r10,.LSpoolKernelFrame_OutputCount[rsp] - add r10,.LSpoolKernelFrame_OutputCountRightPad[rsp] - jz .L\PoolingType\().ExitKernel - -.L\PoolingType\().ProcessNextOutputCount: - ProcessOutputCountN .LSpoolKernelFrame, \PoolingType\(), 1 - add rdi,r8 # advance input by 1 element - dec r10 - jnz .L\PoolingType\().ProcessNextOutputCount - -.L\PoolingType\().ExitKernel: - SpoolKernelExit - - .endm - -// -// Generate the pooling kernels. -// - - SpoolKernelFunction Maximum, Sse - SpoolKernelFunction AverageExcludePad, Sse - SpoolKernelFunction AverageIncludePad, Sse - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S deleted file mode 100644 index d7c2fd1c6e1dd..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S +++ /dev/null @@ -1,118 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - TanhKernelFma3.s - -Abstract: - - This module implements a kernel for computing the hyperbolic tangent - function for a buffer of elements. - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "TransKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the hyperbolic tangent - function. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeTanhF32KernelFma3 - - lea rax,C_UNDERSCORE(MlasTanhConstants)[rip] - vbroadcastss ymm4,.LTanhConstants_LowerRange[rax] - vbroadcastss ymm5,.LTanhConstants_UpperRange[rax] - vbroadcastss ymm6,.LTanhConstants_alpha_13[rax] - vbroadcastss ymm7,.LTanhConstants_alpha_11[rax] - vbroadcastss ymm8,.LTanhConstants_alpha_9[rax] - vbroadcastss ymm9,.LTanhConstants_alpha_7[rax] - vbroadcastss ymm10,.LTanhConstants_alpha_5[rax] - vbroadcastss ymm11,.LTanhConstants_alpha_3[rax] - vbroadcastss ymm12,.LTanhConstants_alpha_1[rax] - vbroadcastss ymm13,.LTanhConstants_beta_6[rax] - vbroadcastss ymm14,.LTanhConstants_beta_2[rax] - vbroadcastss ymm15,.LTanhConstants_beta_0[rax] - - sub rdx,8 - jb .LProcessRemainingCount - -.LComputeTanhBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rdi] # clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LTanhConstants_beta_4[rax] - vfmadd231ps ymm2,ymm1,ymm6 # p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm2,ymm1,ymm8 # p = x2 * p + alpha_9 - vfmadd213ps ymm2,ymm1,ymm9 # p = x2 * p + alpha_7 - vfmadd213ps ymm2,ymm1,ymm10 # p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm11 # p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm12 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 # q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 # p = x * p - vdivps ymm0,ymm2,ymm3 # tanh = p / q - add rdi,8*4 # advance input by 8 elements - vmovups YMMWORD PTR [rsi],ymm0 - add rsi,8*4 # advance output by 8 elements - sub rdx,8 - jae .LComputeTanhBy8Loop - -.LProcessRemainingCount: - add rdx,8 # correct for over-subtract above - jz .LExitKernel - neg rdx - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovups ymm2,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] - vmaxps ymm0,ymm4,ymm0 # clamp lower bound - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LTanhConstants_beta_4[rax] - vfmadd231ps ymm7,ymm1,ymm6 # p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm7,ymm1,ymm8 # p = x2 * p + alpha_9 - vfmadd213ps ymm7,ymm1,ymm9 # p = x2 * p + alpha_7 - vfmadd213ps ymm7,ymm1,ymm10 # p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm11 # p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm12 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 # q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 # p = x * p - vdivps ymm0,ymm7,ymm3 # tanh = p / q - vmaskmovps YMMWORD PTR [rsi],ymm2,ymm0 - -.LExitKernel: - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/TransKernelAvx512F.S b/onnxruntime/core/mlas/lib/x86_64/TransKernelAvx512F.S deleted file mode 100644 index 64b6204249a35..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/TransKernelAvx512F.S +++ /dev/null @@ -1,267 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - TransKernelAvx512F.s - -Abstract: - - This module implements kernels for various transcendental functions. - - This implementation uses AVX512F instructions. - ---*/ - -#include "asmmacro.h" -#include "TransKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the exponential function. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeExpF32KernelAvx512F - - lea rax,C_UNDERSCORE(MlasExpConstants)[rip] - vbroadcastss zmm21,.LExpConstants_LowerRange[rax] - vbroadcastss zmm22,.LExpConstants_RoundingBias[rax] - vbroadcastss zmm23,.LExpConstants_Log2Reciprocal[rax] - vbroadcastss zmm24,.LExpConstants_Log2High[rax] - vbroadcastss zmm25,.LExpConstants_Log2Low[rax] - vbroadcastss zmm26,.LExpConstants_poly_0[rax] - vbroadcastss zmm27,.LExpConstants_poly_1[rax] - vbroadcastss zmm28,.LExpConstants_poly_2[rax] - vbroadcastss zmm29,.LExpConstants_poly_3[rax] - vbroadcastss zmm30,.LExpConstants_poly_4[rax] - vbroadcastss zmm31,.LExpConstants_poly_56[rax] - - sub rdx,16 - jb .LComputeExp.ProcessRemainingCount - -.LComputeExp.ComputeExpBy16Loop: - vmaxps zmm16,zmm21,ZMMWORD PTR [rdi] # clamp lower bound - vmovaps zmm18,zmm23 - vfmadd213ps zmm18,zmm16,zmm22 # (input / ln2) plus rounding bias - vmovaps zmm17,zmm26 # p = poly_0 - vsubps zmm18,zmm18,zmm22 # m = round(input / ln2) - vfmadd231ps zmm16,zmm18,zmm24 # range reduce: x -= (m * ln2_high) - vfmadd231ps zmm16,zmm18,zmm25 # range reduce: x -= (m * ln2_low) - vmovaps zmm17,zmm26 # p = poly_0 - vfmadd213ps zmm17,zmm16,zmm27 # p = p * x + poly_1 - vfmadd213ps zmm17,zmm16,zmm28 # p = p * x + poly_2 - vfmadd213ps zmm17,zmm16,zmm29 # p = p * x + poly_3 - vfmadd213ps zmm17,zmm16,zmm30 # p = p * x + poly_4 - vfmadd213ps zmm17,zmm16,zmm31 # p = p * x + poly_5 - vfmadd213ps zmm17,zmm16,zmm31 # p = p * x + poly_6 - vscalefps zmm17,zmm17,zmm18 # scale p with exponent - add rdi,16*4 # advance input by 16 elements - vmovups ZMMWORD PTR [rsi],zmm17 - add rsi,16*4 # advance output by 16 elements - sub rdx,16 - jae .LComputeExp.ComputeExpBy16Loop - -.LComputeExp.ProcessRemainingCount: - add rdx,16 # correct for over-subtract above - jz .LComputeExp.ExitKernel - lea r10,C_UNDERSCORE(MlasOpmask16BitTableAvx512)[rip] - kmovw k1,WORD PTR [r10+rdx*2] - vmaxps zmm16{k1}{z},zmm21,ZMMWORD PTR [rdi] - # clamp lower bound - vfmadd213ps zmm23,zmm16,zmm22 # (input / ln2) plus rounding bias - vsubps zmm23,zmm23,zmm22 # round(input / ln2) - vfmadd231ps zmm16,zmm23,zmm24 # range reduce: x -= (m * ln2_high) - vfmadd231ps zmm16,zmm23,zmm25 # range reduce: x -= (m * ln2_low) - vfmadd213ps zmm26,zmm16,zmm27 # p = p * x + poly_1 - vfmadd213ps zmm26,zmm16,zmm28 # p = p * x + poly_2 - vfmadd213ps zmm26,zmm16,zmm29 # p = p * x + poly_3 - vfmadd213ps zmm26,zmm16,zmm30 # p = p * x + poly_4 - vfmadd213ps zmm26,zmm16,zmm31 # p = p * x + poly_5 - vfmadd213ps zmm26,zmm16,zmm31 # p = p * x + poly_6 - vscalefps zmm26,zmm26,zmm23 # scale p with exponent - vmovups ZMMWORD PTR [rsi]{k1},zmm26 - -.LComputeExp.ExitKernel: - ret - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the sum of exponential - functions. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Optionally supplies the output buffer. When used for Softmax, - the output buffer is used to store the intermediate exp() results. When - used for LogSoftmax, the intermediate exp() results are not required. - - N (rdx) - Supplies the number of elements to process. - - NegativeMaximum (rcx) - Supplies the address of the negative maximum that - is added to each element before computing the exponential function. - -Return Value: - - Returns the sum of the exponential functions. - ---*/ - - FUNCTION_ENTRY MlasComputeSumExpF32KernelAvx512F - - lea rax,C_UNDERSCORE(MlasExpConstants)[rip] - vbroadcastss zmm21,.LExpConstants_LowerRange[rax] - vbroadcastss zmm22,.LExpConstants_RoundingBias[rax] - vbroadcastss zmm23,.LExpConstants_Log2Reciprocal[rax] - vbroadcastss zmm24,.LExpConstants_Log2High[rax] - vbroadcastss zmm25,.LExpConstants_Log2Low[rax] - vbroadcastss zmm26,.LExpConstants_poly_0[rax] - vbroadcastss zmm27,.LExpConstants_poly_1[rax] - vbroadcastss zmm28,.LExpConstants_poly_2[rax] - vbroadcastss zmm29,.LExpConstants_poly_3[rax] - vbroadcastss zmm30,.LExpConstants_poly_4[rax] - vbroadcastss zmm31,.LExpConstants_poly_56[rax] - - vbroadcastss zmm19,DWORD PTR [rcx] # broadcast negative maximum value - vpxord zmm20,zmm20,zmm20 # clear exp() accumulator - sub rdx,48 - jb .LComputeSumExp.ProcessRemainingCount - -.LComputeSumExp.ComputeExpBy48Loop: - vaddps zmm0,zmm19,ZMMWORD PTR [rdi] # bias by negative maximum value - vaddps zmm3,zmm19,ZMMWORD PTR [rdi+64] - vaddps zmm16,zmm19,ZMMWORD PTR [rdi+128] - vmaxps zmm0,zmm21,zmm0 # clamp lower bound - vmovaps zmm2,zmm23 - vmaxps zmm3,zmm21,zmm3 - vmovaps zmm5,zmm23 - vmaxps zmm16,zmm21,zmm16 - vmovaps zmm18,zmm23 - vfmadd213ps zmm2,zmm0,zmm22 # (input / ln2) plus rounding bias - vfmadd213ps zmm5,zmm3,zmm22 - vfmadd213ps zmm18,zmm16,zmm22 - vmovaps zmm1,zmm26 # p = poly_0 - vmovaps zmm4,zmm26 - vmovaps zmm17,zmm26 - vsubps zmm2,zmm2,zmm22 # m = round(input / ln2) - vsubps zmm5,zmm5,zmm22 - vsubps zmm18,zmm18,zmm22 - vfmadd231ps zmm0,zmm2,zmm24 # range reduce: x -= (m * ln2_high) - vfmadd231ps zmm3,zmm5,zmm24 - vfmadd231ps zmm16,zmm18,zmm24 - vfmadd231ps zmm0,zmm2,zmm25 # range reduce: x -= (m * ln2_low) - vfmadd231ps zmm3,zmm5,zmm25 - vfmadd231ps zmm16,zmm18,zmm25 - vfmadd213ps zmm1,zmm0,zmm27 # p = p * x + poly_1 - vfmadd213ps zmm4,zmm3,zmm27 - vfmadd213ps zmm17,zmm16,zmm27 - vfmadd213ps zmm1,zmm0,zmm28 # p = p * x + poly_2 - vfmadd213ps zmm4,zmm3,zmm28 - vfmadd213ps zmm17,zmm16,zmm28 - vfmadd213ps zmm1,zmm0,zmm29 # p = p * x + poly_3 - vfmadd213ps zmm4,zmm3,zmm29 - vfmadd213ps zmm17,zmm16,zmm29 - vfmadd213ps zmm1,zmm0,zmm30 # p = p * x + poly_4 - vfmadd213ps zmm4,zmm3,zmm30 - vfmadd213ps zmm17,zmm16,zmm30 - vfmadd213ps zmm1,zmm0,zmm31 # p = p * x + poly_5 - vfmadd213ps zmm4,zmm3,zmm31 - vfmadd213ps zmm17,zmm16,zmm31 - vfmadd213ps zmm1,zmm0,zmm31 # p = p * x + poly_6 - vfmadd213ps zmm4,zmm3,zmm31 - vfmadd213ps zmm17,zmm16,zmm31 - vscalefps zmm1,zmm1,zmm2 - vscalefps zmm4,zmm4,zmm5 - vscalefps zmm17,zmm17,zmm18 - vaddps zmm20,zmm20,zmm1 # accumulate exp() results - vaddps zmm20,zmm20,zmm4 - vaddps zmm20,zmm20,zmm17 - add rdi,48*4 # advance input by 48 elements - test rsi,rsi - jz .LComputeSumExp.SkipStoreResultsBy48 - vmovups ZMMWORD PTR [rsi],zmm1 - vmovups ZMMWORD PTR [rsi+64],zmm4 - vmovups ZMMWORD PTR [rsi+128],zmm17 - add rsi,48*4 # advance output by 48 elements - -.LComputeSumExp.SkipStoreResultsBy48: - sub rdx,48 - jae .LComputeSumExp.ComputeExpBy48Loop - -.LComputeSumExp.ProcessRemainingCount: - add rdx,48 # correct for over-subtract above - jz .LComputeSumExp.ReduceAccumulator - mov eax,-1 - kmovw k1,eax # update mask to access all elements - -.LComputeSumExp.ComputeExpBy16Loop: - cmp rdx,16 - jae .LComputeSumExp.ProcessSingleVector - lea r10,C_UNDERSCORE(MlasOpmask16BitTableAvx512)[rip] - kmovw k1,WORD PTR [r10+rdx*2] - -.LComputeSumExp.ProcessSingleVector: - vaddps zmm0{k1}{z},zmm19,ZMMWORD PTR [rdi] - # bias by negative maximum value - vmaxps zmm0,zmm21,zmm0 # clamp lower bound - vmovaps zmm2,zmm23 - vfmadd213ps zmm2,zmm0,zmm22 # (input / ln2) plus rounding bias - vmovaps zmm1,zmm26 # p = poly_0 - vsubps zmm2,zmm2,zmm22 # m = round(input / ln2) - vfmadd231ps zmm0,zmm2,zmm24 # range reduce: x -= (m * ln2_high) - vfmadd231ps zmm0,zmm2,zmm25 # range reduce: x -= (m * ln2_low) - vfmadd213ps zmm1,zmm0,zmm27 # p = p * x + poly_1 - vfmadd213ps zmm1,zmm0,zmm28 # p = p * x + poly_2 - vfmadd213ps zmm1,zmm0,zmm29 # p = p * x + poly_3 - vfmadd213ps zmm1,zmm0,zmm30 # p = p * x + poly_4 - vfmadd213ps zmm1,zmm0,zmm31 # p = p * x + poly_5 - vfmadd213ps zmm1,zmm0,zmm31 # p = p * x + poly_6 - vscalefps zmm1,zmm1,zmm2 - vaddps zmm20{k1},zmm20,zmm1 # accumulate exp() results - add rdi,16*4 # advance input by 16 elements - test rsi,rsi - jz .LComputeSumExp.SkipStoreResultsBy16 - vmovups ZMMWORD PTR [rsi]{k1},zmm1 - add rsi,16*4 # advance output by 16 elements - -.LComputeSumExp.SkipStoreResultsBy16: - sub rdx,16 - ja .LComputeSumExp.ComputeExpBy16Loop - -.LComputeSumExp.ReduceAccumulator: - vextractf64x4 ymm0,zmm20,1 - vaddps zmm0,zmm0,zmm20 - vhaddps ymm0,ymm0,ymm0 - vhaddps ymm0,ymm0,ymm0 - vextractf128 xmm1,ymm0,1 - vaddss xmm0,xmm0,xmm1 - - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/TransKernelCommon.h b/onnxruntime/core/mlas/lib/x86_64/TransKernelCommon.h deleted file mode 100644 index f8c76522c784d..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/TransKernelCommon.h +++ /dev/null @@ -1,74 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - TransKernelCommon.h - -Abstract: - - This module contains common kernel macros and structures for the - transcendental functions. - ---*/ - -// -// Structure layout for the exponential function constants block. -// - - .equ .LExpConstants_LowerRange, 0 - .equ .LExpConstants_UpperRange, 4 - .equ .LExpConstants_LowerRangeSumExp, 8 - .equ .LExpConstants_UpperRangeSumExp, 12 - .equ .LExpConstants_RoundingBias, 16 - .equ .LExpConstants_Log2Reciprocal, 20 - .equ .LExpConstants_Log2High, 24 - .equ .LExpConstants_Log2Low, 28 - .equ .LExpConstants_poly_0, 32 - .equ .LExpConstants_poly_1, 36 - .equ .LExpConstants_poly_2, 40 - .equ .LExpConstants_poly_3, 44 - .equ .LExpConstants_poly_4, 48 - .equ .LExpConstants_poly_56, 52 - .equ .LExpConstants_MinimumExponent, 56 - .equ .LExpConstants_MaximumExponent, 60 - -// -// Structure layout for the logistic constants block. -// - - .equ .LLogisticConstants_LowerRange, 0 - .equ .LLogisticConstants_UpperRange, 4 - .equ .LLogisticConstants_alpha_9, 8 - .equ .LLogisticConstants_alpha_7, 12 - .equ .LLogisticConstants_alpha_5, 16 - .equ .LLogisticConstants_alpha_3, 20 - .equ .LLogisticConstants_alpha_1, 24 - .equ .LLogisticConstants_beta_10, 28 - .equ .LLogisticConstants_beta_8, 32 - .equ .LLogisticConstants_beta_6, 36 - .equ .LLogisticConstants_beta_4, 40 - .equ .LLogisticConstants_beta_2, 44 - .equ .LLogisticConstants_beta_0, 48 - .equ .LLogisticConstants_one_half, 52 - -// -// Structure layout for the tanh constants block. -// - - .equ .LTanhConstants_LowerRange, 0 - .equ .LTanhConstants_UpperRange, 4 - .equ .LTanhConstants_alpha_13, 8 - .equ .LTanhConstants_alpha_11, 12 - .equ .LTanhConstants_alpha_9, 16 - .equ .LTanhConstants_alpha_7, 20 - .equ .LTanhConstants_alpha_5, 24 - .equ .LTanhConstants_alpha_3, 28 - .equ .LTanhConstants_alpha_1, 32 - .equ .LTanhConstants_beta_6, 36 - .equ .LTanhConstants_beta_4, 40 - .equ .LTanhConstants_beta_2, 44 - .equ .LTanhConstants_beta_0, 48 diff --git a/onnxruntime/core/mlas/lib/x86_64/TransKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/TransKernelFma3.S deleted file mode 100644 index 829c735bac182..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/TransKernelFma3.S +++ /dev/null @@ -1,317 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - TransKernelFma3.s - -Abstract: - - This module implements kernels for various transcendental functions. - - This implementation uses AVX fused multiply/add instructions. - ---*/ - -#include "asmmacro.h" -#include "TransKernelCommon.h" - - .intel_syntax noprefix - - .text - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the exponential function. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Supplies the output buffer. - - N (rdx) - Supplies the number of elements to process. - -Return Value: - - None. - ---*/ - - FUNCTION_ENTRY MlasComputeExpF32KernelFma3 - - lea rax,C_UNDERSCORE(MlasExpConstants)[rip] - vbroadcastss ymm4,.LExpConstants_LowerRange[rax] - vbroadcastss ymm5,.LExpConstants_UpperRange[rax] - vbroadcastss ymm6,.LExpConstants_MinimumExponent[rax] - vbroadcastss ymm7,.LExpConstants_MaximumExponent[rax] - vbroadcastss ymm8,.LExpConstants_RoundingBias[rax] - vbroadcastss ymm9,.LExpConstants_Log2Low[rax] - vbroadcastss ymm10,.LExpConstants_poly_0[rax] - vbroadcastss ymm11,.LExpConstants_poly_1[rax] - vbroadcastss ymm12,.LExpConstants_poly_2[rax] - vbroadcastss ymm13,.LExpConstants_poly_3[rax] - vbroadcastss ymm14,.LExpConstants_poly_4[rax] - vbroadcastss ymm15,.LExpConstants_poly_56[rax] - - sub rdx,8 - jb .LComputeExp.ProcessRemainingCount - -.LComputeExp.ComputeExpBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rdi] # clamp lower bound - vbroadcastss ymm2,.LExpConstants_Log2Reciprocal[rax] - vminps ymm0,ymm5,ymm0 # clamp upper bound - vbroadcastss ymm3,.LExpConstants_Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm8 # (x / ln2) plus rounding bias - vsubps ymm1,ymm2,ymm8 # m = round(x / ln2) - vfmadd231ps ymm0,ymm1,ymm3 # range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm9 # range reduce: x -= (m * ln2_low) - vmovaps ymm1,ymm10 # p = poly_0 - vfmadd213ps ymm1,ymm0,ymm11 # p = p * x + poly_1 - vpslld ymm2,ymm2,23 # shift m to exponent field - vfmadd213ps ymm1,ymm0,ymm12 # p = p * x + poly_2 - vpminsd ymm3,ymm2,ymm7 # clamp upper normal exponent to +127 - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_3 - vpmaxsd ymm3,ymm3,ymm6 # clamp lower normal exponent to -126 - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_4 - vpsubd ymm2,ymm2,ymm3 # compute overflow exponent - vpaddd ymm3,ymm3,ymm7 # add exponent bias to normal scale - vpaddd ymm2,ymm2,ymm7 # add exponent bias to overflow scale - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_56 - vmulps ymm0,ymm0,ymm2 # scale x with overflow exponent - vfmadd213ps ymm1,ymm0,ymm2 # p = p * (x * overflow) + overflow - vmulps ymm1,ymm1,ymm3 # scale p with normal exponent - add rdi,8*4 # advance input by 8 elements - vmovups YMMWORD PTR [rsi],ymm1 - add rsi,8*4 # advance output by 8 elements - sub rdx,8 - jae .LComputeExp.ComputeExpBy8Loop - -.LComputeExp.ProcessRemainingCount: - add rdx,8 # correct for over-subtract above - jz .LComputeExp.ExitKernel - neg rdx - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovups ymm2,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] - vmaxps ymm0,ymm4,ymm0 # clamp lower bound - vbroadcastss ymm4,.LExpConstants_Log2Reciprocal[rax] - vminps ymm0,ymm5,ymm0 # clamp upper bound - vbroadcastss ymm3,.LExpConstants_Log2High[rax] - vfmadd213ps ymm4,ymm0,ymm8 # (x / ln2) plus rounding bias - vsubps ymm1,ymm4,ymm8 # m = round(x / ln2) - vfmadd231ps ymm0,ymm1,ymm3 # range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm9 # range reduce: x -= (m * ln2_low) - vmovaps ymm1,ymm10 # p = poly_0 - vfmadd213ps ymm1,ymm0,ymm11 # p = p * x + poly_1 - vpslld ymm4,ymm4,23 # shift m to exponent field - vfmadd213ps ymm1,ymm0,ymm12 # p = p * x + poly_2 - vpminsd ymm3,ymm4,ymm7 # clamp upper normal exponent to +127 - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_3 - vpmaxsd ymm3,ymm3,ymm6 # clamp lower normal exponent to -126 - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_4 - vpsubd ymm4,ymm4,ymm3 # compute overflow exponent - vpaddd ymm3,ymm3,ymm7 # add exponent bias to normal scale - vpaddd ymm4,ymm4,ymm7 # add exponent bias to overflow scale - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_5 - vmulps ymm0,ymm0,ymm4 # scale x with overflow exponent - vfmadd213ps ymm1,ymm0,ymm4 # p = p * (x * overflow) + overflow - vmulps ymm1,ymm1,ymm3 # scale p with normal exponent - vmaskmovps YMMWORD PTR [rsi],ymm2,ymm1 - -.LComputeExp.ExitKernel: - vzeroupper - ret - -/*++ - -Routine Description: - - This routine implements a vectorized kernel for the sum of exponential - functions. - -Arguments: - - Input (rdi) - Supplies the input buffer. - - Output (rsi) - Optionally supplies the output buffer. When used for Softmax, - the output buffer is used to store the intermediate exp() results. When - used for LogSoftmax, the intermediate exp() results are not required. - - N (rdx) - Supplies the number of elements to process. - - NegativeMaximum (rcx) - Supplies the address of the negative maximum value - that is added to each element before computing the exponential function. - -Return Value: - - Returns the sum of the exponential functions. - ---*/ - - FUNCTION_ENTRY MlasComputeSumExpF32KernelFma3 - - lea rax,C_UNDERSCORE(MlasExpConstants)[rip] - vbroadcastss ymm9,DWORD PTR [rcx] # broadcast negative maximum value - vxorps xmm10,xmm10,xmm10 # clear exp() accumulator - sub rdx,24 - jb .LComputeSumExp.ProcessRemainingCount - -.LComputeSumExp.ComputeExpBy24Loop: - vbroadcastss ymm11,.LExpConstants_LowerRangeSumExp[rax] - vbroadcastss ymm2,.LExpConstants_Log2Reciprocal[rax] - vaddps ymm0,ymm9,YMMWORD PTR [rdi] # bias by negative maximum value - vaddps ymm3,ymm9,YMMWORD PTR [rdi+32] - vaddps ymm6,ymm9,YMMWORD PTR [rdi+64] - vbroadcastss ymm15,.LExpConstants_RoundingBias[rax] - vmaxps ymm0,ymm11,ymm0 # clamp lower bound - vmovaps ymm5,ymm2 - vmaxps ymm3,ymm11,ymm3 - vmovaps ymm8,ymm2 - vmaxps ymm6,ymm11,ymm6 - vbroadcastss ymm13,.LExpConstants_Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm15 # (x / ln2) plus rounding bias - vfmadd213ps ymm5,ymm3,ymm15 - vfmadd213ps ymm8,ymm6,ymm15 - vbroadcastss ymm14,.LExpConstants_Log2Low[rax] - vsubps ymm1,ymm2,ymm15 # m = round(x / ln2) - vsubps ymm4,ymm5,ymm15 - vsubps ymm7,ymm8,ymm15 - vfmadd231ps ymm0,ymm1,ymm13 # range reduce: x -= (m * ln2_high) - vfmadd231ps ymm3,ymm4,ymm13 - vfmadd231ps ymm6,ymm7,ymm13 - vfmadd231ps ymm0,ymm1,ymm14 # range reduce: x -= (m * ln2_low) - vfmadd231ps ymm3,ymm4,ymm14 - vfmadd231ps ymm6,ymm7,ymm14 - vbroadcastss ymm1,.LExpConstants_poly_0[rax] - vbroadcastss ymm13,.LExpConstants_poly_1[rax] - vmovaps ymm4,ymm1 - vmovaps ymm7,ymm1 - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_1 - vfmadd213ps ymm4,ymm3,ymm13 - vfmadd213ps ymm7,ymm6,ymm13 - vbroadcastss ymm14,.LExpConstants_poly_2[rax] - vpslld ymm2,ymm2,23 # shift m to exponent field - vpslld ymm5,ymm5,23 - vpslld ymm8,ymm8,23 - vbroadcastss ymm15,.LExpConstants_MaximumExponent[rax] - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_2 - vfmadd213ps ymm4,ymm3,ymm14 - vfmadd213ps ymm7,ymm6,ymm14 - vbroadcastss ymm13,.LExpConstants_poly_3[rax] - vpaddd ymm2,ymm2,ymm15 # add exponent bias to scale - vpaddd ymm5,ymm5,ymm15 - vpaddd ymm8,ymm8,ymm15 - vbroadcastss ymm14,.LExpConstants_poly_4[rax] - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_3 - vfmadd213ps ymm4,ymm3,ymm13 - vfmadd213ps ymm7,ymm6,ymm13 - vbroadcastss ymm15,.LExpConstants_poly_56[rax] - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_4 - vfmadd213ps ymm4,ymm3,ymm14 - vfmadd213ps ymm7,ymm6,ymm14 - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_5 - vfmadd213ps ymm4,ymm3,ymm15 - vfmadd213ps ymm7,ymm6,ymm15 - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_6 - vfmadd213ps ymm4,ymm3,ymm15 - vfmadd213ps ymm7,ymm6,ymm15 - vmulps ymm1,ymm1,ymm2 # scale p with exponent - vmulps ymm4,ymm4,ymm5 - vaddps ymm10,ymm10,ymm1 # accumulate exp() results - vmulps ymm7,ymm7,ymm8 - vaddps ymm10,ymm10,ymm4 - add rdi,24*4 # advance input by 24 elements - vaddps ymm10,ymm10,ymm7 - test rsi,rsi - jz .LComputeSumExp.SkipStoreResultsBy24 - vmovups YMMWORD PTR [rsi],ymm1 - vmovups YMMWORD PTR [rsi+32],ymm4 - vmovups YMMWORD PTR [rsi+64],ymm7 - add rsi,24*4 # advance output by 24 elements - -.LComputeSumExp.SkipStoreResultsBy24: - sub rdx,24 - jae .LComputeSumExp.ComputeExpBy24Loop - -.LComputeSumExp.ProcessRemainingCount: - add rdx,24 # correct for over-subtract above - jz .LComputeSumExp.ReduceAccumulator - vbroadcastss ymm11,.LExpConstants_LowerRangeSumExp[rax] - -.LComputeSumExp.ComputeExpBy8Loop: - cmp rdx,8 # remaining count < 8? - jb .LComputeSumExp.LoadPartialVector - vmovups ymm0,YMMWORD PTR [rdi] - jmp .LComputeSumExp.ProcessSingleVector - -.LComputeSumExp.LoadPartialVector: - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - neg rdx # carry flag unchanged - vmovups ymm3,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] - vandps ymm9,ymm9,ymm3 # mask unused maximum value to 0.0 - -.LComputeSumExp.ProcessSingleVector: - vbroadcastss ymm2,.LExpConstants_Log2Reciprocal[rax] - vaddps ymm0,ymm9,ymm0 # bias by negative maximum value - vbroadcastss ymm15,.LExpConstants_RoundingBias[rax] - vmaxps ymm0,ymm11,ymm0 # clamp lower bound - vbroadcastss ymm13,.LExpConstants_Log2High[rax] - vfmadd213ps ymm2,ymm0,ymm15 # (input / ln2) plus rounding bias - vbroadcastss ymm14,.LExpConstants_Log2Low[rax] - vsubps ymm1,ymm2,ymm15 # round(input / ln2) - vfmadd231ps ymm0,ymm1,ymm13 # range reduce: x -= (m * ln2_high) - vfmadd231ps ymm0,ymm1,ymm14 # range reduce: x -= (m * ln2_low) - vbroadcastss ymm1,.LExpConstants_poly_0[rax] - vbroadcastss ymm13,.LExpConstants_poly_1[rax] - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_1 - vbroadcastss ymm14,.LExpConstants_poly_2[rax] - vpslld ymm2,ymm2,23 # # shift m to exponent field - vbroadcastss ymm15,.LExpConstants_MaximumExponent[rax] - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_2 - vbroadcastss ymm13,.LExpConstants_poly_3[rax] - vpaddd ymm2,ymm2,ymm15 # add exponent bias to scale - vbroadcastss ymm14,.LExpConstants_poly_4[rax] - vfmadd213ps ymm1,ymm0,ymm13 # p = p * x + poly_3 - vbroadcastss ymm15,.LExpConstants_poly_56[rax] - vfmadd213ps ymm1,ymm0,ymm14 # p = p * x + poly_4 - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_5 - vfmadd213ps ymm1,ymm0,ymm15 # p = p * x + poly_6 - vmulps ymm1,ymm1,ymm2 - jb .LComputeSumExp.StorePartialVector - # remaining count < 8? - vaddps ymm10,ymm10,ymm1 # accumulate exp() results - test rsi,rsi # store exp() results? - jz .LComputeSumExp.SkipStoreResultsBy8 - vmovups YMMWORD PTR [rsi],ymm1 - add rsi,8*4 # advance output by 8 elements - -.LComputeSumExp.SkipStoreResultsBy8: - add rdi,8*4 # advance input by 8 elements - sub rdx,8 - jnz .LComputeSumExp.ComputeExpBy8Loop - jmp .LComputeSumExp.ReduceAccumulator - -.LComputeSumExp.StorePartialVector: - vandps ymm1,ymm1,ymm3 # mask unused exp() results to 0.0 - vaddps ymm10,ymm10,ymm1 # accumulate exp() results - test rsi,rsi # store exp() results? - jz .LComputeSumExp.ReduceAccumulator - vmaskmovps YMMWORD PTR [rsi],ymm3,ymm1 - -.LComputeSumExp.ReduceAccumulator: - vhaddps ymm10,ymm10,ymm10 - vhaddps ymm10,ymm10,ymm10 - vextractf128 xmm0,ymm10,1 - vaddss xmm0,xmm0,xmm10 - - vzeroupper - ret - - .end diff --git a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h deleted file mode 100644 index 7d7b3079a5132..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h +++ /dev/null @@ -1,150 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - asmmacro.h - -Abstract: - - This module implements common macros for the assembly modules. - ---*/ - -#if defined(__APPLE__) -#define C_UNDERSCORE(symbol) _##symbol -#else -#define C_UNDERSCORE(symbol) symbol -#endif - -/*++ - -Macro Description: - - This macro emits the assembler directives to annotate a new function. - -Arguments: - - FunctionName - Supplies the name of the function. - ---*/ - - .macro FUNCTION_ENTRY FunctionName - - .p2align 4 -#if defined(__APPLE__) - .globl _\FunctionName\() -_\FunctionName\(): -#else - .globl \FunctionName\() - .type \FunctionName\(),@function -\FunctionName\(): -#endif - - .endm - -/*++ - -Macro Description: - - This macro generates an optimization for "add reg,128" which can instead - be encoded as "sub reg,-128" to reduce code size by using a signed 8-bit - value. - -Arguments: - - Register - Supplies the register to be added to. - - Immediate - Supplies the immediate to add to the register. - ---*/ - - .macro add_immed Register, Immediate - -.if (\Immediate\() != 128) - add \Register\(),\Immediate\() -.else - sub \Register\(),-\Immediate\() # smaller encoding -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count is greater than or - equal to Value. - -Arguments: - - Count - Supplies the variable used in the comparison. - - Value - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCountGE Count1, Value1, Statement - -.if (\Count1\() >= \Value1\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro conditionally emits the statement if Count1 is greater than or - equal to Value1 and Count2 is greater than or equal to Value2. - -Arguments: - - Count1 - Supplies the variable used in the comparison. - - Value1 - Supplies the static used in the comparison. - - Count2 - Supplies the variable used in the comparison. - - Value2 - Supplies the static used in the comparison. - - Statement - Supplies the statement to conditionally emit. - ---*/ - - .macro EmitIfCount2GE Count1, Value1, Count2, Value2, Statement - -.if (\Count1\() >= \Value1\()) && (\Count2\() >= \Value2\()) - \Statement\() -.endif - - .endm - -/*++ - -Macro Description: - - This macro emits the statement for each register listed in the register - list. The statement can use RegItem to access the current register. - -Arguments: - - RegList - Supplies the list of registers. - - Statement - Supplies the statement to emit. - ---*/ - - .macro EmitForEachRegister RegList, Statement - - .irp RegItem, \RegList\() - \Statement\() - .endr - - .endm diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S deleted file mode 100644 index a4d730fa513ab..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S +++ /dev/null @@ -1,143 +0,0 @@ -/*++ - -Copyright (c) Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - cvtfp16Avx2.asm - -Abstract: - - This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. - ---*/ - -#include "asmmacro.h" - -.data -.equ SINGLE_SIZE, 4 -.equ HALF_SIZE, 2 -.equ LOW_SELECTOR, 0b00100000 -.equ HIGH_SELECTOR, 0b00110001 - -.text -.intel_syntax noprefix - -/*++ Routine Description: - - This routine converts the source buffer of half-precision floats to the - destination buffer of single-precision floats. - - This implementation uses AVX2 instructions. - - Arguments: - - Source (rdi) - Supplies the address of the source buffer of half-precision - floats. - - Destination (rsi) - Supplies the address of the destination buffer of - single-precision floats. - - Count (rdx) - Supplies the number of elements to convert. - - Return Value: - - None. - ---*/ -FUNCTION_ENTRY MlasCastF16ToF32KernelAvx - - test rdx, rdx // Check if we have any elements to convert - jz ExitRoutine - cmp rdx, 8 - jb ConvertMaskedVectors - cmp rdx, 16 - jb Convert128Vectors - -Convert256Vectors: - vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes - vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes - vunpcklps ymm2, ymm0, ymm1 // Interleave low part - vunpckhps ymm1, ymm0, ymm1 // Interleave high part - vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order - vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order - vmovups ymmword PTR [rsi], ymm0 // Store the low part - vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part - - add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements - add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements - sub rdx, 16 // Reduce the counter by 16 elements - - jz ExitRoutine // If we are done, exit - cmp rdx, 16 // If the vector is big enough, we go again - jae Convert256Vectors - cmp rdx, 8 // Check if we have enough elements to convert - jb ConvertMaskedVectors - - - -Convert128Vectors: - vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes - vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order - vmovups xmmword PTR [rsi], xmm0 // Store the low part - vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part - - add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements - add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements - sub rdx, 8 // Reduce the counter by 8 elements - - jz ExitRoutine // If we are done, exit - - - -ConvertMaskedVectors: - vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes - vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes - vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order - vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order - - cmp rdx, 4 // Check if we can store the complete lower vector - jae ConvertLowerVector - - vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones - cmp rdx, 2 // Check how many converts we need - jb ConvertLower1 - ja ConvertLower3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values - jmp ConvertLowerMaskedVector -ConvertLower1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value - jmp ConvertLowerMaskedVector -ConvertLower3: - vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values -ConvertLowerMaskedVector: - vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples - jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing -ConvertLowerVector: - vmovups xmmword PTR [rsi], xmm0 // Store the low part - sub rdx, 4 // Check if we still need to convert - jz ExitRoutine - - - add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements - vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones - cmp rdx, 2 // Check how many converts we need - jb ConvertUpper1 - ja ConvertUpper3 - vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values - jmp ConvertMaskedUpperVector -ConvertUpper1: - vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value - jmp ConvertMaskedUpperVector -ConvertUpper3: - vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values -ConvertMaskedUpperVector: - vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples - - jmp ExitRoutine -ExitRoutine: - ret diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S deleted file mode 100644 index f27114c183f44..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S +++ /dev/null @@ -1,129 +0,0 @@ -/*++ - -Copyright (c) Intel Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - cvtfp16a.S - -Abstract: - - This module implements routines to convert between FP16 and FP32 formats using SSE2 isntructions. - ---*/ - -#include "asmmacro.h" - -// We use RIP relative addressing to avoid reallication related errors -.section .rodata -MlasFp16MaskSign: .long 0x00007FFF -MlasFp16CompareInfinity: .long 0x00007C00 -MlasFp16CompareSmallest: .long 0x00000400 -MlasFp16AdjustExponent: .long 0x38000000 -MlasFp16MagicDenormal: .long 0x38800000 - -.text -.intel_syntax noprefix - -/*++ Routine Description: - - This routine converts the source buffer of half-precision floats to the - destination buffer of single-precision floats. - - This implementation uses SSE2 instructions. - - Arguments: - - Source (rdi) - Supplies the address of the source buffer of half-precision - floats. - - Destination (rsi) - Supplies the address of the destination buffer of - single-precision floats. - - Count (rdx) - Supplies the number of elements to convert. - - Return Value: - - None. - ---*/ - -FUNCTION_ENTRY MlasCastF16ToF32KernelSse - - test rdx,rdx - jz ExitRoutine - - // Load xmm constants - movd xmm5, DWORD PTR [rip + MlasFp16MaskSign] - pshufd xmm5, xmm5, 0x00 - movd xmm6, DWORD PTR [rip + MlasFp16AdjustExponent] - pshufd xmm6, xmm6, 0x00 - movd xmm7, DWORD PTR [rip + MlasFp16MagicDenormal] - pshufd xmm7, xmm7, 0x00 - - - cmp rdx,4 - jb LoadPartialVector - -LoadFullVector: - movq xmm0,QWORD PTR [rdi] - add rdi,4*2 // advance S by 4 elements - -ConvertHalfToFloat: - punpcklwd xmm0,xmm0 // duplicate 4 WORDs to 4 DWORDs - movaps xmm1,xmm0 // isolate exponent/mantissa - pand xmm1,xmm5 - pxor xmm0,xmm1 // isolate sign bit - movd xmm2, DWORD PTR [rip + MlasFp16CompareInfinity] - pshufd xmm2, xmm2, 0x00 - pcmpgtd xmm2,xmm1 // test for infinity/NaNs - movd xmm3, DWORD PTR [rip + MlasFp16CompareSmallest] - pshufd xmm3, xmm3, 0x00 - pcmpgtd xmm3,xmm1 // test for denormals - pandn xmm2,xmm6 - pslld xmm1,13 // shift exponent/mask into place - movaps xmm4,xmm1 - paddd xmm1,xmm6 - paddd xmm1,xmm2 // adjust exponent again for infinity/NaNs - paddd xmm4,xmm7 - pslld xmm0,16 // shift sign into place - subps xmm4,xmm7 - pand xmm4,xmm3 // select elements that are denormals - pandn xmm3,xmm1 // select elements that are not denormals - por xmm3,xmm4 // blend the selected values together - por xmm0,xmm3 // merge sign into exponent/mantissa - - cmp rdx,4 // storing full vector? - jb StorePartialVector - movups XMMWORD PTR [rsi],xmm0 - add rsi,4*4 // advance D by 4 elements - sub rdx,4 - jz ExitRoutine - cmp rdx,4 - jae LoadFullVector - -LoadPartialVector: - pxor xmm0,xmm0 - pinsrw xmm0,WORD PTR [rdi],0 - cmp rdx,2 - jb ConvertHalfToFloat - pinsrw xmm0,WORD PTR [rdi+2],1 - je ConvertHalfToFloat - pinsrw xmm0,WORD PTR [rdi+4],2 - jmp ConvertHalfToFloat - -StorePartialVector: - cmp rdx,2 - jb StoreLastElement - movsd QWORD PTR [rsi],xmm0 - je ExitRoutine - movhlps xmm0,xmm0 // shift third element down - add rsi,4*2 // advance D by 2 elements - -StoreLastElement: - movss DWORD PTR [rsi],xmm0 - -ExitRoutine: - ret diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index ea9d8605e2417..2e036670212f1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -7,7 +7,7 @@ #include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/utils.h" diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 6f90eaf07ef4d..dd71df32b7652 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -5,7 +5,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/conv_add_act_fusion.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/utils.h" diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f769d31092d19..41fff826160ee 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -18,7 +18,7 @@ #if !defined(ORT_MINIMAL_BUILD) -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/bias_dropout_fusion.h" #include "core/optimizer/bias_gelu_fusion.h" diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 46f306b92bed5..02c89483f9fb4 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -5,7 +5,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/nchwc_transformer.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index ee79fa620374e..893de438df437 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -3,7 +3,7 @@ // Licensed under the MIT License. #include -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/nhwc_transformer.h" diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 8f99b7409d4fe..3e93c4b0ab042 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -9,7 +9,7 @@ #include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" #include "core/framework/tensorprotoutils.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" namespace onnxruntime { namespace QDQ { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 2738c3ab02799..648828f277eef 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -7,7 +7,7 @@ #include #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 627ddd35b9919..a4d4fbc852e1a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -8,7 +8,7 @@ #include #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/activation/activations.cc b/onnxruntime/core/providers/cpu/activation/activations.cc index 049fee4b95308..78a94f70901f4 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.cc +++ b/onnxruntime/core/providers/cpu/activation/activations.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/cpu/activation/activations.h" #include "core/providers/cpu/fp16/fp16_activations.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d57c33ae965b1..77e4210e07aac 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -7,7 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/int4.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/cpu/cpu_contrib_kernels.h" diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_activations.h b/onnxruntime/core/providers/cpu/fp16/fp16_activations.h index 5404a1b180b64..8a98a56e84b70 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_activations.h +++ b/onnxruntime/core/providers/cpu/fp16/fp16_activations.h @@ -3,7 +3,7 @@ #pragma once -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/framework/float16.h" #include "core/providers/cpu/activation/activations.h" diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 37db095e92570..766f84bb3efd9 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -5,7 +5,7 @@ // This file contains implementation of a fp16 convolution operator. // -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc b/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc index 7c1e05f7ce277..ae00f20023f05 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index a78ff69e5c894..68a01834433ea 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -10,7 +10,7 @@ #include "core/providers/op_kernel_type_control.h" #include #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..a9e4f5584b092 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -8,7 +8,7 @@ #include "core/providers/cpu/math/gemm_matmul_common.h" #include "core/util/math_cpuonly.h" #include "gemm_helper.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b9bbe36583879..4dcf4f8552f79 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -4,7 +4,7 @@ #pragma once #include "core/framework/op_kernel.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 2817dda9d0085..6930734b06056 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -26,7 +26,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { template diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 2f4ebeabe043e..9ea51a3b3f404 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -7,7 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include "core/common/inlined_containers.h" diff --git a/onnxruntime/core/providers/cpu/nn/conv.h b/onnxruntime/core/providers/cpu/nn/conv.h index 5ed5d2ca91def..d294be7737f05 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.h +++ b/onnxruntime/core/providers/cpu/nn/conv.h @@ -5,7 +5,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index f0c1b0b409831..5a5061997e1ef 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -17,7 +17,7 @@ #include "core/providers/cpu/nn/conv_transpose.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/common/safeint.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 24a5dcab225c4..84cf7bb44f93d 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -5,7 +5,7 @@ #include "core/common/safeint.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/util/force_inline.h" diff --git a/onnxruntime/core/providers/cpu/nn/pool_base.h b/onnxruntime/core/providers/cpu/nn/pool_base.h index 00dd1b152026d..baae262261ec7 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_base.h +++ b/onnxruntime/core/providers/cpu/nn/pool_base.h @@ -10,7 +10,7 @@ #include "core/util/math.h" #endif #include "core/providers/cpu/nn/pool_attributes.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/quantization/dynamicquantizelinear.cc b/onnxruntime/core/providers/cpu/quantization/dynamicquantizelinear.cc index 185cd19357742..237dbd864bb1e 100644 --- a/onnxruntime/core/providers/cpu/quantization/dynamicquantizelinear.cc +++ b/onnxruntime/core/providers/cpu/quantization/dynamicquantizelinear.cc @@ -3,7 +3,7 @@ #include "dynamicquantizelinear.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/common.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index e26eae19b8fd4..53cc24761292d 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/framework/op_kernel.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/common.h" #include "core/common/safeint.h" #include "core/quantization/quantization.h" diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 7797cbe678bd4..d1bff3e96b70a 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -9,7 +9,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 3d3e831a12d13..80a5a37524f41 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -8,7 +8,7 @@ #include "core/framework/int4.h" #include "core/framework/op_kernel.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/util/qmath.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc index be448455194f6..24ede95f4175a 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc @@ -10,7 +10,7 @@ #include "core/providers/common.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { // uint8_t kernel supports weight being either uint8_t or int8_t diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc index 9e865671e047d..1a1fb5d381049 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc @@ -13,7 +13,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/cpu/rnn/rnn_activation_functors.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h index 6d54c24b3808b..9ab0e7bc834df 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h @@ -14,7 +14,7 @@ #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 35f3b12aeba35..29450da72bc34 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,7 +22,7 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/common/cpuid_info.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc index d55973eda180f..7330e6f331a57 100644 --- a/onnxruntime/core/providers/cpu/tensor/gelu.cc +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -5,7 +5,7 @@ #include "core/common/narrow.h" #include "core/framework/op_kernel.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 5b904e85848d0..426af3a2e11cd 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -8,7 +8,7 @@ #include "core/framework/utils.h" #include "core/framework/transpose_helper.h" #include "core/framework/op_kernel_type_control_utils.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/op_kernel_type_control.h" #include "utils.h" diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index 45de47f3e5128..8109d1c7226b2 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -11,10 +11,10 @@ index c02ac2096d..2bc51298f0 100644 set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S -diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h +diff --git a/onnxruntime/mlas.h b/onnxruntime/mlas.h index e46105324a..414c46a1ce 100644 ---- a/onnxruntime/core/mlas/inc/mlas.h -+++ b/onnxruntime/core/mlas/inc/mlas.h +--- a/onnxruntime/mlas.h ++++ b/onnxruntime/mlas.h @@ -82,6 +82,9 @@ Abstract: #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) diff --git a/onnxruntime/core/quantization/quantization.h b/onnxruntime/core/quantization/quantization.h index 9acdfa6d86ccf..20440aeea4737 100644 --- a/onnxruntime/core/quantization/quantization.h +++ b/onnxruntime/core/quantization/quantization.h @@ -8,7 +8,7 @@ #include #include "core/common/common.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" // This header contains utility functions for quantizing and dequantizing // values as outlined in the logic in diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 983321593a92b..cdc1c8976f0c0 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -22,7 +22,7 @@ #include #include #include "core/common/narrow.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 1b2180da95058..4a67797e42c4a 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -3,7 +3,7 @@ #pragma once -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/platform/threadpool.h" #include "core/common/narrow.h" #include "core/framework/element_type_lists.h" diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 51a52af1b151e..97af2ecc90512 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -5,7 +5,7 @@ #include #include -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" #include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index c641103a74465..69927e90784f8 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -7,7 +7,7 @@ #include #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 8138829b057f2..be78368178898 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -11,9 +11,9 @@ #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas_qnbit.h" -#include "core/mlas/inc/mlas_q4.h" -#include "core/mlas/inc/mlas.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" +#include "mlas.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc index e739b17d5885f..986cbad6e50e3 100644 --- a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -5,8 +5,8 @@ #include "core/common/span_utils.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas_q4.h" -#include "core/mlas/inc/mlas.h" +#include "mlas_q4.h" +#include "mlas.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" diff --git a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc index 09ae5eddb122c..e5f4a16d57d4c 100644 --- a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc @@ -5,7 +5,7 @@ #include "core/common/span_utils.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 8d7629b5fda1c..171d9a6163e54 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -3,7 +3,7 @@ #include "core/common/span_utils.h" #include "core/framework/tensor.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" diff --git a/onnxruntime/test/contrib_ops/nhwc_maxpool_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_maxpool_op_test.cc index 0a74a68a80b12..dd0b12192f52f 100644 --- a/onnxruntime/test/contrib_ops/nhwc_maxpool_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_maxpool_op_test.cc @@ -5,7 +5,7 @@ #include #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc index e40e635c79a26..08e60495ee89c 100644 --- a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc @@ -9,7 +9,7 @@ #include #include "core/util/math.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc index 71b6f27b5391f..51cb67b34b366 100644 --- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc @@ -5,7 +5,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/contrib_ops/qlinear_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_pool_test.cc index 78f7f431aa66e..91c7886f0ea46 100644 --- a/onnxruntime/test/contrib_ops/qlinear_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_pool_test.cc @@ -6,7 +6,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp deleted file mode 100644 index 65822eb294d7d..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/mlas/lib/mlasi.h" -#include "core/util/thread_utils.h" -#include "test/mlas/bench/bench_util.h" - -using onnxruntime::narrow; - -struct RestrictAlignedPtr { - float* ptr; // Aligned pointer within the underlying buffer - void* underlying_buffer; // Underlying buffer (including extra space for alignment) -}; - -// Return a RestrictAlignedPtr where the ptr is aligned to byte_aligned, but not to byte_aligned * 2 -RestrictAlignedPtr restrict_aligned_alloc(int D, int byte_aligned) { - if (byte_aligned <= 0 || (byte_aligned & (byte_aligned - 1)) != 0) { - throw std::invalid_argument("Alignment must be a power of 2"); - } - - const int byte_alignedx2 = byte_aligned << 1; - - void* buffer = malloc(D * sizeof(float) + byte_alignedx2 * 2); - if (buffer == nullptr) { - ORT_THROW_EX(std::bad_alloc); - } - - uintptr_t address = reinterpret_cast(buffer); - uintptr_t aligned_address = ((address + byte_alignedx2 - 1) & ~(byte_alignedx2 - 1)) + byte_aligned; - ORT_ENFORCE((aligned_address % byte_aligned == 0) && (aligned_address % byte_alignedx2 != 0)); - float* aligned_ptr = reinterpret_cast(aligned_address); - - return {aligned_ptr, buffer}; -} - -void COMPUTESOFTMAXINPLACE(benchmark::State& state) { - const auto byte_aligned = narrow(state.range(0)); - const auto N = narrow(state.range(1)); - const auto D = narrow(state.range(2)); - const auto threads = narrow(state.range(3)); - - if (N <= 0 || D <= 0 || threads <= 0) { - throw std::invalid_argument("N, D, and Threads must be greater than 0!"); - } - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = threads; - tpo.auto_set_affinity = true; - - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool( - &onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto data = RandomVectorUniform(static_cast(N * D), -1.0f, 1.0f); - RestrictAlignedPtr ptr = restrict_aligned_alloc(N * D, byte_aligned); - float* input = ptr.ptr; - float* output = input; - std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory - - // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); - - for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); - } - - free(ptr.underlying_buffer); -} - -#if defined(MLAS_TARGET_AMD64) - -void REDUCEMAXIMUMF32KERNELAVX(benchmark::State& state) { - const auto byte_aligned = narrow(state.range(0)); - const auto D = narrow(state.range(1)); - - if (D <= 0) { - throw std::invalid_argument("D must be greater than 0!"); - } - - auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); - RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); - float* input = ptr.ptr; - std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory - - // warming up run - float Maximum = MlasReduceMaximumF32KernelAvx(input, D); - - for (auto _ : state) { - Maximum = MlasReduceMaximumF32KernelAvx(input, D); - } - - free(ptr.underlying_buffer); - (void)Maximum; -} - -void REDUCEMAXIMUMF32KERNELAVX512F(benchmark::State& state) { - const auto byte_aligned = narrow(state.range(0)); - const auto D = narrow(state.range(1)); - - if (D <= 0) { - throw std::invalid_argument("D must be greater than 0!"); - } - - auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); - RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); - float* input = ptr.ptr; - std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory - - // warming up run - float Maximum = MlasReduceMaximumF32KernelAvx512F(input, D); - - for (auto _ : state) { - Maximum = MlasReduceMaximumF32KernelAvx512F(input, D); - } - - free(ptr.underlying_buffer); - (void)Maximum; -} - -void COMPUTESUMEXPF32KERNELAVX512F(benchmark::State& state) { - const auto byte_aligned = narrow(state.range(0)); - const auto D = narrow(state.range(1)); - - if (D <= 0) { - throw std::invalid_argument("D must be greater than 0!"); - } - - auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); - RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); - float* input = ptr.ptr; - float* output = input; - std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory - - float Maximum = MlasReduceMaximumF32KernelAvx(input, D); - float NegativeMaximum = -Maximum; - - // warming up run - float Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); - - for (auto _ : state) { - Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); - } - - free(ptr.underlying_buffer); - (void)Accumulation; -} - -void COMPUTESOFTMAXOUTPUTF32KERNELAVX(benchmark::State& state) { - const auto byte_aligned = narrow(state.range(0)); - const auto D = narrow(state.range(1)); - - if (D <= 0) { - throw std::invalid_argument("D must be greater than 0!"); - } - - auto data = RandomVectorUniform(static_cast(D), -1.0f, 1.0f); - RestrictAlignedPtr ptr = restrict_aligned_alloc(D, byte_aligned); - float* input = ptr.ptr; - float* output = input; - std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory - - float Maximum = MlasReduceMaximumF32KernelAvx(input, D); - float NegativeMaximum = -Maximum; - - float Accumulation = MlasComputeSumExpF32KernelAvx512F(input, output, D, &NegativeMaximum); - - float Parameters[] = {1.0f / Accumulation}; - - // warming up run - MlasComputeSoftmaxOutputF32KernelAvx(output, D, Parameters); - - for (auto _ : state) { - MlasComputeSoftmaxOutputF32KernelAvx(output, D, Parameters); - } - - free(ptr.underlying_buffer); -} - -#endif // defined(MLAS_TARGET_AMD64) - -static void ComputeSoftmaxInplaceArgs(benchmark::internal::Benchmark* b) { - b->ArgNames({"ByteAligned", "N", "D", "Threads"}); - for (int threads : {1, 8}) { - for (int byte_aligned : {64}) { // MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT is 64 - b->Args({byte_aligned, 16000, 4, threads}); - b->Args({byte_aligned, 16000, 500, threads}); - b->Args({byte_aligned, 48000, 3, threads}); - b->Args({byte_aligned, 48000, 2000, threads}); - b->Args({byte_aligned, 80000, 5, threads}); - b->Args({byte_aligned, 80000, 2000, threads}); - b->Args({byte_aligned, 112000, 7, threads}); - b->Args({byte_aligned, 112000, 2000, threads}); - b->Args({byte_aligned, 144000, 9, threads}); - b->Args({byte_aligned, 144000, 2000, threads}); - b->Args({byte_aligned, 176000, 11, threads}); - b->Args({byte_aligned, 176000, 2000, threads}); - b->Args({byte_aligned, 208000, 13, threads}); - b->Args({byte_aligned, 208000, 2000, threads}); - b->Args({byte_aligned, 240000, 15, threads}); - b->Args({byte_aligned, 240000, 2000, threads}); - } - } -} - -BENCHMARK(COMPUTESOFTMAXINPLACE)->Apply(ComputeSoftmaxInplaceArgs)->UseRealTime(); - -#if defined(MLAS_TARGET_AMD64) - -BENCHMARK(REDUCEMAXIMUMF32KERNELAVX) - ->ArgNames({"ByteAligned", "D"}) - ->ArgsProduct({ - {4, 8, 16, 32, 64, 128}, // ByteAligned - {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D - }) - ->UseRealTime(); - -BENCHMARK(REDUCEMAXIMUMF32KERNELAVX512F) - ->ArgNames({"ByteAligned", "D"}) - ->ArgsProduct({ - {4, 8, 16, 32, 64, 128}, // ByteAligned - {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D - }) - ->UseRealTime(); - -BENCHMARK(COMPUTESUMEXPF32KERNELAVX512F) - ->ArgNames({"ByteAligned", "D"}) - ->ArgsProduct({ - {4, 8, 16, 32, 64, 128}, // ByteAligned - {3, 4, 5, 7, 9, 11, 13, 15, 500, 2000}, // D - }) - ->UseRealTime(); - -BENCHMARK(COMPUTESOFTMAXOUTPUTF32KERNELAVX) - ->ArgNames({"ByteAligned", "D"}) - ->ArgsProduct({ - {4, 8, 16, 32, 64, 128}, // ByteAligned - {3, 4, 5, 7, 9, 11, 13, 15, 16, 500, 2000}, // D - }) - ->UseRealTime(); - -#endif // defined(MLAS_TARGET_AMD64) diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp deleted file mode 100644 index 1dccbe44aafaf..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include "bench_util.h" -#include "core/mlas/lib/mlasi.h" - -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - -void BM_ConvertF16ToF32(benchmark::State& state) { - bool aligned = static_cast(state.range(0)); - const size_t count = 1 << 18; - auto src = RandomVectorUniform(count, 0, 60000); - auto dst = std::vector(count + 16); - auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); - float* dst_start = aligned ? reinterpret_cast(aligned_dst) - : reinterpret_cast(aligned_dst + 1); - - // Warm up - MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); - - for (auto _ : state) { - MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); - } -} - -void BM_ConvertF32ToF16(benchmark::State& state) { - bool aligned = static_cast(state.range(0)); - const size_t count = 1 << 18; - auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); - auto dst = std::vector(count + 16); - auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); - unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) - : reinterpret_cast(aligned_dst + 1); - - // Warm up - MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); - - for (auto _ : state) { - MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); - } -} - -BENCHMARK(BM_ConvertF16ToF32) - ->UseRealTime() - ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"aligned"}); - b->ArgsProduct({{0, 1}}); - }); - -BENCHMARK(BM_ConvertF32ToF16) - ->UseRealTime() - ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"aligned"}); - b->ArgsProduct({{0, 1}}); - }); - -#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/bench/bench_main.cpp b/onnxruntime/test/mlas/bench/bench_main.cpp deleted file mode 100644 index 5ef8fd2cc02a8..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_main.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -BENCHMARK_MAIN(); diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp deleted file mode 100644 index 6d21ed2eef864..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/mlas/inc/mlas_q4.h" -#include "test/mlas/bench/bench_util.h" -#include "core/util/thread_utils.h" - -static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = (int)state.range(0); - int N = (int)state.range(1); - int quant_block_size = (int)state.range(2); - int threads = (int)state.range(3); - size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; - - auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); - auto scales = std::vector(scale_size); - auto zero_points = std::vector((scale_size + 1) / 2); - auto dst = std::vector((M * N + 1) / 2); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - for (auto _ : state) { - benchmark::DoNotOptimize(dst.data()); - MlasQDQQuantizeBlockwise( - src.data(), scales.data(), zero_points.data(), dst.data(), - true, M, N, quant_block_size, tp.get()); - benchmark::ClobberMemory(); - } -} - -static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = (int)state.range(0); - int N = (int)state.range(1); - int quant_block_size = (int)state.range(2); - int threads = (int)state.range(3); - size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; - - auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); - auto scales = std::vector(scale_size); - auto zero_points = std::vector((scale_size + 1) / 2); - auto dst = std::vector((M * N + 1) / 2); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - for (auto _ : state) { - benchmark::DoNotOptimize(dst.data()); - MlasQuantizeBlockwise( - dst.data(), scales.data(), zero_points.data(), src.data(), - quant_block_size, true, M, N, N, tp.get()); - benchmark::ClobberMemory(); - } -} - -static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = (int)state.range(0); - int N = (int)state.range(1); - int quant_block_size = (int)state.range(2); - int threads = (int)state.range(3); - bool add8 = state.range(4) != 0; - int quant_num_M = (M + quant_block_size - 1) / quant_block_size; - int blob_size = (quant_block_size + 1) / 2; - size_t scale_size = quant_num_M * N; - - auto scales = RandomVectorUniform(scale_size, -16.0f, 14.0f); - auto zero_points = RandomVectorUniform(static_cast((scale_size + 1) / 2), 0, 255); - auto dst = RandomVectorUniform(static_cast((M * N + 1) / 2), 0, 255); - auto scales_T = std::vector(scale_size); - auto zero_points_T = std::vector(((quant_num_M + 1) / 2) * N); - auto dst_T = std::vector(quant_num_M * blob_size * N); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - if (add8) { - for (auto _ : state) { - benchmark::DoNotOptimize(dst.data()); - MlasQDQTransposeBlockwiseQuantized( - dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), - true, M, N, quant_block_size, tp.get()); - benchmark::ClobberMemory(); - } - } else { - for (auto _ : state) { - benchmark::DoNotOptimize(dst.data()); - MlasQDQTransposeBlockwiseQuantized( - dst.data(), scales.data(), zero_points.data(), dst_T.data(), scales_T.data(), zero_points_T.data(), - true, M, N, quant_block_size, tp.get()); - benchmark::ClobberMemory(); - } - } -} - -BENCHMARK(BM_QDQBlockwiseQuantizer_QuantizeColumnwise) - ->UseRealTime() - ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "quant_block_size", "threads"}); - b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {8}}); - }); - -BENCHMARK(BM_MlasQuantizeBlockwise) - ->UseRealTime() - ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "quant_block_size", "threads"}); - b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {8}}); - }); - -BENCHMARK(BM_QDQBlockwiseQuantizer_TransposeColumnwise) - ->UseRealTime() - ->Apply([](benchmark::internal::Benchmark* b) { - b->ArgNames({"M", "N", "quant_block_size", "threads", "add8"}); - b->ArgsProduct({{1024, 4096}, {4096, 4095}, {64, 128}, {2, 8, 16}, {0, 1}}); - }); diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp deleted file mode 100644 index 61b3f57d8daac..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas_q4.h" -#include "bench_util.h" -#include "core/util/thread_utils.h" - -#include -#include - -static const std::vector q4gemm_bench_arg_names = {"M", "N", "K", "Threads"}; - -void Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - const size_t threads = static_cast(state.range(3)); - const size_t pack_b_size = MlasQ4GemmPackBSize(qtype, N, K); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = int(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto A1 = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); - std::vector C1(static_cast(M * N)); - - std::vector B1_packed(pack_b_size); - MlasQ4GemmPackB(qtype, B1_packed.data(), B1.data(), N, K, N); - - MLAS_Q4_GEMM_DATA_PARAMS params1; - params1.A = A1.data(); - params1.lda = K; - params1.Bias = nullptr; - params1.C = C1.data(); - params1.ldc = N; - params1.B = B1_packed.data(); - params1.OutputProcessor = nullptr; - - MlasQ4GemmBatch(qtype, M, N, K, 1, ¶ms1, tp.get()); - - for (auto _ : state) { - MlasQ4GemmBatch(qtype, M, N, K, 1, ¶ms1, tp.get()); - } -} - -void Q8Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - const size_t threads = static_cast(state.range(3)); - const size_t pack_b_size = MlasQ4GemmPackBSize(qtype, N, K); - const size_t quant_a_size = MlasQ80BlkQuantSize(qtype, M, K); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = int(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto A1 = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); - std::vector C1(static_cast(M * N)); - - std::vector B1_packed(pack_b_size); - MlasQ4GemmPackB(qtype, B1_packed.data(), B1.data(), N, K, N); - - std::vector A1_quant(quant_a_size); - - MlasQ80BlkQuant(BlkQ4Sym, A1_quant.data(), A1.data(), M, K, K, tp.get()); - - MLAS_Q8Q4_GEMM_DATA_PARAMS params1; - params1.A = A1.data(); - params1.B = B1_packed.data(); - params1.Bias = nullptr; - params1.C = C1.data(); - params1.ldc = N; - params1.OutputProcessor = nullptr; - - MlasQ8Q4GemmBatch(qtype, M, N, K, 1, ¶ms1, tp.get()); - - for (auto _ : state) { - MlasQ80BlkQuant(BlkQ4Sym, A1_quant.data(), A1.data(), M, K, K, tp.get()); - - MLAS_Q8Q4_GEMM_DATA_PARAMS params; - params.A = A1.data(); - params.B = B1_packed.data(); - params.Bias = nullptr; - params.C = C1.data(); - params.ldc = N; - params.OutputProcessor = nullptr; - MlasQ8Q4GemmBatch(qtype, M, N, K, 1, ¶ms, tp.get()); - } -} - -static void GemmSizeProducts(benchmark::internal::Benchmark* b) { - b->ArgNames(q4gemm_bench_arg_names); - b->ArgsProduct({{1, 1024, 2048}, {4096}, {4096}, {8}}); -} - -[[maybe_unused]] static const bool benchmarks_registered = []() { - const bool is_q4gemm_supported = MlasQ4GemmPackBSize(BlkQ4Sym, 1, 1) > 0; - if (is_q4gemm_supported) { - BENCHMARK_CAPTURE(Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); - BENCHMARK_CAPTURE(Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); - BENCHMARK_CAPTURE(Q4GEMM, Q4Sym128, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); - BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym, BlkQ4Sym)->Apply(GemmSizeProducts)->UseRealTime(); - BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Zp8, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); - BENCHMARK_CAPTURE(Q8Q4GEMM, Q4Sym128, BlkQ4Zp8)->Apply(GemmSizeProducts)->UseRealTime(); - return true; - } - return false; -}(); diff --git a/onnxruntime/test/mlas/bench/bench_qgemm.cpp b/onnxruntime/test/mlas/bench/bench_qgemm.cpp deleted file mode 100644 index 29a68f6aec6e6..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_qgemm.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas.h" -#include "bench_util.h" -#include "core/util/thread_utils.h" - -#include -#include -#include -#include - -static const std::vector qgemm_arg_names = {"M", "N", "K", "Batch", "Threads"}; - -void QGEMM(benchmark::State& state, bool pack_b, bool a_is_signed) { - constexpr bool b_is_signed = true; - constexpr uint8_t a_zero_point = 29; - constexpr uint8_t b_zero_point = 179; - - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!"); - if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - - const size_t batch = static_cast(state.range(3)); - const size_t threads = static_cast(state.range(4)); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = int(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto A_holder = RandomVectorUniform(static_cast(M * K * batch), uint8_t(-100), uint8_t(100)); - auto B_holder = RandomVectorUniform(static_cast(N * K * batch), uint8_t(-110), uint8_t(110)); - std::vector C_holder(static_cast(M * N * batch)); - std::vector pack_b_holder; - - size_t packed_b_size = 0; - if (pack_b) { - packed_b_size = MlasGemmPackBSize(N, K, a_is_signed, b_is_signed); - pack_b_holder.resize(packed_b_size * batch); - } - - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; - - gemm_shape.M = static_cast(M); - gemm_shape.N = static_cast(N); - gemm_shape.K = static_cast(K); - gemm_shape.AIsSigned = a_is_signed; - gemm_shape.BIsSigned = b_is_signed; - - std::vector gemm_data_vec(batch); - for (size_t i = 0; i < batch; i++) { - auto& gemm_params = gemm_data_vec[i]; - gemm_params.lda = gemm_shape.K; - gemm_params.ZeroPointA = a_zero_point; - gemm_params.ZeroPointB = &b_zero_point; - gemm_params.ldc = gemm_shape.N; - gemm_params.A = A_holder.data() + M * K * i; - gemm_params.B = B_holder.data() + N * K * i; - gemm_params.ldb = gemm_shape.N; - gemm_params.C = C_holder.data() + M * N * i; - if (pack_b) { - MlasGemmPackB(N, K, (const uint8_t*)gemm_params.B, N, a_is_signed, b_is_signed, (void*)(pack_b_holder.data() + packed_b_size * i)); - gemm_params.BIsPacked = true; - gemm_params.B = (void*)(pack_b_holder.data() + packed_b_size * i); - } - } - for (auto _ : state) { - MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch, tp.get()); - } -} - -static void QGemmSize(benchmark::internal::Benchmark* b) { - b->ArgNames(qgemm_arg_names); - // Args for "M", "N", "K", "Batch", "Threads" - - b->Args({384, 1024, 1024, 1, 4}); - b->Args({384, 1024, 3072, 1, 4}); - b->Args({384, 1024, 4096, 1, 4}); - b->Args({384, 4096, 1024, 1, 4}); - b->Args({384, 1024, 1024, 1, 16}); - b->Args({384, 1024, 3072, 1, 16}); - b->Args({384, 1024, 4096, 1, 16}); - b->Args({384, 4096, 1024, 1, 16}); - b->Args({1536, 1024, 1024, 1, 16}); - b->Args({1536, 1024, 3072, 1, 16}); - b->Args({1536, 1024, 4096, 1, 16}); - b->Args({1536, 4096, 1024, 1, 16}); - b->Args({3072, 1024, 1024, 1, 16}); - b->Args({3072, 1024, 3072, 1, 16}); - b->Args({3072, 1024, 4096, 1, 16}); - b->Args({3072, 4096, 1024, 1, 16}); -} - -BENCHMARK_CAPTURE(QGEMM, UnsignedAPackB, true, false)->Apply(QGemmSize)->UseRealTime(); -BENCHMARK_CAPTURE(QGEMM, UnsignedANoPackB, false, false)->Apply(QGemmSize)->UseRealTime(); -#if !defined(MLAS_TARGET_AMD64) -// QGEMM is not supported for signed A, signed B (Packed) on AMD64 CPU. The -// benchmark assumes MlasGemmPackBSize return non-zero is not true. -BENCHMARK_CAPTURE(QGEMM, SignedAPackB, true, true)->Apply(QGemmSize)->UseRealTime(); -#endif -BENCHMARK_CAPTURE(QGEMM, SignedANoPackB, false, true)->Apply(QGemmSize)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp deleted file mode 100644 index 39d135236b89c..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas.h" -#include "bench_util.h" - -#include -#include - -static std::vector BuildArgNamesForConv(size_t rank) { - std::vector names = {"Rank", "N", "G", "Cpg", "Fpg"}; - - size_t arg_position = names.size(); - names.resize(arg_position + rank * 6, std::string("")); - - names[arg_position] = "I"; - arg_position += rank; - - names[arg_position] = "K"; - arg_position += rank; - - names[arg_position] = "P"; - arg_position += rank * 2; - - names[arg_position] = "S"; - arg_position += rank; - - names[arg_position] = "D"; - - return names; -} - -static const std::vector& ArgNamesForConv(size_t rank) { - static std::map> rank_to_args_name; - if (rank_to_args_name.find(rank) == rank_to_args_name.end()) { - rank_to_args_name.emplace(std::make_pair(rank, BuildArgNamesForConv(rank))); - } - return rank_to_args_name[rank]; -} - -// dummy for some strange build error when using Bench capture -void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { - const int64_t rank = state.range(0); // Rank - const int64_t batch_size = state.range(1); // N - const int64_t groups = state.range(2); // G - const int64_t input_channels_per_group = state.range(3); // Cpg - const int64_t output_channels_per_group = state.range(4); // Fpg - - if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); - if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); - if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); - if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); - if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); - - size_t arg_position = 5; - const auto input_shape = BenchArgsVector(state, arg_position, rank); - const auto kernel_shape = BenchArgsVector(state, arg_position, rank); - const auto paddings = BenchArgsVector(state, arg_position, rank * 2); - const auto strides = BenchArgsVector(state, arg_position, rank); - const auto dilations = BenchArgsVector(state, arg_position, rank); - - // do not check the size of each vector as they are forced from args. - if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { - throw std::invalid_argument("all input image dim must > 0"); - } - - if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { - throw std::invalid_argument("all kernel dim must > 0"); - } - - if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { - throw std::invalid_argument("all strides dim must > 0"); - } - - if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { - throw std::invalid_argument("all dilations dim must > 0"); - } - - const int64_t GC = groups * input_channels_per_group; - const int64_t GF = groups * output_channels_per_group; - std::vector x_shape = {batch_size, GC}; - x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); - std::vector f_shape = {GF, input_channels_per_group}; - f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); - - std::vector output_shape((size_t)rank); - for (int64_t i = 0; i < rank; ++i) { - auto km = 1 + dilations[i] * (kernel_shape[i] - 1); - output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1; - } - std::vector y_shape = {batch_size, GF}; - y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); - - MLAS_ACTIVATION activation; - activation.ActivationKind = MlasIdentityActivation; - MLAS_CONV_PARAMETERS Parameters; - size_t WorkingBufferSize = 0; - MlasConvPrepare(&Parameters, - static_cast(rank), - static_cast(batch_size), - static_cast(groups), - static_cast(input_channels_per_group), - input_shape.data(), - kernel_shape.data(), - dilations.data(), - paddings.data(), - strides.data(), - output_shape.data(), - static_cast(output_channels_per_group), - &activation, - &WorkingBufferSize, - 0.0f, - nullptr); - - auto X = RandomVectorUniform(x_shape, -2.0, 2.0); - auto F = RandomVectorUniform(f_shape, -1.0, 1.0); - int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); - std::vector Y(static_cast(y_size)); - std::vector working_buffer(WorkingBufferSize); - - // warm up first round. - MlasConv(&Parameters, - X.data(), - F.data(), - nullptr, - working_buffer.data(), - Y.data(), - nullptr); - - for (auto _ : state) { - MlasConv(&Parameters, - X.data(), - F.data(), - nullptr, - working_buffer.data(), - Y.data(), - nullptr); - } -} - -static void ResNet50(benchmark::internal::Benchmark* b) { - b->ArgNames(ArgNamesForConv(2)); - - //************************* Conv 1 ************************* - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 3, 64, 224, 224, 7, 7, 3, 3, 3, 3, 2, 2, 1, 1}); - - //************************ Conv 2.1 ************************ - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 64, 64, 56, 56, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); - b->Args({2, 1, 1, 64, 256, 56, 56, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - // b->Args({2, 1, 1, 64,256, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1}); - - //************************ Conv 2.X ************************ - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 256, 64, 56, 56, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - // b->Args({2, 1, 1, 64, 64, 56, 56, 3,3, 1,1,1,1, 1,1, 1,1}); - // b->Args({2, 1, 1, 64,256, 56, 56, 1,1, 0,0,0,0, 1,1, 1,1}); - - /************************ Conv 3.1 ************************/ - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 256, 128, 56, 56, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1}); - b->Args({2, 1, 1, 128, 512, 28, 28, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 256, 512, 56, 56, 1, 1, 0, 0, 0, 0, 2, 2, 1, 1}); - - /************************ Conv 3.X ************************/ - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 512, 128, 28, 28, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); - // b->Args({2, 1, 1,128,512, 28, 28, 1,1, 0,0,0,0, 1,1, 1,1}); - - /************************ Conv 4.1 ************************/ - // Rank, N, G,Cpg,Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 512, 256, 28, 28, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1}); - b->Args({2, 1, 1, 256, 1024, 14, 14, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 512, 1024, 28, 28, 1, 1, 0, 0, 0, 0, 2, 2, 1, 1}); - - /************************ Conv 4.X ************************/ - // Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 1024, 256, 14, 14, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); - // b->Args({2, 1, 1, 256, 1024, 14, 14, 1,1, 0,0,0,0, 1,1, 1,1}); - - /************************ Conv 5.1 ************************/ - // Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 1024, 512, 14, 14, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1}); - b->Args({2, 1, 1, 512, 2048, 7, 7, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 1024, 2048, 14, 14, 1, 1, 0, 0, 0, 0, 2, 2, 1, 1}); - - /************************ Conv 5.X ************************/ - // Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 2048, 512, 7, 7, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - b->Args({2, 1, 1, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); - // b->Args({2, 1, 1, 512,2048, 7, 7, 1,1, 0,0,0,0, 1,1, 1,1}); -} - -BENCHMARK_CAPTURE(SCONV_NCHW, ResNet50, "")->Apply(ResNet50)->UseRealTime(); - -static void TeamsModel(benchmark::internal::Benchmark* b) { - b->ArgNames(ArgNamesForConv(2)); - // Rank, N, G, Cpg, Fpg, I, , K, , P, , , , S, , D, , - b->Args({2, 1, 1, 40, 24, 24, 40, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // fused conv_349 => 24x40 - b->Args({2, 1, 1, 24, 24, 24, 40, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // fused Conv_367 => 24x40 - b->Args({2, 1, 1, 4, 24, 96, 160, 3, 3, 0, 0, 1, 1, 2, 2, 1, 1}); // fused Conv_15 => 48x80 - b->Args({2, 1, 1, 12, 72, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_38 => 48x80 - b->Args({2, 1, 1, 12, 8, 48, 80, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // fused Conv_395 => 48x80 - b->Args({2, 1, 24, 1, 1, 48, 80, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // fused Conv_33 => 48x80 - b->Args({2, 1, 1, 8, 8, 48, 80, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // fused Conv_413 => 48x80 - b->Args({2, 1, 72, 1, 1, 48, 80, 3, 3, 0, 0, 1, 1, 2, 2, 1, 1}); // fused Conv_56 => 24x40 - b->Args({2, 1, 72, 1, 1, 24, 40, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1}); // fused Conv_79 => 24x40 - b->Args({2, 1, 1, 24, 12, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // Conv_36 => 48x80 - b->Args({2, 1, 1, 12, 72, 24, 40, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_61/85 => 24x40 - b->Args({2, 1, 1, 24, 144, 12, 20, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_108/132 => 12x20 - - b->Args({2, 1, 1, 12, 12, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_376 => 48x80 - b->Args({2, 1, 1, 12, 72, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // Conv_59 => 24x40 -} - -BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); - -static void General_Conv2d(benchmark::internal::Benchmark* b) { - b->ArgNames(ArgNamesForConv(2)); - b->ArgsProduct( - {{2}, // Rank, - {1}, // N - {1, 2}, // Groups - {3, 12}, // Cpg - {6}, // Fpg - {24, 72}, // Input Image Shape - {36}, - {3}, // kernel shape - {3}, - {0}, // paddings - {0}, - {0}, - {0}, - {1}, // strides - {1}, - {1}, // dilations - {1}}); -} - -BENCHMARK_CAPTURE(SCONV_NCHW, 2d, "")->Apply(General_Conv2d)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp deleted file mode 100644 index a94d33cd77f63..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas.h" -#include "bench_util.h" -#include "core/util/thread_utils.h" - -#include -#include - -static const std::vector sgemm_bench_arg_names = {"M", "N", "K"}; - -void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, float alpha = 1.0f, float beta = 0.0f) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - - auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); - std::vector C(static_cast(M * N)); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = 8; - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); - std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); - - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A.data(), - trans_a ? M : K, - B_packed.data(), - beta, - C.data(), - N, - tp.get()); - - for (auto _ : state) { - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A.data(), - trans_a ? M : K, - B_packed.data(), - beta, - C.data(), - N, - tp.get()); - } - - } else { - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - trans_b ? CblasTrans : CblasNoTrans, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A.data(), - trans_a ? M : K, - B.data(), - trans_b ? K : N, - beta, - C.data(), - N, - tp.get()); - - for (auto _ : state) { - MlasGemm( - trans_a ? CblasTrans : CblasNoTrans, - trans_b ? CblasTrans : CblasNoTrans, - static_cast(M), - static_cast(N), - static_cast(K), - alpha, - A.data(), - trans_a ? M : K, - B.data(), - trans_b ? K : N, - beta, - C.data(), - N, - tp.get()); - } - } -} - -static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { - b->ArgNames(sgemm_bench_arg_names); - b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); - b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); - b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); -} - -static void GemmSizeProducts(benchmark::internal::Benchmark* b) { - b->ArgNames(sgemm_bench_arg_names); - b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); -} - -BENCHMARK_CAPTURE(SGEMM, NORMAL_NoTrans, false, false, false)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, NORMAL_TransA, false, true, false)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, NORMAL_TransB, false, false, true)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, NORMAL_ABTrans, false, true, true)->Apply(GemmSizeProducts)->UseRealTime(); - -BENCHMARK_CAPTURE(SGEMM, GEMV_NoTrans, false, false, false)->Apply(GemmSizeWithOne)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, GEMV_TransA, false, true, false)->Apply(GemmSizeWithOne)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, GEMV_TransB, false, false, true)->Apply(GemmSizeWithOne)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, GEMV_ABTrans, false, true, true)->Apply(GemmSizeWithOne)->UseRealTime(); - -BENCHMARK_CAPTURE(SGEMM, PACKB_NoTransA, true, false, false)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProducts)->UseRealTime(); - -static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { - b->ArgNames(sgemm_bench_arg_names); - b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); -} - -BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp deleted file mode 100644 index 71db7d81075b5..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas_q4.h" -#include "mlas_qnbit.h" - -#include -#include -#include -#include - -#include "benchmark/benchmark.h" - -#include "bench_util.h" -#include "core/common/narrow.h" -#include "core/util/thread_utils.h" -#include "core/platform/env_var_utils.h" - -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); - return; - } - - size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, - static_cast(K), static_cast(N), - QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(Threads); - tpo.auto_set_affinity = true; - - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); - - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); - - std::vector C(static_cast(M * N)); - - std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); - std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); - bool has_zp_input = !Symmetric; - - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), - Symmetric ? nullptr : QuantBZeroPoint.data(), - B.data(), static_cast(BlkLen), /* columnwise */ true, - static_cast(K), static_cast(N), static_cast(N), - tp.get()); - - std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); - WorkspaceSize > 0) { - Workspace = std::make_unique(WorkspaceSize); - } - - std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); - PackedQuantBDataSize > 0) { - PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); - } - - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; - params.A = A.data(); - params.lda = K; - if (PackedQuantBData != nullptr) - params.QuantBDataWorkspace = PackedQuantBData.get(); - else - params.QuantBDataWorkspace = static_cast(QuantBData.data()); - - params.PackedQuantBData = PackedQuantBData.get(); - params.QuantBScale = QuantBScale.data(); - params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); - params.Bias = HasBias ? Bias.data() : nullptr; - params.C = C.data(); - params.ldc = N; - - // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); - - for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); - } -} - -template -void SQNBITGEMM(benchmark::State& state) { - using onnxruntime::narrow; - - const auto BlkLen = narrow(state.range(0)); - const auto M = narrow(state.range(1)); - const auto N = narrow(state.range(2)); - const auto K = narrow(state.range(3)); - const auto Threads = narrow(state.range(4)); - const auto Symmetric = narrow(state.range(5)); - const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); - - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); -} - -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { - b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); - - b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType - }); -} - -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); - -// This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { - using onnxruntime::ParseEnvironmentVariableWithDefault; - - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); - - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); - - std::ostringstream s; - s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen - << "/M:" << M << "/N:" << N << "/K:" << K - << "/Threads:" << Threads << "/Symmetric:" << Symmetric << "/HasBias:" << HasBias - << "/ComputeType:" << ComputeType; - state.SetLabel(s.str()); -} - -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_symm_qgemm.cpp b/onnxruntime/test/mlas/bench/bench_symm_qgemm.cpp deleted file mode 100644 index fac9350b50914..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_symm_qgemm.cpp +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "mlas.h" -#include "bench_util.h" -#include "core/util/thread_utils.h" - -#include -#include -#include -#include - -static const std::vector qgemm_arg_names = {"M", "N", "K", "Batch", "Threads"}; - -void SYMMQGEMM(benchmark::State& state, bool a_signed) { - const int8_t a_zero_point = 29; - - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Batch must greater than 0!"); - if (state.range(4) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - - const size_t batch = static_cast(state.range(3)); - const size_t threads = static_cast(state.range(4)); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = int(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp( - onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), - tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto A_holder = RandomVectorUniform(static_cast(M * K * batch) + 16, int8_t(-120), int8_t(120)); - auto B_holder = RandomVectorUniform(static_cast(N * K * batch), int8_t(-122), int8_t(122)); - std::vector C_holder(static_cast(M * N * batch)); - std::vector pack_b_holder; - - size_t packed_b_size = MlasSymmQgemmPackBSize(N, K, a_signed); - pack_b_holder.resize(packed_b_size * batch); - - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; - - gemm_shape.M = static_cast(M); - gemm_shape.N = static_cast(N); - gemm_shape.K = static_cast(K); - gemm_shape.AIsSigned = true; - gemm_shape.BIsSigned = true; - - std::vector gemm_data_vec(batch); - for (size_t i = 0; i < batch; i++) { - auto& gemm_params = gemm_data_vec[i]; - gemm_params.lda = gemm_shape.K; - gemm_params.ldc = gemm_shape.N; - gemm_params.A = A_holder.data() + M * K * i; - gemm_params.C = C_holder.data() + M * N * i; - - MlasSymmQgemmPackB(N, K, (const int8_t*)gemm_params.B, N, a_signed, a_zero_point, (void*)(pack_b_holder.data() + packed_b_size * i)); - gemm_params.B = (void*)(pack_b_holder.data() + packed_b_size * i); - } - for (auto _ : state) { - MlasSymmQgemmBatch(gemm_shape, gemm_data_vec.data(), batch, tp.get()); - } -} - -#if defined(MLAS_TARGET_ARM64) -static void SymmQGemmSize(benchmark::internal::Benchmark* b) { - b->ArgNames(qgemm_arg_names); - // Args for "M", "N", "K", "Batch", - - b->Args({512, 32128, 768, 1, 1}); - b->Args({512, 32128, 768, 1, 4}); - b->Args({512, 32128, 768, 1, 6}); - - b->Args({512, 3072, 768, 1, 1}); - b->Args({512, 3072, 768, 1, 4}); - b->Args({512, 3072, 768, 1, 6}); - - b->Args({512, 768, 3072, 1, 1}); - b->Args({512, 768, 3072, 1, 4}); - b->Args({512, 768, 3072, 1, 6}); - - b->Args({512, 768, 768, 1, 1}); - b->Args({512, 768, 768, 1, 4}); - b->Args({512, 768, 768, 1, 6}); - - b->Args({512, 64, 512, 1, 1}); - b->Args({512, 64, 512, 1, 4}); - b->Args({512, 64, 512, 1, 6}); - - b->Args({512, 512, 64, 12, 1}); - b->Args({512, 512, 64, 12, 4}); - b->Args({512, 512, 64, 12, 6}); - - b->Args({512, 64, 512, 12, 1}); - b->Args({512, 64, 512, 12, 4}); - b->Args({512, 64, 512, 12, 6}); -} - -BENCHMARK_CAPTURE(SYMMQGEMM, SignedActivation, true)->Apply(SymmQGemmSize)->UseRealTime(); -#endif diff --git a/onnxruntime/test/mlas/bench/bench_util.cpp b/onnxruntime/test/mlas/bench/bench_util.cpp deleted file mode 100644 index 6b59b7e01b46f..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_util.cpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "bench_util.h" -#include -#include - -std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count) { - std::vector shape; - shape.reserve(count); - for (size_t axis = 0; axis < count; ++axis) { - shape.emplace_back(state.range(start + axis)); - } - start += count; - return shape; -} - -std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value) { - int64_t sz = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); - if (sz <= 0) { - throw std::invalid_argument("shape gives size must greater than 0!"); - } - return RandomVectorUniform(static_cast(sz), min_value, max_value); -} diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h deleted file mode 100644 index f96dd5c673b3d..0000000000000 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include -#include - -template -std::vector RandomVectorUniform( - size_t N, - ElementType min_value = std::numeric_limits::lowest(), - ElementType max_value = std::numeric_limits::max()) { - if (min_value >= max_value) { - return std::vector(N, min_value); - } - std::default_random_engine generator(static_cast(N)); - std::uniform_real_distribution distribution(static_cast(min_value), static_cast(max_value)); - - std::vector r(N); - for (size_t i = 0; i < N; i++) { - r[i] = static_cast(distribution(generator)); - } - return r; -} - -std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); - -std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/onnxruntime/test/mlas/unittest/test_activation.cpp b/onnxruntime/test/mlas/unittest/test_activation.cpp deleted file mode 100644 index a4334c6c80477..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_activation.cpp +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include "test_util.h" - -class MlasActivationTest : public MlasTestBase { - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("Activation"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - union AliasedValue { - unsigned u; - float f; - }; - - // N.B. The test data includes values at the edge of Tanh/Logistic boundaries. - // Identity, Relu, LeakyRelu, Tanh, Logistic, Clip, HardSigmoid - static const AliasedValue TestData[20][7] = { - { - {0x00000001}, - {0x00000001}, - {0x00000001}, - {0x00000000}, - {0x3f000000}, - {0x00000001}, - {0x3df5c28f}, - }, // positive denormal - { - {0x80000001}, - {0x00000000}, - {0x80000000}, - {0x80000000}, - {0x3f000000}, - {0x00000000}, - {0x3df5c28f}, - }, // negative denormal - { - {0x7ff00002}, - {0x7ff00002}, - {0x7ff00002}, - {0x7ff00002}, - {0x7ff00002}, - {0x7ff00002}, - {0x7ff00002}, - }, // positive NaN - { - {0xfff00002}, - {0xfff00002}, - {0xfff00002}, - {0xfff00002}, - {0xfff00002}, - {0xfff00002}, - {0xfff00002}, - }, // negative NaN - { - {0x00000000}, - {0x00000000}, - {0x00000000}, - {0x00000000}, - {0x3f000000}, - {0x00000000}, - {0x3df5c28f}, - }, // 0.0f - { - {0x80000000}, - {0x80000000}, - {0x80000000}, - {0x80000000}, - {0x3f000000}, - {0x80000000}, - {0x3df5c28f}, - }, // -0.0f - { - {0x3e800000}, - {0x3e800000}, - {0x3e800000}, - {0x3e7acbf5}, - {0x3f0feacc}, - {0x3e800000}, - {0x3e2e147b}, - }, // 0.25f - { - {0xbe800000}, - {0x00000000}, - {0xbd4ccccd}, - {0xbe7acbf5}, - {0x3ee02a67}, - {0x00000000}, - {0x3d8f5c28}, - }, // -0.25f - { - {0x40800000}, - {0x40800000}, - {0x40800000}, - {0x3f7fd40a}, - {0x3f7b6541}, - {0x40800000}, - {0x3f6b851f}, - }, // 4.0f - { - {0xc0800000}, - {0x00000000}, - {0xbf4ccccd}, - {0xbf7fd40a}, - {0x3c9357e0}, - {0x00000000}, - {0x00000000}, - }, // -4.0f - { - {0x41200000}, - {0x41200000}, - {0x41200000}, - {0x3f800000}, - {0x3f7ffd06}, - {0x40c00000}, - {0x3f800000}, - }, // 10.0f - { - {0xc1200000}, - {0x00000000}, - {0xc0000000}, - {0xbf800000}, - {0x383e6000}, - {0x00000000}, - {0x00000000}, - }, // -10.0f - { - {0xc18866eb}, - {0x00000000}, - {0xc05a3e45}, - {0xbf800000}, - {0x33000000}, - {0x00000000}, - {0x00000000}, - }, // -17.0502529144f - { - {0xc18869bb}, - {0x00000000}, - {0xc05a42c5}, - {0xbf800000}, - {0x33c00000}, - {0x00000000}, - {0x00000000}, - }, // -17.0516262054f - { - {0xc18852a8}, - {0x00000000}, - {0xc05a1dda}, - {0xbf800000}, - {0x00000000}, - {0x00000000}, - {0x00000000}, - }, // -17.0403594971f - { - {0xc18844aa}, - {0x00000000}, - {0xc05a0777}, - {0xbf800000}, - {0x00000000}, - {0x00000000}, - {0x00000000}, - }, // -17.0335273743f - { - {0x418866eb}, - {0x418866eb}, - {0x418866eb}, - {0x3f800000}, - {0x3f800000}, - {0x40c00000}, - {0x3f800000}, - }, // +17.0502529144f - { - {0x418869bb}, - {0x418869bb}, - {0x418869bb}, - {0x3f800000}, - {0x3f7ffffe}, - {0x40c00000}, - {0x3f800000}, - }, // +17.0516262054f - { - {0x418852a8}, - {0x418852a8}, - {0x418852a8}, - {0x3f800000}, - {0x3f800000}, - {0x40c00000}, - {0x3f800000}, - }, // +17.0403594971f - { - {0x418844aa}, - {0x418844aa}, - {0x418844aa}, - {0x3f800000}, - {0x3f800000}, - {0x40c00000}, - {0x3f800000}, - }, // +17.0335273743f - }; - - MLAS_ACTIVATION Activation; - AliasedValue Buffer[_countof(TestData)]; - - for (unsigned kind = 0; kind < unsigned(MlasActivationKindCount); kind++) { - Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind); - - if (Activation.ActivationKind == MlasLeakyReluActivation) { - Activation.Parameters.LeakyRelu.alpha = 0.2f; - } else if (Activation.ActivationKind == MlasClipActivation) { - Activation.Parameters.Clip.minimum = 0.0f; - Activation.Parameters.Clip.maximum = 6.0f; - } else if (Activation.ActivationKind == MlasHardSigmoidActivation) { - Activation.Parameters.HardSigmoid.alpha = 0.2f; - Activation.Parameters.HardSigmoid.beta = 0.12f; - } - - // - // Test the vectorized activations. - // - - for (unsigned i = 0; i < _countof(TestData); i++) { - Buffer[i].u = TestData[i][0].u; - } - - MlasActivation(&Activation, &Buffer[0].f, nullptr, 1, _countof(Buffer), _countof(Buffer)); - // TODO: Fix the test once centos has updated to almalinux - // for (unsigned i = 0; i < _countof(TestData); i++) { - // // Sensitive to comparing positive/negative zero and NaNs. - // EXPECT_TRUE(Buffer[i].u == TestData[i][kind].u || Buffer[i].f == TestData[i][kind].f) - // << ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:" - // << std::setw(8) << std::setfill('0') << std::hex << Buffer[i].u << ", expecting:" - // << std::setw(8) << std::setfill('0') << std::hex << TestData[i][kind].u; - // } - - // - // Test the scalar activations. - // - - for (unsigned i = 0; i < _countof(TestData); i++) { - Buffer[i].u = TestData[i][0].u; - MlasActivation(&Activation, &Buffer[i].f, nullptr, 1, 1, 1); - } - - for (unsigned i = 0; i < _countof(TestData); i++) { - // Sensitive to comparing positive/negative zero and NaNs. - float error = std::min(std::fabs((Buffer[i].f - TestData[i][kind].f) / TestData[i][kind].f), std::fabs(Buffer[i].f - TestData[i][kind].f)); - EXPECT_TRUE(Buffer[i].u == TestData[i][kind].u || Buffer[i].f == TestData[i][kind].f || error < 0.000001f) - << ", Scalar Activation Kind:" << (int)kind << ", i=" << i << ", value:" - << std::setw(8) << std::setfill('0') << std::hex << Buffer[i].u << ", expecting:" - << std::setw(8) << std::setfill('0') << std::hex << TestData[i][kind].u; - } - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_blkq8.cpp b/onnxruntime/test/mlas/unittest/test_blkq8.cpp deleted file mode 100644 index 5cff86d411ca9..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_blkq8.cpp +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ORT_MINIMAL_BUILD - -#include "test_util.h" -#include "mlas_q4.h" - -#define QK8_0 64 -typedef struct { - float d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - -static void quantize_reference(const float* src, void* dst, size_t M, size_t k) { - const size_t nb = k / QK8_0; - block_q8_0* blob = reinterpret_cast(dst); - - for (size_t m = 0; m < M; m++) { - for (size_t i = 0; i < nb; i++, blob++, src += QK8_0) { - float amax = 0.0f; // absolute max - - for (size_t j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = std::max(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f / d : 0.0f; - - blob->d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j] * id; - - blob->qs[j] = (int8_t)roundf(x0); - } - } - - const size_t remain = k % QK8_0; - if (remain > 0) { - float amax = 0.0f; // absolute max - - for (size_t j = 0; j < remain; j++) { - const float v = src[j]; - amax = std::max(amax, fabsf(v)); - } - - const float d = amax / 127.f; - const float id = (amax != 0.0f) ? 127.f / amax : 0.0f; - - blob->d = d; - - for (size_t j = 0; j < remain; ++j) { - const float x0 = src[j] * id; - - blob->qs[j] = (int8_t)roundf(x0); - } - for (size_t j = remain; j < QK8_0; ++j) { - blob->qs[j] = 0; - } - blob++; - src += remain; - } - } -} - -template -class MlasBlkQ8Test : public MlasTestBase { - private: - MatrixGuardBuffer FpInputBuf; - MatrixGuardBuffer PackedBuf; - MatrixGuardBuffer ReferenceBuf; - MLAS_THREADPOOL* threadpool_; - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name = std::string("Q8DQ") + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void Test(size_t M, size_t K) { - float* Input = FpInputBuf.GetBuffer(M * K); - - const size_t qsize = MlasQ80BlkQuantSize(BlkQ4Sym64, M, K); - int8_t* Packed = PackedBuf.GetBuffer(qsize, true); - int8_t* Ref = ReferenceBuf.GetBuffer(qsize, true); - - MlasQ80BlkQuant(BlkQ4Sym64, Packed, Input, M, K, K, threadpool_); - quantize_reference(Input, Ref, M, K); - - for (size_t i = 0; i < qsize; i++) { - ASSERT_EQ(Packed[i], Ref[i]) << ", index=" << i << ", [" << M << "x" - << K << "]"; - } - } - - MlasBlkQ8Test() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} -}; - -template -class MlasBlkQ8ShortExeTest : public MlasTestFixture> { - public: - explicit MlasBlkQ8ShortExeTest(size_t M, size_t K) : M_(M), K_(K) {} - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, K_); - } - - static size_t RegisterSingleTest(size_t M, size_t K) { - std::stringstream ss; - ss << "/M" << M << "xK" << K; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasBlkQ8Test::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new MlasBlkQ8ShortExeTest( - M, K); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - test_registered += RegisterSingleTest(1, 13); - test_registered += RegisterSingleTest(1, 20); - test_registered += RegisterSingleTest(1, 52); - test_registered += RegisterSingleTest(1, 70); - test_registered += RegisterSingleTest(3, 13); - test_registered += RegisterSingleTest(3, 20); - test_registered += RegisterSingleTest(3, 52); - test_registered += RegisterSingleTest(3, 70); - test_registered += RegisterSingleTest(41, 305); - test_registered += RegisterSingleTest(83, 497); - - return test_registered; - } - - private: - size_t M_, K_; -}; - -static size_t BlkQ8ReisterShortTests() { - size_t cnt = 0; - cnt += MlasBlkQ8ShortExeTest::RegisterShortExecuteTests(); - cnt += MlasBlkQ8ShortExeTest::RegisterShortExecuteTests(); - return cnt; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (MlasQ80BlkQuantSize(BlkQ4Sym, 32, 32) == 0) { - return false; // operation not yet supported on current hardware - } - if (is_short_execute) { - return BlkQ8ReisterShortTests() > 0; - } - return false; -}); - -#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp deleted file mode 100644 index f75002f715154..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ /dev/null @@ -1,271 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_blockq4.cpp - -Abstract: - - Tests for MLAS blockwise int4 quantization and dequantization code. - ---*/ - -#ifndef ORT_MINIMAL_BUILD - -#include "test_util.h" -#include "mlas_q4.h" - -class MlasBlockwiseQdqTest : public MlasTestBase { - private: - MatrixGuardBuffer FpBuf; - MatrixGuardBuffer FpBuf2; - MatrixGuardBuffer InputElements; - MatrixGuardBuffer InputScales; - MatrixGuardBuffer InputOffsets; - MatrixGuardBuffer OutputElements; - MatrixGuardBuffer OutputScales; - MatrixGuardBuffer OutputOffsets; - MatrixGuardBuffer QDQOutputElements; - MatrixGuardBuffer QDQOutputScales; - MatrixGuardBuffer QDQOutputOffsets; - MatrixGuardBuffer QDQTransposedOutputElements; - MatrixGuardBuffer QDQTransposedOutputScales; - MatrixGuardBuffer QDQTransposedOutputOffsets; - - void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { - float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); - float* transposed = FpBuf2.GetBuffer(rows * columns, true); - size_t scale_size = (rows + block_size - 1) / block_size * columns; - size_t zp_size = (scale_size + 1) / 2; - - MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); - - int meta_rows; - int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); - - int q_rows; - int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); - - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - - uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); - uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); - uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - - int v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - elements[idx] = (v1 << 4) | v0; - } - } - - float* scales = InputScales.GetBuffer(q_scale_size); - float* qdq_scales = QDQOutputScales.GetBuffer(scale_size); - float* qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size); - uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); - uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); - uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); - if (zp) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - zp[idx] = (v1 << 4) | v0; - } - } - } - - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, - columnwise, rows, columns, threadpool_ptr); - - MlasTranspose(dequant_buf, transposed, columns, rows); - - uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); - float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, - columnwise, rows, columns, columns, threadpool_ptr); - - if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, - true, rows, columns, block_size, threadpool_ptr); - - ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; - - if (symmetric) { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); - - } else { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); - } - } - - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - } - } - - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r++) { - int idx = c * meta_rows + r; - ASSERT_EQ(o_scales[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - - if (columnwise) { - ASSERT_EQ(qdq_scales_T[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - } - - if (symmetric) return; - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - } - } - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("BlockQ4"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - Test(20, 1, 32, true, false); - Test(20, 1, 32, true, true); - Test(1, 20, 32, false, false); - Test(1, 20, 32, false, true); - Test(52, 1, 32, true, false); - Test(52, 1, 32, true, true); - Test(1, 52, 32, false, false); - Test(1, 52, 32, false, true); - Test(20, 3, 32, true, false); - Test(20, 3, 32, true, true); - Test(3, 20, 32, false, false); - Test(3, 20, 32, false, true); - Test(52, 3, 32, true, false); - Test(52, 3, 32, true, true); - Test(3, 52, 32, false, false); - Test(3, 52, 32, false, true); - Test(52, 3, 64, true, false); - Test(52, 3, 64, true, true); - Test(3, 52, 64, false, false); - Test(3, 52, 64, false, true); - Test(32 * 9 + 17, 41, 32, true, false); - Test(32 * 9 + 17, 41, 32, true, true); - Test(41, 32 * 9 + 17, 32, false, false); - Test(41, 32 * 9 + 17, 32, false, true); - Test(32 * 9 + 17, 41, 64, true, false); - Test(32 * 9 + 17, 41, 64, true, true); - Test(41, 32 * 9 + 17, 64, false, false); - Test(41, 32 * 9 + 17, 64, false, true); - Test(32 * 15 + 17, 63, 128, true, false); - Test(32 * 15 + 17, 63, 128, true, true); - Test(63, 32 * 15 + 17, 128, false, false); - Test(63, 32 * 15 + 17, 128, false, true); - - Test(256, 256, 32, true, false); - Test(256, 256, 32, true, true); - Test(256, 256, 32, false, false); - Test(256, 256, 32, false, true); - } - - MlasBlockwiseQdqTest() = default; -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); - } - return count; -}); - -#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.cpp b/onnxruntime/test/mlas/unittest/test_conv2d.cpp deleted file mode 100644 index 1700cd8f1800f..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_conv2d.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_conv2d.h" -#include "test_conv2d_fixture.h" - -static size_t Conv2dRegistLongExecute() { - size_t count = MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - return count; -} - -static size_t Conv2dRegistShortExecute() { - size_t count = Conv2dShortExecuteTest>::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); - } - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? Conv2dRegistShortExecute() : Conv2dRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.h b/onnxruntime/test/mlas/unittest/test_conv2d.h deleted file mode 100644 index 20bf0ec84f5bf..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_conv2d.h +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -class MlasConv2DTest : public MlasTestBase { - protected: - virtual void MlasConv2D(size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth, - size_t OutputHeight, - size_t OutputWidth, - const float* Input, - const float* Filter, - const float* Bias, - float* Output) { - int64_t InputShape[] = {int64_t(InputHeight), int64_t(InputWidth)}; - int64_t KernelShape[] = {int64_t(KernelHeight), int64_t(KernelWidth)}; - int64_t DilationShape[] = {int64_t(DilationHeight), int64_t(DilationWidth)}; - int64_t Padding[] = {int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth)}; - int64_t StrideShape[] = {int64_t(StrideHeight), int64_t(StrideWidth)}; - int64_t OutputShape[] = {int64_t(OutputHeight), int64_t(OutputWidth)}; - - MLAS_ACTIVATION Activation; - Activation.ActivationKind = MlasIdentityActivation; - - MLAS_CONV_PARAMETERS Parameters; - size_t WorkingBufferSize; - - MlasConvPrepare(&Parameters, - 2, - BatchCount, - GroupCount, - InputChannels, - InputShape, - KernelShape, - DilationShape, - Padding, - StrideShape, - OutputShape, - FilterCount, - &Activation, - &WorkingBufferSize, - 0.0f, - threadpool_); - - MlasConv(&Parameters, - Input, - Filter, - Bias, - BufferWorking.GetBuffer(WorkingBufferSize), - Output, - threadpool_); - } - - void ReferenceConv2D( - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth, - size_t OutputHeight, - size_t OutputWidth, - const float* Input, - const float* Filter, - const float* Bias, - float* Output) { - size_t InputSize = InputHeight * InputWidth; - size_t OutputSize = OutputHeight * OutputWidth; - size_t KernelSize = KernelHeight * KernelWidth; - - size_t K = InputChannels * KernelSize; - size_t Im2ColElements = OutputSize * K; - - for (size_t b = 0; b < BatchCount; b++) { - const float* filter = Filter; - const float* bias = Bias; - - for (size_t g = 0; g < GroupCount; g++) { - // - // Transform the image using IM2COL and invoke the GEMM. - // - - float* Im2Col = BufferIm2Col.GetBuffer(Im2ColElements); - float* Im2ColOut = Im2Col; - - for (size_t c = 0; c < InputChannels; c++) { - for (size_t ky = 0; ky < KernelHeight; ky++) { - for (size_t kx = 0; kx < KernelWidth; kx++) { - for (size_t oh = 0; oh < OutputHeight; oh++) { - size_t ih = oh * StrideHeight + ky * DilationHeight - PaddingLeftHeight; - - for (size_t ow = 0; ow < OutputWidth; ow++) { - size_t iw = ow * StrideWidth + kx * DilationWidth - PaddingLeftWidth; - - *Im2ColOut++ = (ih < InputHeight && iw < InputWidth) ? Input[ih * InputWidth + iw] : 0; - } - } - } - } - - Input += InputSize; - } - - MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, OutputSize, K, 1.0f, - filter, K, Im2Col, OutputSize, 0.0f, Output, OutputSize, threadpool_); - - // - // Apply the bias. - // - - for (size_t f = 0; f < FilterCount; f++) { - float biasValue = *bias++; - - for (size_t o = 0; o < OutputSize; o++) { - *Output++ += biasValue; - } - } - - filter += FilterCount * InputChannels * KernelSize; - } - } - } - - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferFilter; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - MatrixGuardBuffer BufferWorking; - MatrixGuardBuffer BufferIm2Col; - - MLAS_THREADPOOL* threadpool_; - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name(Threaded ? "Conv2d_Threaded" : "Conv2d_SingleThread"); - return suite_name.c_str(); - } - - MlasConv2DTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test( - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth) { - int64_t OutputHeight64 = - ((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) - - (int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / - int64_t(StrideHeight) + - 1; - int64_t OutputWidth64 = - ((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) - - (int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / - int64_t(StrideWidth) + - 1; - - if (OutputHeight64 <= 0 || OutputWidth64 <= 0) { - return; - } - - size_t OutputHeight = size_t(OutputHeight64); - size_t OutputWidth = size_t(OutputWidth64); - - size_t InputSize = InputHeight * InputWidth; - size_t KernelSize = KernelHeight * KernelWidth; - size_t OutputSize = OutputHeight * OutputWidth; - - size_t InputElements = BatchCount * GroupCount * InputChannels * InputSize; - size_t FilterElements = GroupCount * FilterCount * InputChannels * KernelSize; - size_t BiasElements = GroupCount * FilterCount; - size_t OutputElements = BatchCount * GroupCount * FilterCount * OutputSize; - - const float* Input = BufferInput.GetBuffer(InputElements); - const float* Filter = BufferFilter.GetBuffer(FilterElements); - const float* Bias = BufferBias.GetBuffer(BiasElements); - float* Output = BufferOutput.GetBuffer(OutputElements); - float* OutputReference = BufferOutputReference.GetBuffer(OutputElements); - - MlasConv2D(BatchCount, - GroupCount, - InputChannels, - InputHeight, InputWidth, - FilterCount, - KernelHeight, KernelWidth, - PaddingLeftHeight, PaddingLeftWidth, - PaddingRightHeight, PaddingRightWidth, - DilationHeight, DilationWidth, - StrideHeight, StrideWidth, - OutputHeight, OutputWidth, - Input, - Filter, - Bias, - Output); - - ReferenceConv2D(BatchCount, - GroupCount, - InputChannels, - InputHeight, InputWidth, - FilterCount, - KernelHeight, KernelWidth, - PaddingLeftHeight, PaddingLeftWidth, - DilationHeight, DilationWidth, - StrideHeight, StrideWidth, - OutputHeight, OutputWidth, - Input, - Filter, - Bias, - OutputReference); - - ASSERT_EQ(memcmp(Output, OutputReference, OutputElements * sizeof(float)), 0) - << "B" << BatchCount << "/" - << "G" << GroupCount << "/" - << "Cpg" << InputChannels << "/" - << "Fpg" << FilterCount << "/" - << "H" << InputHeight << "/" - << "W" << InputWidth << "/" - << "KH" << KernelHeight << "/" - << "KW" << KernelWidth << "/" - << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Dilation" << DilationHeight << "," << DilationWidth << "/" - << "Stride" << StrideHeight << "," << StrideWidth; - } - - void ExecuteLong(void) override { - static const unsigned cs[] = {32, 14, 1}; - static const unsigned is[] = {53, 11, 5, 1}; - - for (unsigned i = 1; i <= 32; i++) { - Test(4, 18, 1, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1); - Test(4, 18, 1, 32, 89, 48, i, 89, 1, 1, 1, 1, 1, 1, 1, 1); - Test(4, 18, 2, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1); - } - - for (unsigned b = 1; b < 64; b++) { - Test(b, 1, 64, 11, 11, 128, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); - } - - for (unsigned gc = 0; gc < _countof(cs); gc++) { - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling depthwise %ux%ux%u\n", cs[gc], is[ih], is[iw]); - for (unsigned p0 = 0; p0 < 2; p0++) { - for (unsigned p1 = 0; p1 < 2; p1++) { - for (unsigned p2 = 0; p2 < 2; p2++) { - for (unsigned p3 = 0; p3 < 2; p3++) { - for (unsigned dh = 1; dh <= 2; dh++) { - for (unsigned dw = 1; dw <= 2; dw++) { - for (unsigned sh = 1; sh <= 2; sh++) { - for (unsigned sw = 1; sw <= 2; sw++) { - Test(1, cs[gc], 1, is[ih], is[iw], 1, 3, 3, p0, p1, p2, p3, dh, dw, sh, sw); - } - } - } - } - } - } - } - } - } - } - } - - for (unsigned ic = 0; ic < _countof(cs); ic++) { - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling %ux%ux%u\n", cs[ic], is[ih], is[iw]); - for (unsigned fc = 0; fc < _countof(cs); fc++) { - for (unsigned kh = 1; kh <= 5; kh++) { - if (kh == 4) continue; - for (unsigned kw = 1; kw <= 5; kw++) { - if (kw == 4) continue; - for (unsigned p0 = 0; p0 < 2; p0++) { - for (unsigned p1 = 0; p1 < 2; p1++) { - for (unsigned p2 = 0; p2 < 2; p2++) { - for (unsigned p3 = 0; p3 < 2; p3++) { - for (unsigned dh = 1; dh <= 2; dh++) { - for (unsigned dw = 1; dw <= 2; dw++) { - for (unsigned sh = 1; sh <= 2; sh++) { - for (unsigned sw = 1; sw <= 2; sw++) { - Test(1, 1, cs[ic], is[ih], is[iw], cs[fc], kh, kw, p0, p1, p2, p3, dh, dw, sh, sw); - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_fixture.h b/onnxruntime/test/mlas/unittest/test_conv2d_fixture.h deleted file mode 100644 index a9d5996f44d62..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_conv2d_fixture.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -// -// Short Execute that distinguish each test by all parameters. -// -template -class Conv2dShortExecuteTest : public MlasTestFixture { - public: - explicit Conv2dShortExecuteTest(size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth) - : BatchCount_(BatchCount), - GroupCount_(GroupCount), - InputChannels_(InputChannels), - InputHeight_(InputHeight), - InputWidth_(InputWidth), - FilterCount_(FilterCount), - KernelHeight_(KernelHeight), - KernelWidth_(KernelWidth), - PaddingLeftHeight_(PaddingLeftHeight), - PaddingLeftWidth_(PaddingLeftWidth), - PaddingRightHeight_(PaddingRightHeight), - PaddingRightWidth_(PaddingRightWidth), - DilationHeight_(DilationHeight), - DilationWidth_(DilationWidth), - StrideHeight_(StrideHeight), - StrideWidth_(StrideWidth) { - } - - void TestBody() override { - MlasTestFixture::mlas_tester->Test( - BatchCount_, - GroupCount_, - InputChannels_, - InputHeight_, - InputWidth_, - FilterCount_, - KernelHeight_, - KernelWidth_, - PaddingLeftHeight_, - PaddingLeftWidth_, - PaddingRightHeight_, - PaddingRightWidth_, - DilationHeight_, - DilationWidth_, - StrideHeight_, - StrideWidth_); - } - - static size_t RegisterSingleTest( - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth) { - std::stringstream ss; - ss << "B" << BatchCount << "/" - << "G" << GroupCount << "/" - << "Cpg" << InputChannels << "/" - << "Fpg" << FilterCount << "/" - << "H" << InputHeight << "/" - << "W" << InputWidth << "/" - << "KH" << KernelHeight << "/" - << "KW" << KernelWidth << "/" - << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Dilation" << DilationHeight << "," << DilationWidth << "/" - << "Stride" << StrideHeight << "," << StrideWidth; - auto test_name = ss.str(); - - testing::RegisterTest( - Conv2dTester::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture* { - return new Conv2dShortExecuteTest(BatchCount, - GroupCount, - InputChannels, - InputHeight, - InputWidth, - FilterCount, - KernelHeight, - KernelWidth, - PaddingLeftHeight, - PaddingLeftWidth, - PaddingRightHeight, - PaddingRightWidth, - DilationHeight, - DilationWidth, - StrideHeight, - StrideWidth); - }); - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - for (unsigned i = 1; i < 256; i <<= 1) { - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 3, 3, 0, 0, 0, 0, 2, 2, 1, 1); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, i, 1, 0, 0, 0, 0, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 1, 16, i, i, 32, 1, i, 0, 0, 0, 0, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, 1, i, i, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, 1, i, i, 1, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2); - test_registered += RegisterSingleTest(1, 16, 1, i, i, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, 1, i, i, 1, 3, 3, 1, 1, 1, 1, 1, 1, 2, 2); - } - return test_registered; - } - - private: - size_t BatchCount_; - size_t GroupCount_; - size_t InputChannels_; - size_t InputHeight_; - size_t InputWidth_; - size_t FilterCount_; - size_t KernelHeight_; - size_t KernelWidth_; - size_t PaddingLeftHeight_; - size_t PaddingLeftWidth_; - size_t PaddingRightHeight_; - size_t PaddingRightWidth_; - size_t DilationHeight_; - size_t DilationWidth_; - size_t StrideHeight_; - size_t StrideWidth_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp deleted file mode 100644 index e5a536eb9e4f0..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_conv2d_nchwc.h" -#include "test_conv2d_fixture.h" - -static size_t Conv2dNchwcRegistLongExecute() { - size_t count = 0; - - if (MlasNchwcGetBlockSize() > 1) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - } - - return count; -} - -static size_t Conv2dNchwcRegistShortExecute() { - size_t count = 0; - - if (MlasNchwcGetBlockSize() > 1) { - count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); - } - } - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? Conv2dNchwcRegistShortExecute() : Conv2dNchwcRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h deleted file mode 100644 index c125720668381..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_conv2d.h" - -template -class MlasNchwcConv2DTest : public MlasConv2DTest { - protected: - void MlasConv2D( - size_t BatchCount, - size_t GroupCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t FilterCount, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t DilationHeight, - size_t DilationWidth, - size_t StrideHeight, - size_t StrideWidth, - size_t OutputHeight, - size_t OutputWidth, - const float* Input, - const float* Filter, - const float* Bias, - float* Output) override { - int64_t InputShape[] = {int64_t(BatchCount), int64_t(GroupCount) * int64_t(InputChannels), int64_t(InputHeight), int64_t(InputWidth)}; - int64_t FilterShape[] = {int64_t(GroupCount) * int64_t(FilterCount), int64_t(InputChannels), int64_t(KernelHeight), int64_t(KernelWidth)}; - int64_t OutputShape[] = {int64_t(BatchCount), int64_t(GroupCount) * int64_t(FilterCount), int64_t(OutputHeight), int64_t(OutputWidth)}; - - int64_t KernelShape[] = {int64_t(KernelHeight), int64_t(KernelWidth)}; - int64_t DilationShape[] = {int64_t(DilationHeight), int64_t(DilationWidth)}; - int64_t Padding[] = {int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth)}; - int64_t StrideShape[] = {int64_t(StrideHeight), int64_t(StrideWidth)}; - - // - // Select the type of convolution that will be performed. - // - - bool DoReorderInput; - bool ReorderFilterOIHWBo; - - if (GroupCount > 1 && InputChannels == 1 && FilterCount == 1) { - // Depthwise convolution. - DoReorderInput = true; - ReorderFilterOIHWBo = true; - } else if (InputChannels >= BlockSize) { - // NCHWc or pointwise convolution; - DoReorderInput = true; - ReorderFilterOIHWBo = false; - } else { - // NCHW convolution. - DoReorderInput = false; - ReorderFilterOIHWBo = true; - } - - size_t NchwcInputChannels = (GroupCount * InputChannels + BlockSize - 1) & ~(BlockSize - 1); - size_t NchwcOutputChannels = (GroupCount * FilterCount + BlockSize - 1) & ~(BlockSize - 1); - - // - // Reorder the filter buffer as needed. - // - - float* ReorderedFilter; - - if (ReorderFilterOIHWBo) { - size_t NchwcFilterElements = NchwcOutputChannels * InputChannels * KernelHeight * KernelWidth; - ReorderedFilter = BufferNchwcFilter.GetBuffer(NchwcFilterElements); - MlasReorderFilterOIHWBo(FilterShape, Filter, ReorderedFilter); - } else { - size_t NchwcFilterElements = NchwcOutputChannels * NchwcInputChannels * KernelHeight * KernelWidth; - ReorderedFilter = BufferNchwcFilter.GetBuffer(NchwcFilterElements); - MlasReorderFilterOIHWBiBo(FilterShape, Filter, ReorderedFilter); - } - - // - // Align the bias buffer to the filter count if needed. - // - - if (Bias != nullptr && GroupCount * FilterCount < NchwcOutputChannels) { - float* AlignedBias = BufferNchwcBias.GetBuffer(NchwcOutputChannels); - - size_t i; - for (i = 0; i < GroupCount * FilterCount; i++) { - AlignedBias[i] = Bias[i]; - } - for (; i < NchwcOutputChannels; i++) { - AlignedBias[i] = 0.0f; - } - - Bias = AlignedBias; - } - - // - // Reorder the input buffer if needed. - // - - if (DoReorderInput) { - size_t NchwcInputElements = BatchCount * NchwcInputChannels * InputHeight * InputWidth; - float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements); - ReorderInputNchw(InputShape, Input, NchwcInput); - Input = NchwcInput; - InputShape[1] = NchwcInputChannels; - } - - int64_t NchwcOutputShape[] = {int64_t(BatchCount), int64_t(NchwcOutputChannels), int64_t(OutputHeight), int64_t(OutputWidth)}; - - size_t NchwcOutputElements = BatchCount * NchwcOutputChannels * OutputHeight * OutputWidth; - float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements); - - MLAS_ACTIVATION Activation; - Activation.ActivationKind = MlasIdentityActivation; - - MlasNchwcConv(InputShape, - KernelShape, - DilationShape, - Padding, - StrideShape, - NchwcOutputShape, - GroupCount, - Input, - ReorderedFilter, - Bias, - NchwcOutput, - &Activation, - true, - MlasConv2DTest::threadpool_); - - // - // Reorder the output buffer. - // - - MlasReorderOutputNchw(OutputShape, NchwcOutput, Output, MlasConv2DTest::threadpool_); - } - - const size_t BlockSize = MlasNchwcGetBlockSize(); - - MatrixGuardBuffer BufferNchwcInput; - MatrixGuardBuffer BufferNchwcFilter; - MatrixGuardBuffer BufferNchwcBias; - MatrixGuardBuffer BufferNchwcOutput; - - public: - static const char* GetTestSuiteName(void) { - static const std::string suite_name(Threaded ? "Conv2dNchwc_Threaded" : "Conv2dNchwc_SingleThread"); - return suite_name.c_str(); - } - - MlasNchwcConv2DTest() : MlasConv2DTest() {} - - void ExecuteLong(void) override { - // N.B. InputChannels must be a multiple of 4 if the count is greater - // than the block size. - static const unsigned cis[] = {32, 20, 5, 1}; - static const unsigned cos[] = {64, 15, 1}; - static const unsigned is[] = {27, 11, 5, 1}; - - // Depthwise convolutions. - for (unsigned i = 16; i < 256; i <<= 1) { - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 1, 1, 2, 2); - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, 3, 3, 0, 0, 0, 0, 2, 2, 1, 1); - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(1, i, 1, 28, 28, 1, i, 1, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(12, i, 1, 11, 11, 1, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - } - - // Test varying FilterCounts. - for (unsigned i = 1; i < 128; i++) { - MlasConv2DTest::Test(1, 1, 3, 34, 34, i, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(1, 1, 16, 34, 34, i, 3, 3, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(1, 1, 16, 34, 34, i, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); - } - - for (unsigned i = 1; i <= 32; i++) { - MlasConv2DTest::Test(4, 18, 1, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1); - MlasConv2DTest::Test(4, 18, 1, 32, 89, 48, i, 89, 1, 1, 1, 1, 1, 1, 1, 1); - MlasConv2DTest::Test(4, 18, 2, 32, 89, 48, i, 89, 0, 0, 0, 0, 1, 1, 1, 1); - } - - for (unsigned b = 1; b < 64; b++) { - MlasConv2DTest::Test(b, 1, 64, 11, 11, 128, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); - } - - for (unsigned ic = 0; ic < _countof(cis); ic++) { - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling %ux%ux%u\n", cis[ic], is[ih], is[iw]); - for (unsigned fc = 0; fc < _countof(cos); fc++) { - for (unsigned kh = 1; kh <= 5; kh++) { - if (kh == 4) continue; - for (unsigned kw = 1; kw <= 5; kw++) { - if (kw == 4) continue; - for (unsigned p0 = 0; p0 <= 3; p0++) { - for (unsigned p1 = 0; p1 <= 3; p1++) { - for (unsigned p2 = 0; p2 <= 3; p2++) { - for (unsigned p3 = 0; p3 <= 3; p3++) { - for (unsigned dh = 1; dh <= 2; dh++) { - for (unsigned dw = 1; dw <= 2; dw++) { - for (unsigned sh = 1; sh <= 2; sh++) { - for (unsigned sw = 1; sw <= 2; sw++) { - MlasConv2DTest::Test(1, 1, cis[ic], is[ih], is[iw], cos[fc], kh, kw, p0, p1, p2, p3, dh, dw, sh, sw); - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_exp.cpp b/onnxruntime/test/mlas/unittest/test_exp.cpp deleted file mode 100644 index f9cdffef1947d..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_exp.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -class MlasComputeExpTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - void Test(size_t N, float MinimumValue, float MaximumValue) { - float* Input = BufferInput.GetBuffer(N); - float* Output = BufferOutput.GetBuffer(N); - float* OutputReference = BufferOutputReference.GetBuffer(N); - - std::default_random_engine generator(static_cast(N)); - std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - - for (size_t n = 0; n < N; n++) { - Input[n] = distribution(generator); - } - - for (size_t n = 0; n < N; n++) { - OutputReference[n] = std::exp(Input[n]); - } - - MlasComputeExp(Input, Output, N); - - constexpr float AbsoluteTolerance = 1e-6f; - constexpr float RelativeTolerance = 1e-6f; - - for (size_t n = 0; n < N; n++) { - float diff = std::fabs(Output[n] - OutputReference[n]); - ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) - << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("Exp"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - for (size_t n = 1; n < 128; n++) { - Test(n, -10.f, 10.f); - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - // no long execute needed - return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.cpp b/onnxruntime/test/mlas/unittest/test_fgemm.cpp deleted file mode 100644 index e3f50baf3633d..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_fgemm.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_fgemm.h" -#include "test_fgemm_fixture.h" - -#include -#include - -static size_t FGemmRegistLongExecute() { - size_t count = 0; - - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - -#ifdef MLAS_SUPPORTS_GEMM_DOUBLE - - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - -#endif - - return count; -} - -static size_t FGemmRegistShortExecute() { - size_t count = 0; - - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - - if (GetMlasThreadPool() != nullptr) { - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - } - -#ifdef MLAS_SUPPORTS_GEMM_DOUBLE - - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += FgemmShortExecuteTest::RegisterShortExecuteTests(); - } - -#endif - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? FGemmRegistShortExecute() : FGemmRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h deleted file mode 100644 index 2bd094152d6f0..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ /dev/null @@ -1,398 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -const char* GetGemmTestSuitePrefix(); - -template <> -const char* GetGemmTestSuitePrefix() { - return "SGemm"; -} - -template <> -const char* GetGemmTestSuitePrefix() { - return "DGemm"; -} - -template -class FgemmPackedContext; - -template <> -class FgemmPackedContext { - public: - void - TestGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - size_t BatchSize, - const float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - const float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* threadpool) { - std::vector data(BatchSize); - for (size_t i = 0; i < BatchSize; i++) { - data[i].A = A + M * K * i; - data[i].lda = lda; - data[i].B = B + K * N * i; - data[i].ldb = ldb; - data[i].C = C + M * N * i; - data[i].ldc = ldc; - data[i].alpha = alpha; - data[i].beta = beta; - } - MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); - } -}; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) -template <> -class FgemmPackedContext { - public: - void TestGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - size_t BatchSize, - double alpha, - const double* A, - size_t lda, - const double* B, - size_t ldb, - double beta, - double* C, - size_t ldc, - MLAS_THREADPOOL* threadpool) { - std::vector data(BatchSize); - for (size_t i = 0; i < BatchSize; i++) { - data[i].A = A + M * K * i; - data[i].lda = lda; - data[i].B = B + K * N * i; - data[i].ldb = ldb; - data[i].C = C + M * N * i; - data[i].ldc = ldc; - data[i].alpha = alpha; - data[i].beta = beta; - } - MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); - } -}; -#endif - -template <> -class FgemmPackedContext { - public: - void - TestGemm( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - size_t BatchSize, - const float alpha, - const float* A, - size_t lda, - const float* B, - size_t ldb, - const float beta, - float* C, - size_t ldc, - MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); - void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); - std::vector data(BatchSize); - for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); - data[i].BIsPacked = true; - data[i].A = A + M * K * i; - data[i].lda = lda; - data[i].B = (float*)((uint8_t*)PackedB + PackedBSize * i); - data[i].ldb = ldb; - data[i].C = C + M * N * i; - data[i].ldc = ldc; - data[i].alpha = alpha; - data[i].beta = beta; - } - MlasGemmBatch(TransA, TransB, M, N, K, data.data(), BatchSize, threadpool); - } - - private: - MatrixGuardBuffer BufferBPacked; -}; - -template -class MlasFgemmTest : public MlasTestBase { - private: - MLAS_THREADPOOL* threadpool_; - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name = std::string(GetGemmTestSuitePrefix()) + - (Packed ? "_Packed" : "_NoPack") + - (Threaded ? "_Threaded" : "_SingleThread"); - - return suite_name.c_str(); - } - - MlasFgemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test(size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { - Test(false, false, M, N, K, BatchSize, alpha, beta); - Test(false, true, M, N, K, BatchSize, alpha, beta); - Test(true, false, M, N, K, BatchSize, alpha, beta); - Test(true, true, M, N, K, BatchSize, alpha, beta); - } - - void Test(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, T alpha, T beta) { - // - // Skip the test if the B buffer cannot be packed. - // - if constexpr (Packed) { - if (N == 0 || K == 0) - return; - } - - const T* A = BufferA.GetBuffer(K * M * BatchSize); - const T* B = BufferB.GetBuffer(N * K * BatchSize); - T* C = BufferC.GetBuffer(N * M * BatchSize); - T* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - - Test(trans_a ? CblasTrans : CblasNoTrans, - trans_b ? CblasTrans : CblasNoTrans, - M, N, K, BatchSize, alpha, A, trans_a ? M : K, B, trans_b ? K : N, - beta, C, CReference, N); - } - - void Test(CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - size_t BatchSize, - T alpha, - const T* A, - size_t lda, - const T* B, - size_t ldb, - T beta, - T* C, - T* CReference, - size_t ldc) { - std::fill_n(C, M * N * BatchSize, -0.5f); - std::fill_n(CReference, M * N * BatchSize, -0.5f); - - PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); - ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - // Sensitive to comparing positive/negative zero. - ASSERT_EQ(C[f], CReference[f]) - << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " - << (Packed ? "Packed" : "NoPack") << "." - << (Threaded ? "SingleThread" : "Threaded") << "/" - << (TransA == CblasTrans ? "TransA" : "A") << "/" - << (TransB == CblasTrans ? "TransB" : "B") << "/" - << "M" << M << "xN" << N << "xK" << K << "/" - << "Alpha" << alpha << "/" - << "Beta" << beta; - } - } - } - } - - void ReferenceGemm(CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - size_t BatchSize, - T alpha, - const T* A, - size_t lda, - const T* B, - size_t ldb, - T beta, - T* C, - size_t ldc) { - for (size_t batch = 0; batch < BatchSize; batch++) { - if (TransA == CblasNoTrans) { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += 1; - } - - *c = (*c * beta) + (sum * alpha); - } - } - - } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + (m * lda); - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += 1; - } - - *c = (*c * beta) + (sum * alpha); - } - } - } - - } else { - if (TransB == CblasNoTrans) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + n; - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += ldb; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); - } - } - - } else { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const T* a = A + m; - const T* b = B + (n * ldb); - T* c = C + (m * ldc) + n; - T sum = 0.0f; - - for (size_t k = 0; k < K; k++) { - sum += (*b * *a); - b += 1; - a += lda; - } - - *c = (*c * beta) + (sum * alpha); - } - } - } - } - A += M * K; - B += K * N; - C += M * N; - } - } - - void ExecuteLong() override { - static const T multipliers[] = {0.0f, -0.0f, 0.25f, -0.5f, 1.0f, -1.0f}; - - for (size_t N = 1; N < 128; N++) { - for (size_t K = 1; K < 128; K++) { - for (size_t a = 0; a < _countof(multipliers); a++) { - for (size_t b = 0; b < _countof(multipliers); b++) { - Test(1, N, K, 1, multipliers[a], multipliers[b]); - Test(N, 1, K, 1, multipliers[a], multipliers[b]); - if (!Packed) { - Test(1, N, K, 3, multipliers[a], multipliers[b]); - } - } - } - } - } - - for (size_t a = 0; a < _countof(multipliers); a++) { - T alpha = multipliers[a]; - - for (size_t b = 0; b < _countof(multipliers); b++) { - T beta = multipliers[b]; - - for (size_t M = 16; M < 160; M += 32) { - for (size_t N = 16; N < 160; N += 32) { - static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; - for (size_t k = 0; k < _countof(ks); k++) { - size_t K = ks[k]; - - Test(M, N, K, 1, alpha, beta); - Test(M + 1, N, K, 1, alpha, beta); - Test(M, N + 1, K, 1, alpha, beta); - Test(M + 1, N + 1, K, 1, alpha, beta); - Test(M + 3, N + 2, K, 1, alpha, beta); - Test(M + 4, N, K, 1, alpha, beta); - Test(M, N + 4, K, 1, alpha, beta); - Test(M + 4, N + 4, K, 1, alpha, beta); - Test(M + 3, N + 7, K, 1, alpha, beta); - Test(M + 8, N, K, 1, alpha, beta); - Test(M, N + 8, K, 1, alpha, beta); - Test(M + 12, N + 12, K, 1, alpha, beta); - Test(M + 13, N, K, 1, alpha, beta); - Test(M, N + 15, K, 1, alpha, beta); - Test(M + 15, N + 15, K, 1, alpha, beta); - if (!Packed) { - Test(M + 3, N + 1, K, 7, multipliers[a], multipliers[b]); - Test(M + 13, N + 2, K, 9, multipliers[a], multipliers[b]); - } - } - } - printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(multipliers), b, _countof(multipliers), M); - } - } - } - - for (size_t M = 0; M < 160; M++) { - for (size_t N = 0; N < 160; N++) { - for (size_t K = 0; K < 160; K++) { - Test(M, N, K, 1, 1.0f, 0.0f); - } - } - printf("M %zd\n", M); - } - - for (size_t M = 160; M < 320; M += 24) { - for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 0; K < 16; K++) { - Test(M, N, K, 1, 1.0f, 0.0f); - } - for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1, 1.0f, 0.0f); - } - } - printf("M %zd\n", M); - } - } - - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - FgemmPackedContext PackedContext; -}; diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h deleted file mode 100644 index 53b3edafdf84f..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ /dev/null @@ -1,80 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_fgemm.h" -#include -#include - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class FgemmShortExecuteTest : public MlasTestFixture> { - public: - explicit FgemmShortExecuteTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) - : trans_a_(trans_a), trans_b_(trans_b), M_(M), N_(N), K_(K), Batch_(BatchSize), alpha_(alpha), beta_(beta) { - } - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test( - trans_a_, trans_b_, M_, N_, K_, Batch_, alpha_, beta_); - } - - static size_t RegisterSingleTest(bool trans_a, bool trans_b, size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { - std::stringstream ss; - ss << (trans_a ? "TransA" : "A") << "/" - << (trans_b ? "TransB" : "B") << "/" - << "BatchSize" << BatchSize << "/M" << M << "xN" << N << "xK" << K << "/" - << "Alpha" << alpha << "/" - << "Beta" << beta; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasFgemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new FgemmShortExecuteTest( - trans_a, trans_b, M, N, K, BatchSize, alpha, beta); - }); - return 1; - } - - static size_t RegisterTestTransposeABProduct(size_t M, size_t N, size_t K, size_t BatchSize, float alpha, float beta) { - return RegisterSingleTest(false, false, M, N, K, BatchSize, alpha, beta) + - RegisterSingleTest(false, true, M, N, K, BatchSize, alpha, beta) + - RegisterSingleTest(true, false, M, N, K, BatchSize, alpha, beta) + - RegisterSingleTest(true, true, M, N, K, BatchSize, alpha, beta); - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - for (size_t b = 0; b < 16; b++) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); - test_registered += RegisterTestTransposeABProduct(b, b, b, 3, 1.0f, 0.0f); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterTestTransposeABProduct(b, b, b, 1, 1.0f, 0.0f); - } - - test_registered += RegisterTestTransposeABProduct(128, 3072, 768, 1, 1.0f, 0.0f); - test_registered += RegisterTestTransposeABProduct(128, 768, 3072, 1, 1.0f, 0.0f); - test_registered += RegisterTestTransposeABProduct(25, 81, 79, 7, 1.0f, 0.0f); - return test_registered; - } - - private: - bool trans_a_, trans_b_; - const size_t M_, N_, K_, Batch_; - const T alpha_, beta_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_fp16.h b/onnxruntime/test/mlas/unittest/test_fp16.h deleted file mode 100644 index 16636a54a0f66..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_fp16.h +++ /dev/null @@ -1,65 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_fp16.h - -Abstract: - - Define fp16 type before it is available in all compilers - ---*/ - -#pragma once - -#include "test_util.h" -#include "mlas_float16.h" - -// -// Define our own fp16 type to avoid dragging in big dependencies -// -struct MLFp16 { - uint16_t val{0}; - - MLFp16() = default; - explicit constexpr MLFp16(uint16_t x) : val(x) {} - explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {} - explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} - - float ToFloat() const { - return MLAS_Half2Float(val); - } - - operator float() const { return ToFloat(); } - - MLFp16& operator=(float ff) { - val = MLAS_Float2Half(ff); - return *this; - } -}; - -inline bool -operator==(const MLFp16& left, const MLFp16& right) { - return left.val == right.val; -} - -inline bool -operator!=(const MLFp16& left, const MLFp16& right) { - return left.val != right.val; -} - -template -void SmallFloatFill(T* start, size_t size) { - constexpr float MinimumFillValue = -11.0f; - auto* FillAddress = start; - size_t offset = size % 23; - - for (size_t i = 0; i < size; i++) { - offset = (offset + 21) % 23; - *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); - } -} diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp deleted file mode 100644 index 969997d2b84ec..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_fp16.h" -#include - -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - -bool check_equal(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } else { - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; - } -} - -class MlasFp16ActivationTest : public MlasTestBase { - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("Fp16Activation"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - union AliasedValue { - unsigned u; - float f; - }; - - // N.B. The test data includes values at the edge of Tanh/Logistic boundaries. - static const AliasedValue TestData[] = { - {0x00000001}, // positive denormal - {0x80000001}, // negative denormal - {0x7fc00000}, // positive NaN - {0xffc00000}, // negative NaN - {0x00000000}, // 0.0f - {0x80000000}, // -0.0f - {0x3e800000}, // 0.25f - {0xbe800000}, // -0.25f - {0x40800000}, // 4.0f - {0xc0800000}, // -4.0f - {0x41200000}, // 10.0f - {0xc1200000}, // -10.0f - {0xc18866eb}, // -17.0502529144f - {0xc18869bb}, // -17.0516262054f - {0xc18852a8}, // -17.0403594971f - {0xc18844aa}, // -17.0335273743f - {0x418866eb}, // +17.0502529144f - {0x418869bb}, // +17.0516262054f - {0x418852a8}, // +17.0403594971f - {0x418844aa} // +17.0335273743f - }; - - constexpr size_t M = 5; - constexpr size_t N = 23; - constexpr float MinimumFillValue = -11.0f; - - MatrixGuardBuffer HalfBuffer1; - auto* testData1 = HalfBuffer1.GetBuffer(M * N, true); - MatrixGuardBuffer HalfBuffer2; - auto* testData2 = HalfBuffer2.GetBuffer(M * N, true); - MatrixGuardBuffer HalfBuffer3; - auto* testData3 = HalfBuffer3.GetBuffer(M * N, true); - MatrixGuardBuffer AddonBuffer; - auto addonData = AddonBuffer.GetBuffer(M * N, true); - MatrixGuardBuffer FloatBuffer; - auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true); - MatrixGuardBuffer FloatBuffer1; - auto* fpAddBuffer = FloatBuffer1.GetBuffer(M * N, true); - - size_t o = 3; - for (size_t i = 0; i < M * N; i++) { - o = (o + 19) % 23; - addonData[i] = (MinimumFillValue + o) / 16.0f; - } - - MLAS_ACTIVATION_KIND acts[] = { - MlasIdentityActivation, - MlasReluActivation, - MlasLeakyReluActivation, - MlasTanhActivation, - MlasLogisticActivation, - MlasClipActivation, - MlasHardSigmoidActivation}; - - MLAS_ACTIVATION Activation; - MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation, nullptr); - MLAS_HALF_GEMM_ACTIVATION_PROCESSOR addon(Activation, reinterpret_cast(addonData)); - for (auto kind : acts) { - Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind); - - if (Activation.ActivationKind == MlasLeakyReluActivation) { - Activation.Parameters.LeakyRelu.alpha = 0.2f; - } else if (Activation.ActivationKind == MlasClipActivation) { - Activation.Parameters.Clip.minimum = 0.0f; - Activation.Parameters.Clip.maximum = 6.0f; - } else if (Activation.ActivationKind == MlasHardSigmoidActivation) { - Activation.Parameters.HardSigmoid.alpha = 0.2f; - Activation.Parameters.HardSigmoid.beta = 0.12f; - } - - // - // Test the vectorized activations. - // - - for (size_t i = 0; i < _countof(TestData); i++) { - testData1[i] = TestData[i].f; - testData2[i] = TestData[i].f; - testData3[i] = TestData[i].f; - fpBuffer[i] = TestData[i].f; - fpAddBuffer[i] = TestData[i].f + addonData[i].ToFloat(); - } - size_t offset = 7; - for (size_t i = _countof(TestData); i < M * N; i++) { - offset = (offset + 19) % 23; - float f = (MinimumFillValue + offset) / 16.0f; - testData1[i] = f; - testData2[i] = testData1[i]; - testData3[i] = testData1[i]; - fpBuffer[i] = f; - fpAddBuffer[i] = f + addonData[i].ToFloat(); - } - - proc.Process(reinterpret_cast(testData1), 0, 0, M, N, N); - MlasActivation(&Activation, fpBuffer, nullptr, M, N, N); - MlasActivation(&Activation, fpAddBuffer, nullptr, M, N, N); - addon.Process(reinterpret_cast(testData3), 0, 0, M, N, N); - - for (size_t i = 0; i < M * N; i++) { - float actual = testData1[i].ToFloat(); - EXPECT_TRUE(check_equal(actual, fpBuffer[i])) - << ", Vector Activation Kind:" << (int)kind << ", i=" << i << ", value:" - << std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:" - << std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i]; - - float addonActual = testData3[i].ToFloat(); - EXPECT_TRUE(check_equal(addonActual, fpAddBuffer[i])) - << ", Vector + Activation Kind:" << (int)kind << ", i=" << i << ", value:" - << std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:" - << std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i]; - } - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; -}); - -#endif // fp16 vector intrinsic supported diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp deleted file mode 100644 index aafdcc14c0028..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ /dev/null @@ -1,168 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_halfgemm.cpp - -Abstract: - - Tests for MLAS half precision GEMM. - ---*/ - -#include "test_halfgemm.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class HalfGemmShortExecuteTest : public MlasTestFixture> { - public: - explicit HalfGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) - : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { - std::stringstream ss; - ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" - << "hasBias" << hasBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasHalfGemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new HalfGemmShortExecuteTest( - M, N, K, Batch, hasBias); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, false); - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, 1, false); - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, 1, false); - test_registered += RegisterSingleTest(1, 32, b, 1, true); - test_registered += RegisterSingleTest(1, b, b, 1, false); - if (!Packed) { - test_registered += RegisterSingleTest(1, b, 32, 3, true); - test_registered += RegisterSingleTest(1, 32, b, 5, false); - } - } - test_registered += RegisterSingleTest(43, 500, 401, 1, true); - // test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); - if (!Packed) { - test_registered += RegisterSingleTest(43, 500, 401, 5, true); - // test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); - } - - return test_registered; - } - - private: - size_t M_, N_, K_, Batch_; - bool hasBias_; -}; - -static size_t HalfGemmRegistLongExecute() { - size_t count = 0; - - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasHalfGemmPackBSize(128, 128, false) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - if (MlasHalfGemmPackBSize(128, 128, true) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasHalfGemmPackBSize(128, 128, false) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - if (MlasHalfGemmPackBSize(128, 128, true) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - } - - return count; -} - -static size_t HalfGemmRegistShortExecute() { - size_t count = 0; - - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasHalfGemmPackBSize(128, 128, false) > 0) { - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasHalfGemmPackBSize(128, 128, true) > 0) { - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - } - - if (GetMlasThreadPool() != nullptr) { - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasHalfGemmPackBSize(128, 128, false) > 0) { - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasHalfGemmPackBSize(128, 128, true) > 0) { - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); - } - } - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (!MlasFp16AccelerationSupported()) { - return false; - } - if (is_short_execute) { - return HalfGemmRegistShortExecute() > 0; - } - return HalfGemmRegistLongExecute() > 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h deleted file mode 100644 index 4db5c2bebca40..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ /dev/null @@ -1,271 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_halfgemm.h - -Abstract: - - Tests for MLAS half precision GEMM. - ---*/ - -#pragma once - -#include "test_fp16.h" - -/** - * @brief Test class for half precision GEMM - * @tparam AType Data type of A matrix, can be either float or MLFp16 - * @tparam BType Data type of b matrix, can be either float or MLFp16 - */ -template -class MlasHalfGemmTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferBPacked; - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - MatrixGuardBuffer BufferFloatC; - MLAS_THREADPOOL* threadpool_; - - void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { - size_t PackedBSize = MlasHalfGemmPackBSize(N, K, std::is_same::value); - if (PackedBSize == 0) { - return nullptr; - } - void* PackedB = BufferBPacked.GetBuffer(PackedBSize); - if (std::is_same::value) { - MlasHalfGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); - } else { - MlasHalfGemmPackB(N, K, (const MLAS_FP16*)B, ldb, PackedB); - } - return PackedB; - } - - void CallGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - size_t lda, - const BType* B, - size_t ldb, - const MLFp16* Bias, - MLFp16* C, - size_t ldc, - float* Cfloat) { - MLAS_ACTIVATION act; - act.ActivationKind = MlasIdentityActivation; - std::vector Converters; - Converters.reserve(BatchSize); - - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < GemmParameters.size(); i++) { - auto& params = GemmParameters[i]; - params.A = A + (M * lda * i); - params.lda = lda; - if (nullptr != Bias) { - params.Bias = reinterpret_cast(Bias + N * i); - } else { - params.Bias = nullptr; - } - params.C = reinterpret_cast(C + (M * ldc * i)); - params.ldc = ldc; - - if (Packed) { - ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb); - params.ldb = 0; - } else { - params.B = B + (K * N * i); - params.ldb = ldb; - } - params.AIsfp32 = std::is_same::value; - params.BIsfp32 = std::is_same::value; - Converters.emplace_back(act, Cfloat + (M * N * i), N); - params.OutputProcessor = &(Converters[i]); - } - - MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); - } - - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - const BType* B, - const MLFp16* Bias, - float* C) { - // TODO!! deal with half precision accumulation error - // Most CPUs does not support mixed precision accumulation, - // only mul & add fuse. As a result, different striding - // on the K dimension may lead to rounding error. - // Accumulation of these rounding error maybe very significant. - // So setting a approximation ratio does NOT work. - // - // Currently this test require a manual efforts: - // 1. Change the K stride of the kernel under test to be 16; - // 2. Force the K stride of the fp16 kernel to 16 - // 3. Change the test oracle to be exact match. - // 4. Pass this test and then change it back :-(. - // - constexpr size_t KStride = 512; - - for (size_t batch = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const AType* a = A + M * K * batch + m * K; - const BType* b = B + K * N * batch + n; - float* c = C + (M * N * batch) + (m * N) + n; - - for (size_t k = 0; k < K; k += KStride) { - float sum = 0.0f; - if (k == 0 && Bias != nullptr) { - sum = float(Bias[n]); - } - for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { - MLFp16 down(float(*b) * float(*a) + sum); - sum = float(down); - b += N; - a += 1; - } - if (k == 0) { - *c = sum; - } else { - MLFp16 d(sum + *c); - *c = float(d); - } - } - } - } - if (Bias) { - Bias += N; - } - } - } - - public: - MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { - const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); - AType Atail[16]; - std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); - - const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); - BType Btail[16]; - std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); - - MLFp16 BiasTail[16]; - const MLFp16* Bias = nullptr; - if (withBias) { - Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); - std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)); - } - - MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); - float* Cfloat = BufferFloatC.GetBuffer(N * M * BatchSize, true); - float* CReference = BufferCReference.GetFilledBuffer( - N * M * BatchSize, - [](float* start, size_t size) { - std::fill_n(start, size, -1.0f); - }); - - this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, Cfloat); - ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_TRUE(CloseEnough(float(C[f]), CReference[f])) << "@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; - ASSERT_TRUE(CloseEnough(Cfloat[f], CReference[f])) << "Converted@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; - } - } - } - ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; - ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; - if (withBias) { - ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)), 0) << "Bias buffer overwritten!"; - } - } - - private: - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("HalfGemmFP") + - (std::is_same::value ? "32" : "16") + - (std::is_same::value ? "32" : "16") + - (Packed ? "_Packed" : "_NoPack") + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void ExecuteLong(void) override { - for (size_t M = 16; M < 160; M += 32) { - for (size_t N = 16; N < 160; N += 32) { - static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; - for (size_t k = 0; k < _countof(ks); k++) { - size_t K = ks[k]; - - Test(M, N, K, 1, false); - Test(M, N, K, 1, true); - Test(M + 1, N, K, 1, false); - Test(M, N + 1, K, 1, true); - Test(M + 1, N + 1, K, 1, false); - Test(M + 3, N + 2, K, 1, true); - Test(M + 4, N, K, 1, false); - Test(M, N + 4, K, 1, true); - Test(M + 4, N + 4, K, 1, false); - Test(M + 3, N + 7, K, 1, true); - Test(M + 8, N, K, 1, false); - Test(M, N + 8, K, 1, true); - Test(M + 12, N + 12, K, 1, false); - Test(M + 13, N, K, 1, true); - Test(M, N + 15, K, 1, false); - Test(M + 15, N + 15, K, 1, false); - if (!Packed) { - Test(M, N, K, 7, false); - Test(M + 3, N, K, 8, true); - Test(M, N + 1, K, 9, false); - Test(M + 12, N, K, 10, true); - Test(M, N + 15, K, 11, false); - Test(M + 15, N + 15, K, 12, true); - } - } - } - printf("M %zd\n", M); - } - - for (size_t M = 1; M < 160; M++) { - for (size_t N = 1; N < 160; N++) { - for (size_t K = 1; K < 160; K++) { - Test(M, N, K, 1, true); - } - } - printf("M %zd\n", M); - } - - for (size_t M = 160; M < 320; M += 24) { - for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 1; K < 16; K++) { - Test(M, N, K, 1, true); - } - for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1, false); - } - } - printf("M %zd\n", M); - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_main.cpp b/onnxruntime/test/mlas/unittest/test_main.cpp deleted file mode 100644 index 505c0c01dfa90..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_main.cpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "test_util.h" - -#if !defined(BUILD_MLAS_NO_ONNXRUNTIME) - -MLAS_THREADPOOL* GetMlasThreadPool(void) { - static auto threadpool = std::make_unique( - &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 2, true); - return threadpool.get(); -} - -#else - -MLAS_THREADPOOL* GetMlasThreadPool(void) { - return nullptr; -} - -#endif - -// Singleton to avoid initialization order impact. -class LongShortExecuteManager { - public: - static LongShortExecuteManager& instance(void) { - static LongShortExecuteManager s_instance; - return s_instance; - }; - - void AddTestRegister(TestRegister test_register) { - test_registers_.push_back(test_register); - } - - size_t RegisterAll(bool is_short_execute) { - size_t count = 0; - for (const auto& r : instance().test_registers_) { - count += r(is_short_execute); - } - return count; - } - - private: - LongShortExecuteManager() : test_registers_() {} - LongShortExecuteManager(const LongShortExecuteManager&) = delete; - LongShortExecuteManager& operator=(const LongShortExecuteManager&) = delete; - - std::list test_registers_; -}; - -bool AddTestRegister(TestRegister test_register) { - LongShortExecuteManager::instance().AddTestRegister(test_register); - return true; -} - -int main(int argc, char** argv) { - bool is_short_execute = (argc <= 1 || strcmp("--long", argv[1]) != 0); - std::cout << "-------------------------------------------------------" << std::endl; - if (is_short_execute) { - std::cout << "----Running normal quick check mode. To enable more complete test," << std::endl; - std::cout << "---- run with '--long' as first argument!" << std::endl; - } - auto test_count = LongShortExecuteManager::instance().RegisterAll(is_short_execute); - std::cout << "----Total " << test_count << " tests registered programmably!" << std::endl; - std::cout << "-------------------------------------------------------" << std::endl; - - ::testing::InitGoogleTest(&argc, argv); - - return RUN_ALL_TESTS(); -} diff --git a/onnxruntime/test/mlas/unittest/test_minmax.cpp b/onnxruntime/test/mlas/unittest/test_minmax.cpp deleted file mode 100644 index 245879deccffd..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_minmax.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -class MlasFindMinMaxElementsTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - - void Test(size_t N, float MinimumValue, float MaximumValue) { - float* Input = BufferInput.GetBuffer(N); - - std::default_random_engine generator(static_cast(N)); - std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - - for (size_t n = 0; n < N; n++) { - Input[n] = distribution(generator); - } - - auto min_max_pair = std::minmax_element(Input, Input + N); - float min_ref = *min_max_pair.first; - float max_ref = *min_max_pair.second; - - float min, max; - MlasFindMinMaxElement(Input, &min, &max, N); - - constexpr float epsilon = 1e-6f; - - float diff_min = std::fabs(min - min_ref); - ASSERT_LE(diff_min, epsilon) << " for minimum with parameter (" << N << "," << MinimumValue << "," << MaximumValue << ")"; - - float diff_max = std::fabs(max - max_ref); - ASSERT_LE(diff_max, epsilon) << " for maximum with parameter (" << N << "," << MinimumValue << "," << MaximumValue << ")"; - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("MinMaxElement"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - for (size_t n = 1; n < 128; n++) { - Test(n, -10.f, 10.f); - } - } -}; - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_pool2d.cpp b/onnxruntime/test/mlas/unittest/test_pool2d.cpp deleted file mode 100644 index 8cefb8332ec32..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool2d.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_pool2d.h" -#include "test_pool2d_fixture.h" - -static size_t Pool2dRegistLongExecute() { - size_t count = 0; - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - return count; -} - -static size_t Pool2dRegistShortExecute() { - size_t count = 0; - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - } - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? Pool2dRegistShortExecute() : Pool2dRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_pool2d.h b/onnxruntime/test/mlas/unittest/test_pool2d.h deleted file mode 100644 index ebb1f256ae507..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool2d.h +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -class MlasPool2DTest : public MlasTestBase { - public: - MlasPool2DTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - static const char* GetTestSuiteName() { - static std::string suite_name = - std::string(PoolingKind == MlasMaximumPooling - ? "Pool2dMax" - : (PoolingKind == MlasAveragePoolingExcludePad ? "Pool2dAverageExcludePad" : "Pool2dAverageIncludePad")) + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void Test(size_t BatchCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideHeight, - size_t StrideWidth) { - constexpr size_t DilationHeight = 1; - constexpr size_t DilationWidth = 1; - - int64_t OutputHeight64 = - ((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) - - (int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / - int64_t(StrideHeight) + - 1; - int64_t OutputWidth64 = - ((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) - - (int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / - int64_t(StrideWidth) + - 1; - - if (OutputHeight64 <= 0 || OutputWidth64 <= 0) { - return; - } - - int64_t InputShape[] = {int64_t(BatchCount), int64_t(InputChannels), int64_t(InputHeight), int64_t(InputWidth)}; - int64_t KernelShape[] = {int64_t(KernelHeight), int64_t(KernelWidth)}; - int64_t Padding[] = {int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth)}; - int64_t StrideShape[] = {int64_t(StrideHeight), int64_t(StrideWidth)}; - int64_t OutputShape[] = {int64_t(BatchCount), int64_t(InputChannels), OutputHeight64, OutputWidth64}; - - size_t InputBufferElements = size_t(InputShape[0] * InputShape[1] * InputShape[2] * InputShape[3]); - size_t OutputBufferElements = size_t(OutputShape[0] * OutputShape[1] * OutputShape[2] * OutputShape[3]); - - const float* Input = BufferInput.GetBuffer(InputBufferElements); - float* Output = BufferOutput.GetBuffer(OutputBufferElements); - float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements); - - MlasPool2D(InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output); - if constexpr (PoolingKind == MlasMaximumPooling) { - ReferenceMaximumPool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference); - } else if constexpr (PoolingKind == MlasAveragePoolingExcludePad) { - ReferenceAveragePool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, false); - } else if constexpr (PoolingKind == MlasAveragePoolingIncludePad) { - ReferenceAveragePool2D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, true); - } - - ASSERT_EQ(memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)), 0) - << "PoolingKind:" << int(PoolingKind) << " " - << "input(" << InputChannels << "," << InputHeight << ", " << InputWidth << "), " - << "Kernel(" << KernelHeight << "," << KernelWidth << ")"; - } - - protected: - virtual void MlasPool2D(const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output) { - MlasPool(PoolingKind, 2, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool_); - } - - void ReferenceMaximumPool2D(const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const float* Input, - float* Output) { - int64_t ChannelCount = InputShape[0] * InputShape[1]; - - int64_t InputHeight = InputShape[2]; - int64_t InputWidth = InputShape[3]; - - int64_t KernelHeight = KernelShape[0]; - int64_t KernelWidth = KernelShape[1]; - - int64_t PaddingLeftY = Padding[0]; - int64_t PaddingLeftX = Padding[1]; - int64_t PaddingRightY = Padding[2]; - int64_t PaddingRightX = Padding[3]; - - int64_t StrideHeight = StrideShape[0]; - int64_t StrideWidth = StrideShape[1]; - - int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1; - int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1; - - for (int64_t c = 0; c < ChannelCount; c++) { - for (int64_t ph = 0; ph < OutputHeight; ph++) { - int64_t ihStart = ph * StrideHeight - PaddingLeftY; - int64_t ihEnd = ihStart + KernelHeight; - - ihStart = (std::max)(ihStart, int64_t(0)); - ihEnd = (std::min)(ihEnd, InputHeight); - - for (int64_t pw = 0; pw < OutputWidth; pw++) { - int64_t iwStart = pw * StrideWidth - PaddingLeftX; - int64_t iwEnd = iwStart + KernelWidth; - - iwStart = (std::max)(iwStart, int64_t(0)); - iwEnd = (std::min)(iwEnd, InputWidth); - - float m = std::numeric_limits::lowest(); - - for (int64_t ih = ihStart; ih < ihEnd; ih++) { - for (int64_t iw = iwStart; iw < iwEnd; iw++) { - m = (std::max)(m, Input[ih * InputWidth + iw]); - } - } - - Output[ph * OutputWidth + pw] = m; - } - } - - Input += InputHeight * InputWidth; - Output += OutputHeight * OutputWidth; - } - } - - void ReferenceAveragePool2D(const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const float* Input, - float* Output, - bool CountIncludePad) { - int64_t ChannelCount = InputShape[0] * InputShape[1]; - - int64_t InputHeight = InputShape[2]; - int64_t InputWidth = InputShape[3]; - - int64_t KernelHeight = KernelShape[0]; - int64_t KernelWidth = KernelShape[1]; - - int64_t PaddingLeftY = Padding[0]; - int64_t PaddingLeftX = Padding[1]; - int64_t PaddingRightY = Padding[2]; - int64_t PaddingRightX = Padding[3]; - - int64_t StrideHeight = StrideShape[0]; - int64_t StrideWidth = StrideShape[1]; - - int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1; - int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1; - - for (int64_t c = 0; c < ChannelCount; c++) { - for (int64_t ph = 0; ph < OutputHeight; ph++) { - int64_t ihStart = ph * StrideHeight - PaddingLeftY; - int64_t ihEnd = ihStart + KernelHeight; - - ihStart = (std::max)(ihStart, int64_t(0)); - ihEnd = (std::min)(ihEnd, InputHeight); - - for (int64_t pw = 0; pw < OutputWidth; pw++) { - int64_t iwStart = pw * StrideWidth - PaddingLeftX; - int64_t iwEnd = iwStart + KernelWidth; - - iwStart = (std::max)(iwStart, int64_t(0)); - iwEnd = (std::min)(iwEnd, InputWidth); - - float m = 0.0f; - - for (int64_t ih = ihStart; ih < ihEnd; ih++) { - for (int64_t iw = iwStart; iw < iwEnd; iw++) { - m += Input[ih * InputWidth + iw]; - } - } - - if (CountIncludePad) { - m /= (KernelHeight * KernelWidth); - } else { - m /= (ihEnd - ihStart) * (iwEnd - iwStart); - } - - Output[ph * OutputWidth + pw] = m; - } - } - - Input += InputHeight * InputWidth; - Output += OutputHeight * OutputWidth; - } - } - - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - MLAS_THREADPOOL* threadpool_; - - public: - void ExecuteLong(void) override { - static const unsigned is[] = {53, 17, 11, 5, 4, 3, 2, 1}; - - for (unsigned i = 1; i < 2058; i++) { - Test(1, 1, 4, i, 2, 4, 0, 2, 0, 1, 1, 1); - } - - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling %ux%u\n", is[ih], is[iw]); - Test(1, 1, is[ih], is[iw], is[ih], is[iw], 0, 0, 0, 0, 1, 1); - Test(1, 1, is[ih], is[iw], is[ih], 1, 0, 0, 0, 0, 1, 1); - Test(1, 1, is[ih], is[iw], 1, is[iw], 0, 0, 0, 0, 1, 1); - for (unsigned kh = 1; kh <= 5; kh++) { - if (kh > is[ih]) break; - for (unsigned kw = 1; kw <= 5; kw++) { - if (kw > is[iw]) break; - for (unsigned sh = 1; sh <= 3; sh++) { - for (unsigned sw = 1; sw <= 3; sw++) { - for (unsigned p0 = 0; p0 < kh; p0++) { - for (unsigned p1 = 0; p1 < kw; p1++) { - for (unsigned p2 = 0; p2 < kh; p2++) { - for (unsigned p3 = 0; p3 < kw; p3++) { - Test(5, 3, is[ih], is[iw], kh, kw, p0, p1, p2, p3, sh, sw); - } - } - } - } - } - } - } - } - } - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_fixture.h b/onnxruntime/test/mlas/unittest/test_pool2d_fixture.h deleted file mode 100644 index cb748bbaccce0..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool2d_fixture.h +++ /dev/null @@ -1,135 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_pool2d.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class Pooling2dShortExecuteTest : public MlasTestFixture { - public: - explicit Pooling2dShortExecuteTest(size_t BatchCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideHeight, - size_t StrideWidth) - : BatchCount_(BatchCount), - InputChannels_(InputChannels), - InputHeight_(InputHeight), - InputWidth_(InputWidth), - KernelHeight_(KernelHeight), - KernelWidth_(KernelWidth), - PaddingLeftHeight_(PaddingLeftHeight), - PaddingLeftWidth_(PaddingLeftWidth), - PaddingRightHeight_(PaddingRightHeight), - PaddingRightWidth_(PaddingRightWidth), - StrideHeight_(StrideHeight), - StrideWidth_(StrideWidth) { - } - - void TestBody() override { - MlasTestFixture::mlas_tester->Test( - BatchCount_, - InputChannels_, - InputHeight_, - InputWidth_, - KernelHeight_, - KernelWidth_, - PaddingLeftHeight_, - PaddingLeftWidth_, - PaddingRightHeight_, - PaddingRightWidth_, - StrideHeight_, - StrideWidth_); - } - - static size_t RegisterSingleTest(size_t BatchCount, - size_t InputChannels, - size_t InputHeight, - size_t InputWidth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideHeight, - size_t StrideWidth) { - std::stringstream ss; - ss << "B" << BatchCount << "/" - << "C" << InputChannels << "/" - << "H" << InputHeight << "/" - << "W" << InputWidth << "/" - << "KH" << KernelHeight << "/" - << "KW" << KernelWidth << "/" - << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Stride" << StrideHeight << "," << StrideWidth; - auto test_name = ss.str(); - - testing::RegisterTest( - Pool2DTester::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture* { - return new Pooling2dShortExecuteTest( - BatchCount, - InputChannels, - InputHeight, - InputWidth, - KernelHeight, - KernelWidth, - PaddingLeftHeight, - PaddingLeftWidth, - PaddingRightHeight, - PaddingRightWidth, - StrideHeight, - StrideWidth); - }); - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (unsigned i = 1; i < 256; i <<= 1) { - test_registered += RegisterSingleTest(1, 16, i, i, 3, 3, 0, 0, 0, 0, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, 3, 3, 0, 0, 0, 0, 2, 2); - test_registered += RegisterSingleTest(1, 16, i, i, 3, 3, 0, 0, 0, 0, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, 3, 3, 1, 1, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, 1, 1, 0, 0, 0, 0, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 1, 0, 0, 0, 0, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, 1, i, 0, 0, 0, 0, 1, 1); - } - - return test_registered; - } - - private: - size_t BatchCount_; - size_t InputChannels_; - size_t InputHeight_; - size_t InputWidth_; - size_t KernelHeight_; - size_t KernelWidth_; - size_t PaddingLeftHeight_; - size_t PaddingLeftWidth_; - size_t PaddingRightHeight_; - size_t PaddingRightWidth_; - size_t StrideHeight_; - size_t StrideWidth_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp deleted file mode 100644 index bee690b10b737..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_pool2d_nchwc.h" -#include "test_pool2d_fixture.h" - -static size_t Pool2dNchwcRegistLongExecute() { - size_t count = 0; - if (MlasNchwcGetBlockSize() > 1) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - } - return count; -} - -static size_t Pool2dNchwcRegistShortExecute() { - size_t count = 0; - if (MlasNchwcGetBlockSize() > 1) { - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - count += Pooling2dShortExecuteTest>::RegisterShortExecuteTests(); - } - } - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? Pool2dNchwcRegistShortExecute() : Pool2dNchwcRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h deleted file mode 100644 index 38ac63a68c843..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_pool2d.h" - -template -class MlasNchwcPool2DTest : public MlasPool2DTest { - public: - static const char* GetTestSuiteName() { - static std::string suite_name = - std::string(PoolingKind == MlasMaximumPooling - ? "Pool2dNchwcMax" - : (PoolingKind == MlasAveragePoolingExcludePad ? "Pool2dNchwcAverageExcludePad" : "Pool2dNchwcAverageIncludePad")) + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - protected: - void MlasPool2D( - const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const int64_t* OutputShape, - const float* Input, - float* Output) override { - size_t NchwcChannels = (size_t(InputShape[1]) + BlockSize - 1) & ~(BlockSize - 1); - - int64_t NchwcInputShape[] = {InputShape[0], int64_t(NchwcChannels), InputShape[2], InputShape[3]}; - size_t NchwcInputElements = size_t(NchwcInputShape[0]) * size_t(NchwcInputShape[1]) * size_t(NchwcInputShape[2]) * size_t(NchwcInputShape[3]); - float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements); - - int64_t NchwcOutputShape[] = {OutputShape[0], int64_t(NchwcChannels), OutputShape[2], OutputShape[3]}; - size_t NchwcOutputElements = size_t(NchwcOutputShape[0]) * size_t(NchwcOutputShape[1]) * size_t(NchwcOutputShape[2]) * size_t(NchwcOutputShape[3]); - float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements); - - ReorderInputNchw(InputShape, Input, NchwcInput); - - MlasNchwcPool(PoolingKind, - NchwcInputShape, - KernelShape, - nullptr, - Padding, - StrideShape, - NchwcOutputShape, - NchwcInput, - NchwcOutput, - nullptr); - - MlasReorderOutputNchw(OutputShape, NchwcOutput, Output, nullptr); - } - - MatrixGuardBuffer BufferNchwcInput; - MatrixGuardBuffer BufferNchwcOutput; - - const size_t BlockSize = MlasNchwcGetBlockSize(); - - public: - void ExecuteLong(void) override { - static const unsigned is[] = {53, 11, 1}; - - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling %ux%u\n", is[ih], is[iw]); - MlasPool2DTest::Test(1, 12, is[ih], is[iw], is[ih], is[iw], 0, 0, 0, 0, 1, 1); - MlasPool2DTest::Test(1, 32, is[ih], is[iw], is[ih], 1, 0, 0, 0, 0, 1, 1); - MlasPool2DTest::Test(1, 68, is[ih], is[iw], 1, is[iw], 0, 0, 0, 0, 1, 1); - for (unsigned kh = 1; kh <= 5; kh++) { - if (kh > is[ih]) break; - for (unsigned kw = 1; kw <= 5; kw++) { - if (kw > is[iw]) break; - for (unsigned sh = 1; sh <= 3; sh++) { - for (unsigned sw = 1; sw <= 3; sw++) { - for (unsigned p0 = 0; p0 < kh; p0++) { - for (unsigned p1 = 0; p1 < kw; p1++) { - for (unsigned p2 = 0; p2 < kh; p2++) { - for (unsigned p3 = 0; p3 < kw; p3++) { - MlasPool2DTest::Test(1, 32, is[ih], is[iw], kh, kw, p0, p1, p2, p3, sh, sw); - } - } - } - } - } - } - } - } - } - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_pool3d.cpp b/onnxruntime/test/mlas/unittest/test_pool3d.cpp deleted file mode 100644 index e0ce4c240be80..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool3d.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_pool3d.h" -#include "test_pool3d_fixture.h" - -static size_t Pool3dRegistLongExecute() { - size_t count = 0; - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - return count; -} - -static size_t Pool3dRegistShortExecute() { - size_t count = 0; - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - if (GetMlasThreadPool() != nullptr) { - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - count += Pooling3dShortExecuteTest::RegisterShortExecuteTests(); - } - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? Pool3dRegistShortExecute() : Pool3dRegistLongExecute(); -}); diff --git a/onnxruntime/test/mlas/unittest/test_pool3d.h b/onnxruntime/test/mlas/unittest/test_pool3d.h deleted file mode 100644 index bdab16fc28c57..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool3d.h +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -class MlasPool3DTest : public MlasTestBase { - public: - MlasPool3DTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - static const char* GetTestSuiteName() { - static std::string suite_name = - std::string(PoolingKind == MlasMaximumPooling - ? "Pool3dMax" - : (PoolingKind == MlasAveragePoolingExcludePad - ? "Pool3dAverageExcludePad" - : "Pool3dAverageIncludePad")) + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void Test(size_t BatchCount, - size_t InputChannels, - size_t InputDepth, - size_t InputHeight, - size_t InputWidth, - size_t KernelDepth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftDepth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightDepth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideDepth, - size_t StrideHeight, - size_t StrideWidth) { - constexpr size_t DilationDepth = 1; - constexpr size_t DilationHeight = 1; - constexpr size_t DilationWidth = 1; - - int64_t OutputDepth64 = - ((int64_t(InputDepth) + int64_t(PaddingLeftDepth) + int64_t(PaddingRightDepth)) - - (int64_t(DilationDepth) * (int64_t(KernelDepth) - 1) + 1)) / - int64_t(StrideDepth) + - 1; - int64_t OutputHeight64 = - ((int64_t(InputHeight) + int64_t(PaddingLeftHeight) + int64_t(PaddingRightHeight)) - - (int64_t(DilationHeight) * (int64_t(KernelHeight) - 1) + 1)) / - int64_t(StrideHeight) + - 1; - int64_t OutputWidth64 = - ((int64_t(InputWidth) + int64_t(PaddingLeftWidth) + int64_t(PaddingRightWidth)) - - (int64_t(DilationWidth) * (int64_t(KernelWidth) - 1) + 1)) / - int64_t(StrideWidth) + - 1; - - if (OutputDepth64 <= 0 || OutputHeight64 <= 0 || OutputWidth64 <= 0) { - return; - } - - int64_t InputShape[] = {int64_t(BatchCount), int64_t(InputChannels), int64_t(InputDepth), int64_t(InputHeight), int64_t(InputWidth)}; - int64_t KernelShape[] = {int64_t(KernelDepth), int64_t(KernelHeight), int64_t(KernelWidth)}; - int64_t Padding[] = {int64_t(PaddingLeftDepth), int64_t(PaddingLeftHeight), int64_t(PaddingLeftWidth), int64_t(PaddingRightDepth), int64_t(PaddingRightHeight), int64_t(PaddingRightWidth)}; - int64_t StrideShape[] = {int64_t(StrideDepth), int64_t(StrideHeight), int64_t(StrideWidth)}; - int64_t OutputShape[] = {int64_t(BatchCount), int64_t(InputChannels), OutputDepth64, OutputHeight64, OutputWidth64}; - - OutputShape[2] = (InputShape[2] + Padding[0] + Padding[3] - KernelShape[0]) / StrideShape[0] + 1; - OutputShape[3] = (InputShape[3] + Padding[1] + Padding[4] - KernelShape[1]) / StrideShape[1] + 1; - OutputShape[4] = (InputShape[4] + Padding[2] + Padding[5] - KernelShape[2]) / StrideShape[2] + 1; - - size_t InputBufferElements = size_t(InputShape[0] * InputShape[1] * InputShape[2] * InputShape[3] * InputShape[4]); - size_t OutputBufferElements = size_t(OutputShape[0] * OutputShape[1] * OutputShape[2] * OutputShape[3] * OutputShape[4]); - - const float* Input = BufferInput.GetBuffer(InputBufferElements); - float* Output = BufferOutput.GetBuffer(OutputBufferElements); - float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements); - - MlasPool(PoolingKind, 3, InputShape, KernelShape, Padding, StrideShape, OutputShape, Input, Output, threadpool_); - if constexpr (PoolingKind == MlasMaximumPooling) { - ReferenceMaximumPool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference); - } else if constexpr (PoolingKind == MlasAveragePoolingExcludePad) { - ReferenceAveragePool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, false); - } else if constexpr (PoolingKind == MlasAveragePoolingIncludePad) { - ReferenceAveragePool3D(InputShape, KernelShape, Padding, StrideShape, Input, OutputReference, true); - } - - ASSERT_EQ(memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)), 0) - << "PoolingKind:" << int(PoolingKind) << " " - << "input(" << InputChannels << "," << InputDepth << "," << InputHeight << ", " << InputWidth << "), " - << "Kernel(" << KernelDepth << "," << KernelHeight << "," << KernelWidth << ")"; - } - - protected: - void ReferenceMaximumPool3D(const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const float* Input, - float* Output) { - int64_t ChannelCount = InputShape[0] * InputShape[1]; - - int64_t InputDepth = InputShape[2]; - int64_t InputHeight = InputShape[3]; - int64_t InputWidth = InputShape[4]; - - int64_t KernelDepth = KernelShape[0]; - int64_t KernelHeight = KernelShape[1]; - int64_t KernelWidth = KernelShape[2]; - - int64_t PaddingLeftZ = Padding[0]; - int64_t PaddingLeftY = Padding[1]; - int64_t PaddingLeftX = Padding[2]; - int64_t PaddingRightZ = Padding[3]; - int64_t PaddingRightY = Padding[4]; - int64_t PaddingRightX = Padding[5]; - - int64_t StrideDepth = StrideShape[0]; - int64_t StrideHeight = StrideShape[1]; - int64_t StrideWidth = StrideShape[2]; - - int64_t OutputDepth = (InputDepth + PaddingLeftZ + PaddingRightZ - KernelDepth) / StrideDepth + 1; - int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1; - int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1; - - for (int64_t c = 0; c < ChannelCount; c++) { - for (int64_t pd = 0; pd < OutputDepth; pd++) { - int64_t idStart = pd * StrideDepth - PaddingLeftZ; - int64_t idEnd = idStart + KernelDepth; - - idStart = (std::max)(idStart, int64_t(0)); - idEnd = (std::min)(idEnd, InputDepth); - - for (int64_t ph = 0; ph < OutputHeight; ph++) { - int64_t ihStart = ph * StrideHeight - PaddingLeftY; - int64_t ihEnd = ihStart + KernelHeight; - - ihStart = (std::max)(ihStart, int64_t(0)); - ihEnd = (std::min)(ihEnd, InputHeight); - - for (int64_t pw = 0; pw < OutputWidth; pw++) { - int64_t iwStart = pw * StrideWidth - PaddingLeftX; - int64_t iwEnd = iwStart + KernelWidth; - - iwStart = (std::max)(iwStart, int64_t(0)); - iwEnd = (std::min)(iwEnd, InputWidth); - - float m = std::numeric_limits::lowest(); - - for (int64_t id = idStart; id < idEnd; id++) { - for (int64_t ih = ihStart; ih < ihEnd; ih++) { - for (int64_t iw = iwStart; iw < iwEnd; iw++) { - m = (std::max)(m, Input[id * InputHeight * InputWidth + ih * InputWidth + iw]); - } - } - } - - Output[pd * OutputHeight * OutputWidth + ph * OutputWidth + pw] = m; - } - } - } - - Input += InputDepth * InputHeight * InputWidth; - Output += OutputDepth * OutputHeight * OutputWidth; - } - } - - void ReferenceAveragePool3D(const int64_t* InputShape, - const int64_t* KernelShape, - const int64_t* Padding, - const int64_t* StrideShape, - const float* Input, - float* Output, - bool CountIncludePad) { - int64_t ChannelCount = InputShape[0] * InputShape[1]; - - int64_t InputDepth = InputShape[2]; - int64_t InputHeight = InputShape[3]; - int64_t InputWidth = InputShape[4]; - - int64_t KernelDepth = KernelShape[0]; - int64_t KernelHeight = KernelShape[1]; - int64_t KernelWidth = KernelShape[2]; - - int64_t PaddingLeftZ = Padding[0]; - int64_t PaddingLeftY = Padding[1]; - int64_t PaddingLeftX = Padding[2]; - int64_t PaddingRightZ = Padding[3]; - int64_t PaddingRightY = Padding[4]; - int64_t PaddingRightX = Padding[5]; - - int64_t StrideDepth = StrideShape[0]; - int64_t StrideHeight = StrideShape[1]; - int64_t StrideWidth = StrideShape[2]; - - int64_t OutputDepth = (InputDepth + PaddingLeftZ + PaddingRightZ - KernelDepth) / StrideDepth + 1; - int64_t OutputHeight = (InputHeight + PaddingLeftY + PaddingRightY - KernelHeight) / StrideHeight + 1; - int64_t OutputWidth = (InputWidth + PaddingLeftX + PaddingRightX - KernelWidth) / StrideWidth + 1; - - for (int64_t c = 0; c < ChannelCount; c++) { - for (int64_t pd = 0; pd < OutputDepth; pd++) { - int64_t idStart = pd * StrideDepth - PaddingLeftZ; - int64_t idEnd = idStart + KernelDepth; - - idStart = (std::max)(idStart, int64_t(0)); - idEnd = (std::min)(idEnd, InputDepth); - - for (int64_t ph = 0; ph < OutputHeight; ph++) { - int64_t ihStart = ph * StrideHeight - PaddingLeftY; - int64_t ihEnd = ihStart + KernelHeight; - - ihStart = (std::max)(ihStart, int64_t(0)); - ihEnd = (std::min)(ihEnd, InputHeight); - - for (int64_t pw = 0; pw < OutputWidth; pw++) { - int64_t iwStart = pw * StrideWidth - PaddingLeftX; - int64_t iwEnd = iwStart + KernelWidth; - - iwStart = (std::max)(iwStart, int64_t(0)); - iwEnd = (std::min)(iwEnd, InputWidth); - - float m = 0.0f; - - for (int64_t id = idStart; id < idEnd; id++) { - for (int64_t ih = ihStart; ih < ihEnd; ih++) { - for (int64_t iw = iwStart; iw < iwEnd; iw++) { - m += Input[id * InputHeight * InputWidth + ih * InputWidth + iw]; - } - } - } - - if (CountIncludePad) { - m /= (KernelDepth * KernelHeight * KernelWidth); - } else { - m /= (idEnd - idStart) * (ihEnd - ihStart) * (iwEnd - iwStart); - } - - Output[pd * OutputHeight * OutputWidth + ph * OutputWidth + pw] = m; - } - } - } - - Input += InputDepth * InputHeight * InputWidth; - Output += OutputDepth * OutputHeight * OutputWidth; - } - } - - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - MLAS_THREADPOOL* threadpool_; - - public: - // void - // ExecuteShort( - // void - // ) override - // { - // for (unsigned i = 1; i < 64; i <<= 1) { - // Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1); - // Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 2, 2, 2); - // Test(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1); - // Test(1, 16, i, i, i, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1); - // Test(1, 16, i, i, i, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1); - // Test(1, 16, i, i, i, 1, i, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1); - // Test(1, 16, i, i, i, 1, 1, i, 0, 0, 0, 0, 0, 0, 1, 1, 1); - // } - // } - - void ExecuteLong(void) override { - static const unsigned is[] = {11, 5, 4, 3, 2, 1}; - - for (unsigned id = 0; id < _countof(is); id++) { - for (unsigned ih = 0; ih < _countof(is); ih++) { - for (unsigned iw = 0; iw < _countof(is); iw++) { - fprintf(stderr, "Handling %ux%ux%u\n", is[id], is[ih], is[iw]); - Test(1, 1, is[id], is[ih], is[iw], is[id], is[ih], is[iw], 0, 0, 0, 0, 0, 0, 1, 1, 1); - for (unsigned kd = 1; kd <= 4; kd++) { - if (kd > is[id]) break; - for (unsigned kh = 1; kh <= 4; kh++) { - if (kh > is[ih]) break; - for (unsigned kw = 1; kw <= 4; kw++) { - if (kw > is[iw]) break; - for (unsigned sd = 1; sd <= 3; sd++) { - for (unsigned sh = 1; sh <= 3; sh++) { - for (unsigned sw = 1; sw <= 3; sw++) { - for (unsigned p0 = 0; p0 < kd; p0++) { - for (unsigned p1 = 0; p1 < kh; p1++) { - for (unsigned p2 = 0; p2 < kw; p2++) { - for (unsigned p3 = 0; p3 < kd; p3++) { - for (unsigned p4 = 0; p4 < kh; p4++) { - for (unsigned p5 = 0; p5 < kw; p5++) { - Test(1, 1, is[id], is[ih], is[iw], kd, kh, kw, p0, p1, p2, p3, p4, p5, sd, sh, sw); - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_pool3d_fixture.h b/onnxruntime/test/mlas/unittest/test_pool3d_fixture.h deleted file mode 100644 index e3d2aebc39cec..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_pool3d_fixture.h +++ /dev/null @@ -1,163 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_pool3d.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class Pooling3dShortExecuteTest : public MlasTestFixture> { - public: - explicit Pooling3dShortExecuteTest(size_t BatchCount, - size_t InputChannels, - size_t InputDepth, - size_t InputHeight, - size_t InputWidth, - size_t KernelDepth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftDepth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightDepth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideDepth, - size_t StrideHeight, - size_t StrideWidth) - : BatchCount_(BatchCount), - InputChannels_(InputChannels), - InputDepth_(InputDepth), - InputHeight_(InputHeight), - InputWidth_(InputWidth), - KernelDepth_(KernelDepth), - KernelHeight_(KernelHeight), - KernelWidth_(KernelWidth), - PaddingLeftDepth_(PaddingLeftDepth), - PaddingLeftHeight_(PaddingLeftHeight), - PaddingLeftWidth_(PaddingLeftWidth), - PaddingRightDepth_(PaddingRightDepth), - PaddingRightHeight_(PaddingRightHeight), - PaddingRightWidth_(PaddingRightWidth), - StrideDepth_(StrideDepth), - StrideHeight_(StrideHeight), - StrideWidth_(StrideWidth) { - } - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test( - BatchCount_, - InputChannels_, - InputDepth_, - InputHeight_, - InputWidth_, - KernelDepth_, - KernelHeight_, - KernelWidth_, - PaddingLeftDepth_, - PaddingLeftHeight_, - PaddingLeftWidth_, - PaddingRightDepth_, - PaddingRightHeight_, - PaddingRightWidth_, - StrideDepth_, - StrideHeight_, - StrideWidth_); - } - - static size_t RegisterSingleTest(size_t BatchCount, - size_t InputChannels, - size_t InputDepth, - size_t InputHeight, - size_t InputWidth, - size_t KernelDepth, - size_t KernelHeight, - size_t KernelWidth, - size_t PaddingLeftDepth, - size_t PaddingLeftHeight, - size_t PaddingLeftWidth, - size_t PaddingRightDepth, - size_t PaddingRightHeight, - size_t PaddingRightWidth, - size_t StrideDepth, - size_t StrideHeight, - size_t StrideWidth) { - std::stringstream ss; - ss << "B" << BatchCount << "/" - << "C" << InputChannels << "/" - << "Input_" << InputDepth << "x" << InputHeight << "x" << InputWidth << "/" - << "Kernel" << KernelDepth << "x" << KernelHeight << "x" << KernelWidth << "/" - << "Pad" << PaddingLeftDepth << "," << PaddingLeftHeight << "," << PaddingLeftWidth - << "," << PaddingRightDepth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Stride" << StrideDepth << "," << StrideHeight << "," << StrideWidth; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasPool3DTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new Pooling3dShortExecuteTest(BatchCount, - InputChannels, - InputDepth, - InputHeight, - InputWidth, - KernelDepth, - KernelHeight, - KernelWidth, - PaddingLeftDepth, - PaddingLeftHeight, - PaddingLeftWidth, - PaddingRightDepth, - PaddingRightHeight, - PaddingRightWidth, - StrideDepth, - StrideHeight, - StrideWidth); - }); - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (unsigned i = 1; i < 64; i <<= 1) { - test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 2, 2, 2); - test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 1, i, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1); - test_registered += RegisterSingleTest(1, 16, i, i, i, 1, 1, i, 0, 0, 0, 0, 0, 0, 1, 1, 1); - } - - return test_registered; - } - - private: - size_t BatchCount_; - size_t InputChannels_; - size_t InputDepth_; - size_t InputHeight_; - size_t InputWidth_; - size_t KernelDepth_; - size_t KernelHeight_; - size_t KernelWidth_; - size_t PaddingLeftDepth_; - size_t PaddingLeftHeight_; - size_t PaddingLeftWidth_; - size_t PaddingRightDepth_; - size_t PaddingRightHeight_; - size_t PaddingRightWidth_; - size_t StrideDepth_; - size_t StrideHeight_; - size_t StrideWidth_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q4gemm.cpp deleted file mode 100644 index dccd7d00b6d3f..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.cpp +++ /dev/null @@ -1,109 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_q4gemm.cpp - -Abstract: - - Tests for MLAS GEMM for blockwise int4 quantization. - ---*/ - -#ifndef ORT_MINIMAL_BUILD - -#include "test_q4gemm.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class Q4GemmShortExecuteTest : public MlasTestFixture> { - public: - explicit Q4GemmShortExecuteTest(size_t M, size_t N, size_t K, bool hasBias) - : M_(M), N_(N), K_(K), hasBias_(hasBias) {} - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, hasBias_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool hasBias) { - std::stringstream ss; - ss << "/M" << M << "xN" << N << "xK" << K << "/" - << "hasBias" << hasBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasQ4GemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new Q4GemmShortExecuteTest( - M, N, K, hasBias); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, false); - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, false); - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, false); - test_registered += RegisterSingleTest(1, 32, b, true); - test_registered += RegisterSingleTest(1, b, b, false); - } - test_registered += RegisterSingleTest(43, 500, 401, true); - // test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); - - return test_registered; - } - - private: - size_t M_, N_, K_; - bool hasBias_; -}; - -static size_t Q4GemmRegistShortExecute() { - size_t count = 0; - - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (MlasQ4GemmPackBSize(BlkQ4Sym, 32, 32) == 0) { - return false; - } - if (is_short_execute) { - return Q4GemmRegistShortExecute() > 0; - } - return false; -}); - -#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.h b/onnxruntime/test/mlas/unittest/test_q4gemm.h deleted file mode 100644 index 97c6969b5bf91..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.h +++ /dev/null @@ -1,144 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_q4gemm.h - -Abstract: - - Tests for MLAS int4 block quantized GEMM. - ---*/ - -#pragma once - -#include "test_util.h" -#include "mlas_q4.h" - -/** - * @brief Test class for int4 block quantized GEMM - * Note: only 2-D matmul supported for now - */ -template -class MlasQ4GemmTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferBPacked; - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - MatrixGuardBuffer BufferUnpack; - MLAS_THREADPOOL* threadpool_; - - void* PackB(size_t N, size_t K, const float* B, size_t ldb) { - size_t PackedBSize = MlasQ4GemmPackBSize(QType, N, K); - if (PackedBSize == 0) { - return nullptr; - } - void* PackedB = BufferBPacked.GetBuffer(PackedBSize); - MlasQ4GemmPackB(QType, PackedB, B, N, K, ldb); - return PackedB; - } - - void CallGemm(size_t M, - size_t N, - size_t K, - const float* A, - size_t lda, - const uint8_t* PackedB, - const float* Bias, - float* C, - size_t ldc) { - MLAS_Q4_GEMM_DATA_PARAMS params; - params.A = A; - params.lda = lda; - params.Bias = Bias; - params.C = C; - params.ldc = ldc; - params.B = PackedB; - params.OutputProcessor = nullptr; - - MlasQ4GemmBatch(QType, M, N, K, 1, ¶ms, threadpool_); - } - - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - const float* A, - const uint8_t* PackedB, - const float* Bias, - float* C) { - // std::vector B(K * N); - // MlasQ4GemmUnPackB(QType, B.data(), PackedB, N, K, N); - float* bdata = BufferUnpack.GetBuffer(K * N); - MlasQ4GemmUnPackB(QType, bdata, PackedB, N, K, N); - - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const float* a = A + m * K; - const float* b = bdata + n; - float* c = C + (m * N) + n; - - float sum = Bias == nullptr ? 0.0f : Bias[n]; - for (size_t k = 0; k < K; k++) { - sum += (*a) * (*b); - b += N; - a += 1; - } - *c = sum; - } - } - } - - public: - MlasQ4GemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test(size_t M, size_t N, size_t K, bool withBias) { - const float* A = BufferA.GetBuffer(K * M); - - const float* B = BufferB.GetBuffer(N * K); - - const float* Bias = nullptr; - if (withBias) { - Bias = BufferBias.GetBuffer(N); - } - - float* C = BufferC.GetBuffer(N * M, true); - float* CReference = BufferCReference.GetFilledBuffer( - N * M, - [](float* start, size_t size) { - std::fill_n(start, size, -1.0f); - }); - const uint8_t* PackedB = (uint8_t*)PackB(N, K, B, N); - this->CallGemm(M, N, K, A, K, PackedB, Bias, C, N); - ReferenceQgemm(M, N, K, A, PackedB, Bias, CReference); - size_t f = 0; - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_TRUE(CloseEnough(C[f], CReference[f])) - << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " - << "M=" << M << ", N=" << N << ", K=" << K; - } - } - } - - public: - static const char* GetTestSuiteName() { - /* - BlkQ4Sym = 0, - BlkQ4Zp8 = 1, - BlkQ4Sym64 = 2, - BlkQ4Sym128 = 4 - */ - static const std::vector qtype_names = {"BlkQ4Sym", "BlkQ4Zp8", "BlkQ4Sym64", "", "BlkQ4Sym128"}; - static std::string suite_name = std::string("Q4GemmFP") + - qtype_names[QType] + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp b/onnxruntime/test/mlas/unittest/test_q4qdq.cpp deleted file mode 100644 index c317395bee970..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_q4qdq.cpp +++ /dev/null @@ -1,155 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_q4qdq.cpp - -Abstract: - - Tests for MLAS int4 quantization and dequantization code. - ---*/ - -#ifndef ORT_MINIMAL_BUILD - -#include "test_util.h" -#include "mlas_q4.h" - -#if ((defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)) - -/** - * @brief For testing purpose, - * Dequantize the data intp fp32, and then pack them for use - * in sgemm kernel. equivalent to MlasQ4GemmUnPackB and then - * MlasSgemmCopyPackB - * @param QType - * @param FpData - * @param PackedB - * @param CountN - * @param CountK - * @param ldb - */ -void MlasBlkQ4DequantSgemmPackB( - MLAS_BLK_QUANT_TYPE QType, - float* FpData, - const uint8_t* PackedB, - size_t CountN, - size_t CountK, - size_t ldb); - -void MlasSgemmCopyPackB( - float* D, - const float* B, - size_t ldb, - size_t CountX, - size_t CountY); - -#endif // x64 - -class MlasQ4dqTest : public MlasTestBase { - private: - MatrixGuardBuffer FpInputBuf; - MatrixGuardBuffer PackedBuf; - MatrixGuardBuffer FpOutBuf; - MatrixGuardBuffer SgemmPackBuf; - MatrixGuardBuffer SgemmPackRefBuf; - - void Test(size_t N, size_t K, MLAS_BLK_QUANT_TYPE qtype) { - float* Input = FpInputBuf.GetBuffer(N * K, true); - if (qtype != BlkQ4Zp8) { - int v = -7; - for (size_t i = 0; i < N * K; i++) { - if (v == 0 || v == -3 || v == 3) { - v++; - } - Input[i] = (float)v; - if (++v >= 8) { - v = -8; - } - } - } else { - int v = 0; - for (size_t i = 0; i < N * K; i++) { - Input[i] = (float)v; - if (++v >= 16) { - v = -0; - } - } - } - - size_t qsize = MlasQ4GemmPackBSize(qtype, N, K); - uint8_t* Packed = PackedBuf.GetBuffer(qsize, true); - float* Output = FpOutBuf.GetBuffer(N * K, true); - - MlasQ4GemmPackB(qtype, Packed, Input, N, K, N); - MlasQ4GemmUnPackB(qtype, Output, Packed, N, K, N); - - for (size_t i = 0; i < N * K; i++) { - ASSERT_EQ(Output[i], Input[i]) << ", index=" << i << ", [" << N << "x" - << K << "] QType: " << qtype; - } - -#if ((defined(_M_AMD64) && !defined(_M_ARM64EC)) || defined(__x86_64__)) - - /* Test MlasBlkQ4DequantSgemmPackB, make sure we can reuse SGEMM kernel as it rearrange B the same way as sgemm pack B*/ - const size_t AlignedN = (N + 15) & ~15; - const size_t AlignedK = (K + 15) & ~15; - float* gemmpack = SgemmPackBuf.GetBuffer(AlignedK * AlignedN, true); - float* gemmpack_ref = SgemmPackRefBuf.GetBuffer(AlignedK * AlignedN, true); - MlasSgemmCopyPackB(gemmpack_ref, Input, N, N, K); - - const size_t blkq_ldb = MlasQ4GemmPackBSize(qtype, 1, K); - MlasBlkQ4DequantSgemmPackB(qtype, gemmpack, Packed, N, K, blkq_ldb); - for (size_t i = 0; i < AlignedN * K; i++) { - ASSERT_EQ(gemmpack[i], gemmpack_ref[i]) << ", sgemm pack index=" << i << ", [" << N << "x" - << K << "] QType: " << qtype; - } -#endif // x64 - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("Q4DQ"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - Test(1, 20, BlkQ4Sym); - Test(1, 20, BlkQ4Zp8); - Test(1, 52, BlkQ4Sym); - Test(1, 52, BlkQ4Zp8); - Test(1, 52, BlkQ4Sym64); - Test(3, 20, BlkQ4Sym); - Test(3, 20, BlkQ4Zp8); - Test(3, 52, BlkQ4Sym); - Test(3, 52, BlkQ4Zp8); - Test(3, 52, BlkQ4Sym64); - Test(static_cast(4 * 10) + 1, static_cast(32 * 9) + 17, BlkQ4Zp8); - Test(static_cast(4 * 10) + 1, static_cast(32 * 9) + 17, BlkQ4Sym); - Test(static_cast(4 * 10) + 1, static_cast(32 * 9) + 17, BlkQ4Sym64); - Test(static_cast(4 * 10) + 1, static_cast(32 * 9) + 17, BlkQ4Sym128); - Test(static_cast(4 * 20) + 3, static_cast(32 * 15) + 17, BlkQ4Zp8); - Test(static_cast(4 * 20) + 3, static_cast(32 * 15) + 17, BlkQ4Sym); - Test(static_cast(4 * 20) + 3, static_cast(32 * 15) + 17, BlkQ4Sym64); - Test(static_cast(4 * 20) + 3, static_cast(32 * 15) + 17, BlkQ4Sym128); - } - - MlasQ4dqTest() = default; -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (MlasQ4GemmPackBSize(BlkQ4Sym, 32, 32) == 0) { - return (size_t)0; - } - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); - } - return count; -}); - -#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp deleted file mode 100644 index d3f601793a970..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ /dev/null @@ -1,283 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_q8q4gemm.cpp - -Abstract: - - Tests for MLAS int8 x int4 block quantized GEMM. - ---*/ - -#ifndef ORT_MINIMAL_BUILD - -#include "test_util.h" -#include "mlas_q4.h" - -template -static void blkq8_dequant_reference(const int8_t* src, float* dst, size_t M, size_t K) { - const size_t num_blks = K / QBlkLen; - const size_t remain = K % QBlkLen; - const auto* blob = reinterpret_cast(src); - - for (size_t m = 0; m < M; m++) { - for (size_t i = 0; i < num_blks; i++, dst += QBlkLen) { - const float scale = *reinterpret_cast(blob); - blob += sizeof(float); - for (size_t j = 0; j < QBlkLen; ++j) { - dst[j] = *(blob++) * scale; - } - } - - if (remain > 0) { - const float scale = *reinterpret_cast(blob); - blob += sizeof(float); - for (size_t j = 0; j < remain; ++j) { - dst[j] = blob[j] * scale; - } - blob += QBlkLen; - dst += remain; - } - } -} - -/** - * @brief Test class for int8 x int4 block quantized GEMM - * Note: only 2-D matmul supported for now - */ -template -class MlasQ8Q4GemmTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferAQuant; - MatrixGuardBuffer BufferDequantA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferBPacked; - MatrixGuardBuffer BufferUnpack; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - MLAS_THREADPOOL* threadpool_; - - void* PackB(size_t N, size_t K, const float* B, size_t ldb) { - size_t PackedBSize = MlasQ4GemmPackBSize(QType, N, K); - if (PackedBSize == 0) { - return nullptr; - } - void* PackedB = BufferBPacked.GetBuffer(PackedBSize); - MlasQ4GemmPackB(QType, PackedB, B, N, K, ldb); - return PackedB; - } - - int8_t* QuantizeA(size_t M, size_t K, const float* A, size_t lda) { - size_t bufsize = MlasQ80BlkQuantSize(QType, M, K); - if (bufsize == 0) { - return nullptr; - } - auto* QuantA = BufferAQuant.GetBuffer(bufsize); - MlasQ80BlkQuant(QType, QuantA, A, M, K, lda, threadpool_); - return QuantA; - } - - void CallGemm(size_t M, - size_t N, - size_t K, - const int8_t* QuantA, - const uint8_t* PackedB, - const float* Bias, - float* C, - size_t ldc) { - MLAS_Q8Q4_GEMM_DATA_PARAMS params; - params.A = QuantA; - params.B = PackedB; - params.Bias = Bias; - params.C = C; - params.ldc = ldc; - params.OutputProcessor = nullptr; - - MlasQ8Q4GemmBatch(QType, M, N, K, 1, ¶ms, threadpool_); - } - - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - const int8_t* QuantA, - const uint8_t* PackedB, - const float* Bias, - float* C) { - // std::vector B(K * N); - // MlasQ4GemmUnPackB(QType, B.data(), PackedB, N, K, N); - float* bdata = BufferUnpack.GetBuffer(K * N); - MlasQ4GemmUnPackB(QType, bdata, PackedB, N, K, N); - - float* adata = BufferDequantA.GetBuffer(M * K); - switch (QType) { - case BlkQ4Sym64: - blkq8_dequant_reference<64>(QuantA, adata, M, K); - break; - case BlkQ4Sym128: - blkq8_dequant_reference<128>(QuantA, adata, M, K); - break; - default: - blkq8_dequant_reference<32>(QuantA, adata, M, K); - break; - } - - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const float* a = adata + m * K; - const float* b = bdata + n; - float* c = C + (m * N) + n; - - float sum = Bias == nullptr ? 0.0f : Bias[n]; - for (size_t k = 0; k < K; k++) { - sum += (*a) * (*b); - b += N; - a += 1; - } - *c = sum; - } - } - } - - public: - MlasQ8Q4GemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test(size_t M, size_t N, size_t K, bool withBias) { - const float* A = BufferA.GetBuffer(K * M); - - const float* B = BufferB.GetBuffer(N * K); - - const float* Bias = nullptr; - if (withBias) { - Bias = BufferBias.GetBuffer(N); - } - - float* C = BufferC.GetBuffer(N * M, true); - float* CReference = BufferCReference.GetFilledBuffer( - N * M, - [](float* start, size_t size) { - std::fill_n(start, size, -1.0f); - }); - const uint8_t* PackedB = (uint8_t*)PackB(N, K, B, N); - const int8_t* QuantA = QuantizeA(M, K, A, K); - this->CallGemm(M, N, K, QuantA, PackedB, Bias, C, N); - ReferenceQgemm(M, N, K, QuantA, PackedB, Bias, CReference); - size_t f = 0; - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_TRUE(CloseEnough(C[f], CReference[f])) - << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " - << "M=" << M << ", N=" << N << ", K=" << K; - } - } - } - - public: - static const char* GetTestSuiteName() { - /* - BlkQ4Sym = 0, - BlkQ4Zp8 = 1, - BlkQ4Sym64 = 2, - BlkQ4Sym128 = 4 - */ - static const std::vector qtype_names = {"BlkQ4Sym", "BlkQ4Zp8", "BlkQ4Sym64", "", "BlkQ4Sym128"}; - static std::string suite_name = std::string("Q8Q4GemmFP") + - qtype_names[QType] + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } -}; - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class Q8Q4GemmShortExecuteTest : public MlasTestFixture> { - public: - explicit Q8Q4GemmShortExecuteTest(size_t M, size_t N, size_t K, bool hasBias) - : M_(M), N_(N), K_(K), hasBias_(hasBias) {} - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, hasBias_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool hasBias) { - std::stringstream ss; - ss << "/M" << M << "xN" << N << "xK" << K << "/" - << "hasBias" << hasBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasQ8Q4GemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new Q8Q4GemmShortExecuteTest( - M, N, K, hasBias); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, false); - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, true); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, false); - test_registered += RegisterSingleTest(1, 32, b, true); - test_registered += RegisterSingleTest(1, b, b, false); - } - test_registered += RegisterSingleTest(43, 500, 401, true); - - return test_registered; - } - - private: - size_t M_, N_, K_; - bool hasBias_; -}; - -static size_t Q8Q4GemmRegistShortExecute() { - size_t count = 0; - - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - count += Q8Q4GemmShortExecuteTest::RegisterShortExecuteTests(); - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (MlasQ80BlkQuantSize(BlkQ4Sym, 32, 32) == 0) { - return false; // operation not yet supported on current hardware - } - if (is_short_execute) { - return Q8Q4GemmRegistShortExecute() > 0; - } - return false; -}); - -#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_qgemm.cpp deleted file mode 100644 index 12955e6f04688..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_qgemm.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "test_qgemm.h" -#include "test_qgemm_fixture.h" - -static size_t QGemmRegistLongExecute() { - size_t count = 0; - - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - - return count; -} - -static size_t QGemmRegistShortExecute() { - size_t count = 0; - - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { - // QGEMM U8U8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM U8U8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, true /*BIsSigned*/) > 0) { - // QGEMM U8S8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM U8S8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - // QGEMM S8S8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM S8S8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { - // QGEMM S8U8=float packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - // QGEMM S8U8=int32_t packed tests - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - - if (GetMlasThreadPool() != nullptr) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, false /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, false /*AIsSigned*/, true /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, true /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - if (MlasGemmPackBSize(128, 128, true /*AIsSigned*/, false /*BIsSigned*/) > 0) { - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - count += QgemmShortExecuteTest::RegisterShortExecuteTests(); - } - } - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? QGemmRegistShortExecute() : QGemmRegistLongExecute(); -}); \ No newline at end of file diff --git a/onnxruntime/test/mlas/unittest/test_qgemm.h b/onnxruntime/test/mlas/unittest/test_qgemm.h deleted file mode 100644 index 4097b3588a04e..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_qgemm.h +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -class MlasQgemmTestBase : public MlasTestBase { - private: - void* PackB(size_t N, size_t K, const uint8_t* B, size_t ldb, bool AIsSigned, bool BIsSigned) { - size_t PackedBSize = MlasGemmPackBSize(N, K, AIsSigned, BIsSigned); - void* PackedB = BufferBPacked.GetBuffer(PackedBSize); - MlasGemmPackB(N, K, B, ldb, AIsSigned, BIsSigned, PackedB); - return PackedB; - } - - protected: - MLAS_THREADPOOL* threadpool_; - - MlasQgemmTestBase() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void TestGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - uint8_t offa, - bool AIsSigned, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - int32_t* C, - size_t ldc) { - MLAS_GEMM_QUANT_SHAPE_PARAMS GemmShape; - GemmShape.M = M; - GemmShape.N = N; - GemmShape.K = K; - GemmShape.AIsSigned = AIsSigned; - GemmShape.BIsSigned = BIsSigned; - - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < GemmParameters.size(); i++) { - auto& params = GemmParameters[i]; - params.A = A + (M * K * i); - params.lda = lda; - params.ZeroPointA = offa; - params.ZeroPointB = &offb; - params.C = C + (M * N * i); - params.ldc = ldc; - - if (Packed) { - ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb, AIsSigned, BIsSigned); - params.BIsPacked = true; - } else { - params.B = B + (K * N * i); - params.ldb = ldb; - } - } - - MlasGemmBatch(GemmShape, GemmParameters.data(), BatchSize, threadpool_); - } - - void TestGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - uint8_t offa, - bool AIsSigned, - const uint8_t* B, - size_t ldb, - const uint8_t* offb, - bool BIsSigned, - int32_t* C, - size_t ldc) { - MLAS_GEMM_QUANT_SHAPE_PARAMS GemmShape; - GemmShape.M = M; - GemmShape.N = N; - GemmShape.K = K; - GemmShape.AIsSigned = AIsSigned; - GemmShape.BIsSigned = BIsSigned; - - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < GemmParameters.size(); i++) { - auto& params = GemmParameters[i]; - params.A = A + M * K * i; - params.lda = lda; - params.ZeroPointA = offa; - params.ZeroPointB = offb; - params.PerColumnZeroPoints = true; - params.C = C + M * N * i; - params.ldc = ldc; - - if (Packed) { - ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb, AIsSigned, BIsSigned); - params.BIsPacked = true; - } else { - params.B = B + K * N * i; - params.ldb = ldb; - } - } - - MlasGemmBatch(GemmShape, GemmParameters.data(), BatchSize, threadpool_); - } - - void TestGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - uint8_t offa, - bool AIsSigned, - const uint8_t* B, - size_t ldb, - uint8_t offb, - bool BIsSigned, - float* C, - size_t ldc, - float CScale, - const float* Bias) { - MLAS_GEMM_QUANT_SHAPE_PARAMS GemmShape; - GemmShape.M = M; - GemmShape.N = N; - GemmShape.K = K; - GemmShape.AIsSigned = AIsSigned; - GemmShape.BIsSigned = BIsSigned; - - std::vector ScaleBiasProcessors; - ScaleBiasProcessors.reserve(BatchSize); - - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < BatchSize; i++) { - auto& params = GemmParameters[i]; - params.A = A + M * K * i; - params.lda = lda; - params.ZeroPointA = offa; - params.ZeroPointB = &offb; - params.C = reinterpret_cast(C + M * N * i); - params.ldc = ldc; - - if (Packed) { - ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb, AIsSigned, BIsSigned); - params.BIsPacked = true; - } else { - params.B = B + K * N * i; - params.ldb = ldb; - } - ScaleBiasProcessors.emplace_back(C + M * N * i, ldc, &CScale, Bias); - params.OutputProcessor = &(ScaleBiasProcessors[i]); - } - - MlasGemmBatch(GemmShape, GemmParameters.data(), BatchSize, threadpool_); - } - - private: - MatrixGuardBuffer BufferBPacked; -}; - -template -class MlasQgemmTest; - -template -class MlasQgemmTest : public MlasQgemmTestBase { - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize, uint8_t offa, uint8_t offb) { - const uint8_t* A = BufferA.GetBuffer(K * M * BatchSize); - const uint8_t* B = BufferB.GetBuffer(N * K * BatchSize); - int32_t* C = BufferC.GetBuffer(N * M * BatchSize); - int32_t* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - - Test(M, N, K, BatchSize, A, K, offa, B, N, offb, C, CReference, N); - } - - void Test(size_t M, size_t N, size_t K, size_t BatchSize, uint8_t offa) { - const uint8_t* A = BufferA.GetBuffer(K * M * BatchSize); - const uint8_t* B = BufferB.GetBuffer(N * K * BatchSize); - const uint8_t* ZeroPointB = BufferZeroPointB.GetBuffer(N); - int32_t* C = BufferC.GetBuffer(N * M * BatchSize); - int32_t* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - - Test(M, N, K, BatchSize, A, K, offa, B, N, ZeroPointB, C, CReference, N); - } - - void Test(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - uint8_t offb, - int32_t* C, - int32_t* CReference, - size_t ldc) { - std::fill_n(C, M * N * BatchSize, -1); - std::fill_n(CReference, M * N * BatchSize, -1); - - this->TestGemm(M, N, K, BatchSize, A, lda, offa, AIsSigned, B, ldb, offb, BIsSigned, C, ldc); - ReferenceQgemm(M, N, K, BatchSize, (const AType*)A, lda, (AType)offa, (const BType*)B, ldb, (BType)offb, CReference, ldc); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_EQ(C[f], CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K - << ", offa=" << int(offa) << ", offb=" << int(offb); - } - } - } - } - - void Test(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - uint8_t offa, - const uint8_t* B, - size_t ldb, - const uint8_t* offb, - int32_t* C, - int32_t* CReference, - size_t ldc) { - std::fill_n(C, M * N * BatchSize, -1); - std::fill_n(CReference, M * N * BatchSize, -1); - - this->TestGemm(M, N, K, BatchSize, A, lda, offa, AIsSigned, B, ldb, offb, BIsSigned, C, ldc); - ReferenceQgemm(M, N, K, BatchSize, (const AType*)A, lda, (AType)offa, (const BType*)B, ldb, (const BType*)offb, CReference, ldc); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_EQ(C[f], CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K - << ", offa=" << int(offa) << ", offb=--"; - } - } - } - } - - private: - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - size_t lda, - AType offa, - const BType* B, - size_t ldb, - BType offb, - int32_t* C, - size_t ldc) { - for (size_t batch = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const AType* a = A + (M * K * batch) + (m * lda); - const BType* b = B + (K * N * batch) + n; - int32_t* c = C + (M * N * batch) + (m * ldc) + n; - int32_t sum = 0; - - for (size_t k = 0; k < K; k++) { - sum += ((int32_t(*b) - offb) * (int32_t(*a) - offa)); - b += ldb; - a += 1; - } - - *c = sum; - } - } - } - } - - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - size_t lda, - AType offa, - const BType* B, - size_t ldb, - const BType* offb, - int32_t* C, - size_t ldc) { - for (size_t batch = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const AType* a = A + (M * K * batch) + (m * lda); - const BType* b = B + (K * N * batch) + n; - int32_t* c = C + (M * N * batch) + (m * ldc) + n; - int32_t sum = 0; - - for (size_t k = 0; k < K; k++) { - sum += ((int32_t(*b) - offb[n]) * (int32_t(*a) - offa)); - b += ldb; - a += 1; - } - - *c = sum; - } - } - } - } - - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferZeroPointB; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - const bool AIsSigned = std::is_signed::value; - const bool BIsSigned = std::is_signed::value; - - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("QGemm") + - (std::is_signed::value ? "S8" : "U8") + - (std::is_signed::value ? "S8" : "U8") + - (Packed ? "_Int32_Packed" : "_Int32_NoPack") + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void ExecuteLong(void) override { - static const uint8_t zero_points[] = {0, 18, 75, 128, 157, 231, 255}; - - for (size_t a = 0; a < _countof(zero_points); a++) { - uint8_t offa = zero_points[a]; - - for (size_t b = 0; b < _countof(zero_points); b++) { - uint8_t offb = zero_points[b]; - - for (size_t M = 16; M < 160; M += 32) { - for (size_t N = 16; N < 160; N += 32) { - static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; - for (size_t k = 0; k < _countof(ks); k++) { - size_t K = ks[k]; - - Test(M, N, K, 1, offa, offb); - Test(M + 1, N, K, 1, offa, offb); - Test(M, N + 1, K, 1, offa, offb); - Test(M + 1, N + 1, K, 1, offa, offb); - Test(M + 3, N + 2, K, 1, offa, offb); - Test(M + 4, N, K, 1, offa, offb); - Test(M, N + 4, K, 1, offa, offb); - Test(M + 4, N + 4, K, 1, offa, offb); - Test(M + 3, N + 7, K, 1, offa, offb); - Test(M + 8, N, K, 1, offa, offb); - Test(M, N + 8, K, 1, offa, offb); - Test(M + 12, N + 12, K, 1, offa, offb); - Test(M + 13, N, K, 1, offa, offb); - Test(M, N + 15, K, 1, offa, offb); - Test(M + 15, N + 15, K, 1, offa, offb); - if (!Packed) { - Test(M, N, K, 7 + a, offa, offb); - Test(M + 3, N, K, 7 + a, offa, offb); - Test(M, N + 1, K, 7 + a, offa, offb); - Test(M + 12, N, K, 7 + a, offa, offb); - Test(M, N + 15, K, 7 + a, offa, offb); - Test(M + 15, N + 15, K, 7 + a, offa, offb); - } - } - } - printf("a %zd/%zd b %zd/%zd M %zd\n", a, _countof(zero_points), b, _countof(zero_points), M); - } - } - } - - for (size_t M = 1; M < 160; M++) { - for (size_t N = 1; N < 160; N++) { - for (size_t K = 1; K < 160; K++) { - Test(M, N, K, 1, 18, 24); - } - } - printf("M %zd\n", M); - } - - for (size_t M = 160; M < 320; M += 24) { - for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 1; K < 16; K++) { - Test(M, N, K, 1, 1, 3); - } - for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1, 5, 7); - } - } - printf("M %zd\n", M); - } - } -}; - -template -class MlasQgemmTest : public MlasQgemmTestBase { - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize, uint8_t offa, uint8_t offb) { - const uint8_t* A = BufferA.GetBuffer(K * M * BatchSize); - const uint8_t* B = BufferB.GetBuffer(N * K * BatchSize); - float* C = BufferC.GetBuffer(N * M * BatchSize); - float* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - const float* Bias = BufferBias.GetBuffer(N); - - constexpr float AScale = 0.5f; - float* AFloat = BufferAFloat.GetBuffer(K * M * BatchSize); - for (size_t b = 0; b < BatchSize; b++) { - DequantizeLinear((AType*)(A + K * M * b), AFloat + K * M * b, K * M, AScale, (AType)offa); - } - - constexpr float BScale = 0.25f; - float* BFloat = BufferBFloat.GetBuffer(N * K * BatchSize); - for (size_t b = 0; b < BatchSize; b++) { - DequantizeLinear((BType*)(B + N * K * b), BFloat + N * K * b, N * K, BScale, BType(offb)); - } - - constexpr float CScale = AScale * BScale; - - Test(M, N, K, BatchSize, A, AFloat, K, offa, B, BFloat, N, offb, C, CReference, N, CScale, nullptr); - Test(M, N, K, BatchSize, A, AFloat, K, offa, B, BFloat, N, offb, C, CReference, N, CScale, Bias); - } - - void Test(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - const float* AFloat, - size_t lda, - uint8_t offa, - const uint8_t* B, - const float* BFloat, - size_t ldb, - uint8_t offb, - float* C, - float* CReference, - size_t ldc, - float CScale, - const float* Bias) { - for (size_t b = 0; b < BatchSize; b++) { - MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.0f, - AFloat + K * M * b, lda, - BFloat + N * K * b, ldb, 0.0f, - CReference + N * M * b, ldc, - MlasQgemmTestBase::threadpool_); - } - - if (Bias != nullptr) { - for (size_t b = 0; b < BatchSize; b++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - CReference[N * M * b + m * ldc + n] += Bias[n]; - } - } - } - } - - this->TestGemm(M, N, K, BatchSize, A, lda, offa, AIsSigned, B, ldb, offb, BIsSigned, C, ldc, CScale, Bias); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - // Sensitive to comparing positive/negative zero. - ASSERT_EQ(C[f], CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K - << ", offa=" << int(offa) << ", offb=" << offb; - } - } - } - } - - private: - template - void DequantizeLinear(const qint8_t* Input, - float* Output, - size_t N, - float scale, - qint8_t offset) { - for (size_t n = 0; n < N; n++) { - Output[n] = float((int32_t(Input[n]) - offset)) * scale; - } - } - - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferAFloat; - MatrixGuardBuffer BufferBFloat; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - MatrixGuardBuffer BufferBias; - const bool AIsSigned = std::is_signed::value; - const bool BIsSigned = std::is_signed::value; - - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("QGemm") + - (std::is_signed::value ? "S8" : "U8") + - (std::is_signed::value ? "S8" : "U8") + - (Packed ? "_Fp32_Packed" : "_Fp32_NoPack") + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h deleted file mode 100644 index 40f688a16ecca..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h +++ /dev/null @@ -1,182 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_qgemm.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class QgemmShortExecuteTest; - -template -class QgemmShortExecuteTest : public MlasTestFixture> { - public: - explicit QgemmShortExecuteTest(bool use_offb, size_t M, size_t N, size_t K, size_t Batch, uint8_t offa, uint8_t offb) - : use_offb_(use_offb), M_(M), N_(N), K_(K), Batch_(Batch), offa_(offa), offb_(offb) { - } - - void TestBody() override { - if (use_offb_) { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, offa_, offb_); - } else { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, offa_); - } - } - - static size_t RegisterSingleTest(bool use_offb, size_t M, size_t N, size_t K, size_t Batch, uint8_t offa, uint8_t offb) { - std::stringstream ss; - ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" - << "offa" << (unsigned)offa << "/" - << "offb"; - if (use_offb) { - ss << (unsigned)offb; - } else { - ss << "--"; - } - auto test_name = ss.str(); - - testing::RegisterTest( - MlasQgemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new QgemmShortExecuteTest( - use_offb, M, N, K, Batch, offa, offb); - }); - - return 1; - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, uint8_t offa, uint8_t offb) { - return RegisterSingleTest(true, M, N, K, Batch, offa, offb); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, uint8_t offa) { - return RegisterSingleTest(false, M, N, K, Batch, offa, 0); - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, 14, 211); - test_registered += RegisterSingleTest(b, b, b, 1, 21); - if (!Packed) { - test_registered += RegisterSingleTest(b, b, b, 3 + b / 4, 14, 211); - test_registered += RegisterSingleTest(b, b, b, 2 + b / 4, 21); - } - } - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, 14, 211); - test_registered += RegisterSingleTest(b, b, b, 1, 17); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, 1, 34, 1); - test_registered += RegisterSingleTest(b, b, b, 1, 1); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, 1, 85, 173); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, 1, 0, 0); - test_registered += RegisterSingleTest(1, 32, b, 1, 0, 0); - test_registered += RegisterSingleTest(1, b, b, 1, 0, 0); - if (!Packed) { - test_registered += RegisterSingleTest(1, b, 32, 3, 0, 0); - test_registered += RegisterSingleTest(1, 32, b, 5, 0, 0); - } - } - test_registered += RegisterSingleTest(43, 500, 401, 1, 183, 223); - test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 5, 8); - test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 7); - if (!Packed) { - test_registered += RegisterSingleTest(43, 500, 401, 7, 183, 223); - test_registered += RegisterSingleTest(1023, 1023, 1023, 3, 5, 8); - } - size_t dims[] = {400, 500, 1024}; - size_t kdims[] = {1003, 2048 + 50, 4096 - 100}; - for (size_t m : dims) { - for (size_t n : dims) { - for (size_t k : kdims) { - test_registered += RegisterSingleTest(m - 3, n + 5, k, 1, 14, 211); - test_registered += RegisterSingleTest(m + 5, n - 3, k, 1, 17); - } - } - } - - return test_registered; - } - - private: - bool use_offb_; - size_t M_, N_, K_, Batch_; - uint8_t offa_, offb_; -}; - -template -class QgemmShortExecuteTest : public MlasTestFixture> { - public: - explicit QgemmShortExecuteTest(size_t M, size_t N, size_t K, uint8_t offa, uint8_t offb) - : M_(M), N_(N), K_(K), offa_(offa), offb_(offb) { - } - - void TestBody() override { - // Batching code is agnostic to result type. Only cover batches above, not here. - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, 1, offa_, offb_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, uint8_t offa, uint8_t offb) { - std::stringstream ss; - ss << "M" << M << "xN" << N << "xK" << K << "/" - << "offa" << (unsigned)offa << "/" - << "offb" << (unsigned)offb; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasQgemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new QgemmShortExecuteTest(M, N, K, offa, offb); - }); - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 34, 46); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, 15, 191); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, 223, 73); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, 0, 0); - } - test_registered += RegisterSingleTest(43, 503, 401, 183, 223); - test_registered += RegisterSingleTest(1024, 1024, 256, 13, 15); - - return test_registered; - } - - private: - bool use_offb_; - size_t M_, N_, K_; - uint8_t offa_, offb_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp b/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp deleted file mode 100644 index 5876f186eaa0d..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_qlinear_binaryop.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -class MlasQLinearBinaryOpTest : public MlasTestBase { - public: - typedef void(MLASCALL* QLinearBinaryOpS8)( - const int8_t* InputA, float ScaleA, int32_t ZeroPointA, - const int8_t* InputB, float ScaleB, int32_t ZeroPointB, - float ScaleC, int32_t ZeroPointC, int8_t* OutputC, - size_t N, bool IsScalarB); - typedef void(MLASCALL* QLinearBinaryOpU8)( - const uint8_t* InputA, float ScaleA, int32_t ZeroPointA, - const uint8_t* InputB, float ScaleB, int32_t ZeroPointB, - float ScaleC, int32_t ZeroPointC, uint8_t* OutputC, - size_t N, bool IsScalarB); - - private: - std::function ScalarOp; - std::string ScalarOpName; - QLinearBinaryOpS8 QLinearS8Op; - QLinearBinaryOpU8 QLinearU8Op; - MatrixGuardBuffer BufferInputA; - MatrixGuardBuffer BufferInputB; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - template - T QLinearBinaryScalar(T a, - float ScaleA, - int32_t ZeroPointA, - T b, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC) { - constexpr int qmax = std::numeric_limits::max(); - constexpr int qmin = std::numeric_limits::min(); - - float ValueA = ScaleA * (static_cast(a) - ZeroPointA); - float ValueB = ScaleB * (static_cast(b) - ZeroPointB); - float ValueC = std::nearbyintf(ScalarOp(ValueA, ValueB) / ScaleC) + ZeroPointC; - int qc = static_cast(ValueC); - qc = std::min(qc, qmax); - qc = std::max(qc, qmin); - return static_cast(qc); - } - - template - void Test(void(MLASCALL* QLinearBinaryOp)( - const T* InputA, float ScaleA, int32_t ZeroPointA, - const T* InputB, float ScaleB, int32_t ZeroPointB, - float ScaleC, int32_t ZeroPointC, T* OutputC, - size_t N, bool IsScalarB), - size_t N, - bool IsScalarB, - float ScaleA, - int32_t ZeroPointA, - float ScaleB, - int32_t ZeroPointB, - float ScaleC, - int32_t ZeroPointC) { - T* InputA = (T*)BufferInputA.GetBuffer(N); - T* InputB = (T*)BufferInputB.GetBuffer(IsScalarB ? 1 : N); - T* OutputC = (T*)BufferOutput.GetBuffer(N); - T* OutputReference = (T*)BufferOutputReference.GetBuffer(N); - - constexpr int MinimumValue = (int)std::numeric_limits::min(); - constexpr int MaximumValue = (int)std::numeric_limits::max(); - std::default_random_engine generator(static_cast(N)); - std::uniform_int_distribution distribution(MinimumValue, MaximumValue); - - if (IsScalarB) { - InputB[0] = static_cast(distribution(generator)); - } - for (size_t n = 0; n < N; n++) { - InputA[n] = static_cast(distribution(generator)); - if (!IsScalarB) { - InputB[n] = static_cast(distribution(generator)); - } - OutputReference[n] = QLinearBinaryScalar(InputA[n], ScaleA, ZeroPointA, InputB[IsScalarB ? 0 : n], ScaleB, ZeroPointB, ScaleC, ZeroPointC); - } - - QLinearBinaryOp(InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N, IsScalarB); - - for (size_t n = 0; n < N; n++) { - int diff = (int)OutputC[n] - (int)OutputReference[n]; - ASSERT_TRUE(diff >= -1 && diff <= 1) - << ", IsScalarB=" << static_cast(IsScalarB) << ", @" << n << " of " << N << ", " - << static_cast(InputA[n]) << "(" << ScaleA << "," << ZeroPointA << "), " - << static_cast(InputB[IsScalarB ? 0 : n]) << "(" << ScaleB << "," << ZeroPointB << ") ==> " - << static_cast(OutputC[n]) << "(" << ScaleC << "," << ZeroPointC << "), " - << " expecting:" << static_cast(OutputReference[n]); - } - } - - public: - explicit MlasQLinearBinaryOpTest(std::function P_ScalarOp, - const std::string& P_ScalarOpName, - QLinearBinaryOpS8 P_QLinearS8Op, - QLinearBinaryOpU8 P_QLinearU8Op) - : ScalarOp(P_ScalarOp), - ScalarOpName(P_ScalarOpName), - QLinearS8Op(P_QLinearS8Op), - QLinearU8Op(P_QLinearU8Op) { - } - - void ExecuteShort(void) override { - static const uint8_t zero_points[] = {0, 18, 75, 128, 157, 231, 255}; - static const float c_scales[] = {18.0f, 90.0f}; - - const int8_t* s_zero_points = (const int8_t*)(&zero_points[0]); - for (size_t a = 0; a < _countof(zero_points); a++) { - for (size_t b = 0; b < _countof(zero_points); b++) { - for (size_t c = 0; c < _countof(zero_points); c++) { - for (size_t s = 0; s < _countof(c_scales); s++) { - for (size_t n = 1; n < 128; n++) { - // u8, vector + vector - Test(QLinearU8Op, n, false, 10.f, zero_points[a], 10.f, zero_points[b], c_scales[s], zero_points[c]); - - // u8, vector + scalar - Test(QLinearU8Op, n, true, 10.f, zero_points[a], 10.f, zero_points[b], c_scales[s], zero_points[c]); - - // s8, vector + vector - Test(QLinearS8Op, n, false, 10.f, s_zero_points[a], 10.f, s_zero_points[b], c_scales[s], s_zero_points[c]); - - // s8, vector + scalar - Test(QLinearS8Op, n, true, 10.f, s_zero_points[a], 10.f, s_zero_points[b], c_scales[s], s_zero_points[c]); - } - } - } - } - } - } -}; - -class MlasQLinearAddTest : public MlasQLinearBinaryOpTest { - public: - MlasQLinearAddTest() : MlasQLinearBinaryOpTest( - [](float a, float b) { return a + b; }, - "+", - MlasQLinearAdd, - MlasQLinearAdd) {} - - static const char* GetTestSuiteName() { - static const std::string suite_name("QLinearAdd"); - return suite_name.c_str(); - } -}; - -class MlasQLinearMulTest : public MlasQLinearBinaryOpTest { - public: - MlasQLinearMulTest() : MlasQLinearBinaryOpTest( - [](float a, float b) { return a * b; }, - "*", - MlasQLinearMul, - MlasQLinearMul) {} - - static const char* GetTestSuiteName() { - static const std::string suite_name("QLinearMul"); - return suite_name.c_str(); - } -}; - -static bool UNUSED_VARIABLE added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); - count += MlasDirectShortExecuteTests::RegisterShortExecute(); - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp b/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp deleted file mode 100644 index e6c230df57fbc..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_qlinear_gavgpool.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -#include - -template -class MlasQLinearGlobalAveragePoolTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - static const std::vector ZeroPoints; - - static void CalculateGlobalAvgPool( - const T8Bits* x, int64_t batch, int64_t channel, int64_t hw, bool channel_last, - T8Bits* y, int32_t x_zero_point, float x_scale, int32_t y_zero_point, float y_scale) { - int32_t bias = -x_zero_point * static_cast(hw); - int64_t stride_image = channel_last ? channel : 1; - int64_t stride_channel = channel_last ? 1 : hw; - - for (int64_t b = 0; b < batch; ++b) { - const T8Bits* bx = x + b * hw * channel; - T8Bits* by = y + b * channel; - for (int64_t c = 0; c < channel; ++c) { - const T8Bits* ix = bx + c * stride_channel; - int32_t sum = 0; - for (int64_t i = 0; i < hw; ++i) { - sum += static_cast(*ix); - ix += stride_image; - } - sum += bias; - int32_t r = static_cast(std::nearbyintf(x_scale * sum / static_cast(hw) / y_scale)); - r += y_zero_point; - r = std::min((int32_t)(std::numeric_limits::max()), r); - r = std::max((int32_t)(std::numeric_limits::lowest()), r); - by[c] = static_cast(r); - } - } - } - - static void CompareResultWithGold(size_t Batch, size_t Channel, - T8Bits* Output, T8Bits* OutputReference, std::string& info) { - size_t n = 0; - for (size_t b = 0; b < Batch; ++b) { - for (size_t c = 0; c < Channel; c++) { - int diff = abs((int)Output[n] - (int)OutputReference[n]); - ASSERT_LE(diff, 1) << " got:" << int(Output[n]) << " expecting:" << int(OutputReference[n]) << " @[" << b << "," << c << "], " << info.c_str(); - } - } - } - - static std::string GetTestInfo(bool channel_last, - size_t Batch, - size_t Stride, - size_t Channel, - size_t ImageSize, - float InputScale, - T8Bits InputZeroPoint, - float OutputScale, - T8Bits OutputZeroPoint) { - std::stringstream ss; - ss << (channel_last ? "Nhwc_" : "Nchw_"); - ss << Batch << "x [C=" << Stride << "-" << Channel << "] x" << ImageSize << "-"; - ss << "(" << (int)InputZeroPoint << "," << InputScale << "," << (int)OutputZeroPoint << "," << OutputScale << ")"; - return ss.str(); - } - - void Test(bool channel_last, - size_t Batch, - size_t Stride, - size_t Channel, - size_t ImageSize, - float InputScale, - T8Bits InputZeroPoint, - float OutputScale, - T8Bits OutputZeroPoint, - int32_t UnalignedOffset = 0) { - size_t N = Batch * Stride * ImageSize; - size_t ResultLen = Batch * Stride; - T8Bits* Input = BufferInput.GetBuffer(N); - T8Bits* Output = BufferOutput.GetBuffer(ResultLen); - T8Bits* Gold = BufferOutputReference.GetBuffer(ResultLen); - std::string test_info = GetTestInfo( - channel_last, Batch, Stride, Channel, ImageSize, - InputScale, InputZeroPoint, OutputScale, OutputZeroPoint); - - std::default_random_engine generator(static_cast(N)); - std::uniform_int_distribution distribution(std::numeric_limits::lowest(), std::numeric_limits::max()); - for (size_t n = 0; n < N; n++) { - Input[n] = static_cast(distribution(generator)); - } - CalculateGlobalAvgPool( - Input, Batch, Stride, ImageSize, channel_last, - Gold, InputZeroPoint, InputScale, OutputZeroPoint, OutputScale); - - if (!channel_last) { - std::vector acc(MlasQLinearSafePaddingElementCount(sizeof(int32_t), ResultLen + UnalignedOffset)); - MlasQLinearGlobalAveragePoolNchw( - Input, InputScale, InputZeroPoint, Output, - OutputScale, OutputZeroPoint, ResultLen, ImageSize, acc.data() + UnalignedOffset); - } else { - std::vector acc(MlasQLinearSafePaddingElementCount(sizeof(int32_t), Channel + UnalignedOffset)); - std::vector zero(MlasQLinearSafePaddingElementCount(sizeof(T8Bits), Channel + UnalignedOffset)); - if (Stride == Channel) { - MlasQLinearGlobalAveragePoolNhwc( - Input, InputScale, InputZeroPoint, Output, - OutputScale, OutputZeroPoint, Batch, ImageSize, Stride, Channel, - acc.data() + UnalignedOffset, zero.data() + UnalignedOffset); - } else { - for (size_t tc = 0; tc < Stride; tc += Channel) { - size_t cg = ((tc + Channel <= Stride) ? Channel : (Stride - tc)); - MlasQLinearGlobalAveragePoolNhwc( - Input + tc, InputScale, InputZeroPoint, Output + tc, - OutputScale, OutputZeroPoint, Batch, ImageSize, Stride, cg, - acc.data() + UnalignedOffset, zero.data() + UnalignedOffset); - } - } - } - - CompareResultWithGold(Batch, Channel, Output, Gold, test_info); - } - - public: - static const char* GetTestSuiteName() { - constexpr bool is_signed = std::is_signed::value; - static const std::string suite_name(is_signed ? "QLinearGlobalAvgPoolS8" : "QLinearGlobalAvgPoolU8"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - static const float scales[] = {18.0f, 90.0f}; - static const size_t Batch[] = {1, 3}; - static const size_t Stride[] = {7, 8, 63, 256}; - static const size_t ImageSize[] = {7, 8, 64}; - static int unalign_offset = 0; - - for (int channel_last = 0; channel_last <= 1; ++channel_last) { - for (size_t b = 0; b < _countof(Batch); b++) { - for (size_t xzp = 0; xzp < ZeroPoints.size(); xzp++) { - for (size_t yzp = 0; yzp < ZeroPoints.size(); yzp++) { - for (size_t xs = 0; xs < _countof(scales); ++xs) { - for (size_t ys = 0; ys < _countof(scales); ++ys) { - for (size_t i = 0; i < _countof(ImageSize); i++) { - for (size_t s = 0; s < _countof(Stride); s++) { - Test(channel_last != 0, Batch[b], Stride[s], Stride[s], ImageSize[i], - scales[xs], ZeroPoints[xzp], scales[ys], ZeroPoints[yzp], unalign_offset); - if (channel_last == 1 && Stride[s] > 32) { - Test(channel_last != 0, Batch[b], Stride[s], 32, ImageSize[i], - scales[xs], ZeroPoints[xzp], scales[ys], ZeroPoints[yzp], unalign_offset); - } - unalign_offset = (unalign_offset + 1) & 3; - } - } - } - } - } - } - } - } - } -}; - -template <> -const std::vector MlasQLinearGlobalAveragePoolTest::ZeroPoints = {-128, -110, 1, 103, 127}; - -template <> -const std::vector MlasQLinearGlobalAveragePoolTest::ZeroPoints = {0, 18, 128, 231, 255}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (is_short_execute) { - return MlasDirectShortExecuteTests>::RegisterShortExecute() + - MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - return (size_t)0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp deleted file mode 100644 index 7c160b6696265..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_quantizelinear.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -template -class MlasQuantizeLinearTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - void GenerateReference(const float* Input, QuantInt* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { - for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); - FloatValue = std::max(FloatValue, static_cast(std::numeric_limits::min())); - FloatValue = std::min(FloatValue, static_cast(std::numeric_limits::max())); - OutputReference[n] = static_cast(FloatValue); - } - } - - void Test(size_t N) { - float* Input = BufferInput.GetBuffer(N); - QuantInt* Output = BufferOutput.GetBuffer(N); - QuantInt* OutputReference = BufferOutputReference.GetBuffer(N); - - std::default_random_engine generator(static_cast(N)); - - std::uniform_real_distribution min_gen(-10.f, -10e-3f); - float MinimumValue = min_gen(generator); - - std::uniform_real_distribution max_gen(10e-3f, 10.f); - float MaximumValue = max_gen(generator); - - float Scale = (MaximumValue - MinimumValue) / 512.f; - - std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), - std::numeric_limits::max()); - QuantInt ZeroPoint = static_cast(zp_distribution(generator)); - - std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - for (size_t n = 0; n < N; n++) { - Input[n] = distribution(generator); - } - - GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); - MlasQuantizeLinear(Input, Output, N, Scale, ZeroPoint); - - for (size_t n = 0; n < N; n++) { - ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; - } - } - - public: - static const char* GetTestSuiteName() { - if constexpr (std::is_same_v) { - return "QuantizeLinearS8"; - } else if (std::is_same_v) { - return "QuantizeLinearU8"; - } else if (std::is_same_v) { - return "QuantizeLinearS16"; - } else { - return "QuantizeLinearU16"; - } - } - - void ExecuteShort(void) override { - for (size_t n = 1; n <= 512; n++) { - Test(n); - } - } -}; - -template -class MlasQuantizeLinear4BitTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - int32_t MinVal() const { - if constexpr (Signed) { - return -8; - } else { - return 0; - } - } - - int32_t MaxVal() const { - if constexpr (Signed) { - return 7; - } else { - return 15; - } - } - - void GenerateReference(const float* Input, uint8_t* OutputReference, size_t N, float Scale, - int8_t ZeroPoint) { - for (size_t n = 0; n < N; n++) { - float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); - FloatValue = std::max(FloatValue, static_cast(MinVal())); - FloatValue = std::min(FloatValue, static_cast(MaxVal())); - - int8_t IntValue = static_cast(FloatValue); - - size_t i = n >> 1; - size_t j = n & 0x1; - uint8_t Shift = 4 * static_cast(j); - uint8_t Mask = 0xF << Shift; - - OutputReference[i] &= ~Mask; // Clear 4-bit lane - OutputReference[i] |= static_cast((IntValue & 0xF) << Shift); // Set 4-bit lane - } - } - - void Test(size_t N) { - size_t OutBufLen = (N + 1) / 2; - float* Input = BufferInput.GetBuffer(N); - uint8_t* Output = BufferOutput.GetBuffer(OutBufLen); - uint8_t* OutputReference = BufferOutputReference.GetBuffer(OutBufLen); - - std::default_random_engine generator(static_cast(N)); - - std::uniform_real_distribution min_gen(-10.f, -10e-3f); - float MinimumValue = min_gen(generator); - - std::uniform_real_distribution max_gen(10e-3f, 10.f); - float MaximumValue = max_gen(generator); - - float Scale = (MaximumValue - MinimumValue) / 32.f; - - std::uniform_int_distribution zp_distribution(MinVal(), MaxVal()); - int8_t ZeroPoint = static_cast(zp_distribution(generator)); - - std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - for (size_t n = 0; n < N; n++) { - Input[n] = distribution(generator); - } - - GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); - - if constexpr (Signed) { - MlasQuantizeLinearS4(Input, Output, N, Scale, ZeroPoint); - } else { - MlasQuantizeLinearU4(Input, Output, N, Scale, ZeroPoint); - } - - for (size_t n = 0; n < N; n++) { - size_t i = n >> 1; - size_t j = n & 0x1; - const uint8_t Shift = 4 * static_cast(j); - - int32_t actual_val = (Output[i] >> Shift) & 0xF; - int32_t expected_val = (OutputReference[i] >> Shift) & 0xF; - - if constexpr (Signed) { - constexpr uint8_t SignExtShift = (sizeof(int32_t) * 8) - 4; - actual_val = (actual_val << SignExtShift) >> SignExtShift; - expected_val = (expected_val << SignExtShift) >> SignExtShift; - } - - ASSERT_EQ(actual_val, expected_val) << ", size=" << N - << ", index=" << n - << ", nibble=" << j; - } - } - - public: - static const char* GetTestSuiteName() { - if constexpr (Signed) { - return "QuantizeLinearS4"; - } else { - return "QuantizeLinearU4"; - } - } - - void ExecuteShort(void) override { - for (size_t n = 1; n <= 512; n++) { - Test(n); - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_reorder_output.cpp b/onnxruntime/test/mlas/unittest/test_reorder_output.cpp deleted file mode 100644 index e39abd8578da4..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_reorder_output.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -class MlasReorderOutputTest : public MlasTestBase { - private: - const size_t BlockSize = MlasNchwcGetBlockSize(); - - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutput2; - MatrixGuardBuffer BufferOutputReference; - - void Test(size_t BatchCount, size_t Channels, size_t Height, size_t Width) { - size_t NchwcChannels = (Channels + BlockSize - 1) & ~(BlockSize - 1); - - size_t InputBufferElements = BatchCount * NchwcChannels * Height * Width; - size_t OutputBufferElements = BatchCount * Channels * Height * Width; - - const float* Input = BufferInput.GetBuffer(InputBufferElements); - float* Output = BufferOutput.GetBuffer(OutputBufferElements); - float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements); - - int64_t NchwOutputShape[] = {int64_t(BatchCount), int64_t(Channels), int64_t(Height), int64_t(Width)}; - - std::fill_n(Output, OutputBufferElements, -0.5f); - std::fill_n(OutputReference, OutputBufferElements, -0.5f); - - MlasReorderOutputNchw(NchwOutputShape, Input, Output, GetMlasThreadPool()); - ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, false); - ASSERT_EQ(memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)), 0) - << " [Nchw] batch=" << BatchCount << ", channels=" << Channels - << ", height=" << Height << ", width=" << Width; - - int64_t NhwcOutputShape[] = {int64_t(BatchCount), int64_t(Height), int64_t(Width), int64_t(Channels)}; - - std::fill_n(Output, OutputBufferElements, -0.5f); - std::fill_n(OutputReference, OutputBufferElements, -0.5f); - - MlasReorderOutputNhwc(NhwcOutputShape, Input, Output); - ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, true); - ASSERT_EQ(memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)), 0) - << " [Nhwc] batch=" << BatchCount << ", channels=" << Channels - << ", height=" << Height << ", width=" << Width; - } - - void ReferenceReorderOutput(size_t BatchCount, - size_t Channels, - size_t Height, - size_t Width, - const float* Input, - float* Output, - bool NhwcFormat) { - size_t NchwcChannels = (Channels + (BlockSize - 1)) & ~(BlockSize - 1); - size_t SpatialSize = Height * Width; - - size_t ChannelStride = NhwcFormat ? 1 : SpatialSize; - size_t SpatialStride = NhwcFormat ? Channels : 1; - - for (size_t n = 0; n < BatchCount; n++) { - for (size_t c = 0; c < Channels; c++) { - const float* input = Input + ((c & ~(BlockSize - 1)) * SpatialSize) + (c & (BlockSize - 1)); - float* output = Output + (c * ChannelStride); - - for (size_t hw = 0; hw < SpatialSize; hw++) { - output[hw * SpatialStride] = input[hw * BlockSize]; - } - } - - Input += NchwcChannels * SpatialSize; - Output += Channels * SpatialSize; - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("ReorderOutput"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - for (size_t c = 1; c < 48; c++) { - Test(1, c, 112, 112); - Test(4, c, 15, 21); - Test(16, c, 11, 11); - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return (MlasNchwcGetBlockSize() > 1 && is_short_execute) - ? MlasDirectShortExecuteTests::RegisterShortExecute() - : 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp deleted file mode 100644 index f85fe97776dc1..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - test_sbgemm.cpp - -Abstract: - - Tests for MLAS bf16 precision GEMM. - ---*/ - -#if defined(__aarch64__) && defined(__linux__) - -#include "test_sbgemm.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class SBGemmShortExecuteTest : public MlasTestFixture> { - public: - explicit SBGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) - : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { - std::stringstream ss; - ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" - << "hasBias" << hasBias; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasSBGemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SBGemmShortExecuteTest( - M, N, K, Batch, hasBias); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, false); - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, 1, false); - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, 1, true); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, 1, false); - test_registered += RegisterSingleTest(1, 32, b, 1, true); - test_registered += RegisterSingleTest(1, b, b, 1, false); - if (!Packed) { - test_registered += RegisterSingleTest(1, b, 32, 3, true); - test_registered += RegisterSingleTest(1, 32, b, 5, false); - } - } - // TODO: check why the cosine similarly is < 0.99 for this shape alone - // test_registered += RegisterSingleTest(43, 500, 401, 1, true); - test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); - if (!Packed) { - test_registered += RegisterSingleTest(43, 500, 401, 5, true); - test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); - } - - return test_registered; - } - - private: - size_t M_, N_, K_, Batch_; - bool hasBias_; -}; - -static size_t SBGemmRegistLongExecute() { - size_t count = 0; - - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - } - - return count; -} - -static size_t SBGemmRegistShortExecute() { - size_t count = 0; - - count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { - count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - } - - if (GetMlasThreadPool() != nullptr) { - count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { - count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - } - } - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - if (!MlasBf16AccelerationSupported()) { - return false; - } - - if (is_short_execute) { - return SBGemmRegistShortExecute() > 0; - } - return SBGemmRegistLongExecute() > 0; -}); -#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h deleted file mode 100644 index 13701e2e3de46..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.h +++ /dev/null @@ -1,281 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the MIT License. - -Module Name: - - test_sbgemm.h - -Abstract: - - Tests for MLAS bf16 precision GEMM. - ---*/ - -#if defined(__aarch64__) && defined(__linux__) - -#pragma once - -#include "test_util.h" - -template -void SmallFloatFill(T* start, size_t size) { - constexpr float MinimumFillValue = -11.0f; - auto FillAddress = start; - size_t offset = size % 23; - - for (size_t i = 0; i < size; i++) { - offset = (offset + 21) % 23; - *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); - } -} - -float cosine_similarity(const float* A, const float* B, size_t Vector_Length) { - float dot = 0.0, denom_a = 0.0, denom_b = 0.0; - for (size_t i = 0u; i < Vector_Length; ++i) { - dot += A[i] * B[i]; - denom_a += A[i] * A[i]; - denom_b += B[i] * B[i]; - } - return dot / (sqrt(denom_a) * sqrt(denom_b)); -} - -/** - * @brief Test class for bf16 precision GEMM - * @tparam AType Data type of A matrix, need to be float - * @tparam BType Data type of b matrix, can be either float or prepacked bf16 - */ -template -class MlasSBGemmTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferBPacked; - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - MatrixGuardBuffer BufferFloatC; - MLAS_THREADPOOL* threadpool_; - - void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { - size_t PackedBSize = MlasSBGemmPackBSize(N, K); - if (PackedBSize == 0) { - return nullptr; - } - void* PackedB = BufferBPacked.GetBuffer(PackedBSize); - if (std::is_same::value) { - MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); - } else { - } - return PackedB; - } - - void CallSBGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const float* A, - size_t lda, - const BType* B, - size_t ldb, - const float* Bias, - float* C, - size_t ldc) { - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < GemmParameters.size(); i++) { - auto& params = GemmParameters[i]; - params.A = A + (M * lda * i); - params.lda = lda; - if (nullptr != Bias) { - params.Bias = reinterpret_cast(Bias + N * i); - } else { - params.Bias = nullptr; - } - params.C = reinterpret_cast(C + (M * ldc * i)); - params.ldc = ldc; - params.AIsfp32 = true; - params.BIsfp32 = true; - - if (Packed) { - ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb); - params.ldb = 0; - params.BIsfp32 = false; - } else { - params.B = B + (K * N * i); - params.ldb = ldb; - } - } - - MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); - } - - void ReferenceSgemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - const BType* B, - const float* Bias, - float* C) { - constexpr size_t KStride = 256; - - for (size_t batch = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const AType* a = A + M * K * batch + m * K; - const BType* b = B + K * N * batch + n; - float* c = C + (M * N * batch) + (m * N) + n; - - for (size_t k = 0; k < K; k += KStride) { - float sum = 0.0f; - if (k == 0 && Bias != nullptr) { - sum = float(Bias[n]); - } - for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { - float down(float(*b) * float(*a) + sum); - sum = float(down); - b += N; - a += 1; - } - if (k == 0) { - *c = sum; - } else { - float d(sum + *c); - *c = float(d); - } - } - } - } - if (Bias) { - Bias += N; - } - } - } - - public: - MlasSBGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { - AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); - AType Atail[16]; - std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); - - BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); - BType Btail[16]; - std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); - - float BiasTail[16]; - const float* Bias = nullptr; - if (withBias) { - Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); - std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); - } - - float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); - float* CReference = BufferCReference.GetFilledBuffer( - N * M * BatchSize, - [](float* start, size_t size) { - std::fill_n(start, size, -1.0f); - }); - this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); - ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); - const float cosine_similarity_threshold = 0.98; - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - if (!(CloseEnough(float(C[f]), CReference[f]))) { - float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); - if (abs(cos_sim) < cosine_similarity_threshold) { - ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; - } else { - break; - } - } - } - } - } - - ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; - ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; - if (withBias) { - ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; - } - } - - private: - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SBGemmFP") + - (std::is_same::value ? "32" : "16") + - (std::is_same::value ? "32" : "16") + - (Packed ? "_Packed" : "_NoPack") + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void ExecuteLong(void) override { - for (size_t M = 16; M < 160; M += 32) { - for (size_t N = 16; N < 160; N += 32) { - static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; - for (size_t k = 0; k < _countof(ks); k++) { - size_t K = ks[k]; - - Test(M, N, K, 1, false); - Test(M, N, K, 1, true); - Test(M + 1, N, K, 1, false); - Test(M, N + 1, K, 1, true); - Test(M + 1, N + 1, K, 1, false); - Test(M + 3, N + 2, K, 1, true); - Test(M + 4, N, K, 1, false); - Test(M, N + 4, K, 1, true); - Test(M + 4, N + 4, K, 1, false); - Test(M + 3, N + 7, K, 1, true); - Test(M + 8, N, K, 1, false); - Test(M, N + 8, K, 1, true); - Test(M + 12, N + 12, K, 1, false); - Test(M + 13, N, K, 1, true); - Test(M, N + 15, K, 1, false); - Test(M + 15, N + 15, K, 1, false); - if (!Packed) { - Test(M, N, K, 7, false); - Test(M + 3, N, K, 8, true); - Test(M, N + 1, K, 9, false); - Test(M + 12, N, K, 10, true); - Test(M, N + 15, K, 11, false); - Test(M + 15, N + 15, K, 12, true); - } - } - } - printf("M %zd\n", M); - } - - for (size_t M = 1; M < 160; M++) { - for (size_t N = 1; N < 160; N++) { - for (size_t K = 1; K < 160; K++) { - Test(M, N, K, 1, true); - } - } - printf("M %zd\n", M); - } - - for (size_t M = 160; M < 320; M += 24) { - for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 1; K < 16; K++) { - Test(M, N, K, 1, true); - } - for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1, false); - } - } - printf("M %zd\n", M); - } - } -}; - -#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp b/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp deleted file mode 100644 index 34f17843b0726..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_scaleoutput.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -class MlasScaleOutputTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputRef; - MatrixGuardBuffer BufferScale; - - void Test(size_t M, size_t N, bool PerColumn, bool AccumulateMode) { - int32_t* Input = BufferInput.GetBuffer(M * N); - float* Output = BufferOutput.GetBuffer(M * N); - float* OutputRef = BufferOutputRef.GetBuffer(M * N); - float* Scale = BufferScale.GetBuffer(PerColumn ? N : 1); - - std::default_random_engine generator(static_cast(M * N)); - std::uniform_real_distribution real_distribution(-1.0f, 1.0f); - std::uniform_int_distribution int_distribution(std::numeric_limits::min(), - std::numeric_limits::max()); - - for (size_t s = 0; s < M * N; s++) { - Input[s] = int_distribution(generator); - Output[s] = OutputRef[s] = real_distribution(generator); - } - - for (size_t s = 0; s < (PerColumn ? N : 1); s++) { - Scale[s] = real_distribution(generator); - } - - // Compute Reference Value - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - float current_scale = PerColumn ? Scale[n] : Scale[0]; - if (AccumulateMode) { - OutputRef[m * N + n] += Input[m * N + n] * current_scale; - } else { - OutputRef[m * N + n] = Input[m * N + n] * current_scale; - } - } - } - - // Compute Output with MLAS - MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR OutputProcessor( - Output, N, Scale, nullptr, - AccumulateMode ? MLAS_QGEMM_OUTPUT_MODE::AccumulateMode : MLAS_QGEMM_OUTPUT_MODE::ZeroMode, - PerColumn ? MLAS_QUANTIZATION_GRANULARITY::PerColumn : MLAS_QUANTIZATION_GRANULARITY::PerMatrix); - OutputProcessor.Process(Input, 0, 0, M, N, N); - - constexpr float epsilon = 1e-6f; - - for (size_t n = 0; n < M * N; n++) { - float diff = std::fabs((Output[n] - OutputRef[n]) / OutputRef[n]); - ASSERT_LE(diff, epsilon) - << " @[" << n / N << "," << n % N << "], total:[" << M << "," << N << "], got:" - << Output[n] << ", expecting:" << OutputRef[n]; - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name("ScaleOutput"); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - for (size_t m = 1; m < 18; m++) { - for (size_t n = 1; n < 18; n++) { - Test(m, n, true, true); - Test(m, n, true, false); - Test(m, n, false, true); - Test(m, n, false, false); - } - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? MlasDirectShortExecuteTests::RegisterShortExecute() : 0; -}); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp deleted file mode 100644 index fb4ebbee77faf..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -template -class MlasSoftmaxTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - MLAS_THREADPOOL* threadpool_; - - void Test(size_t N, size_t D, float MinimumValue, float MaximumValue) { - float* Input = BufferInput.GetBuffer(N * D); - float* Output = BufferOutput.GetBuffer(N * D); - float* OutputReference = BufferOutputReference.GetBuffer(N * D); - - std::default_random_engine generator(static_cast(N * D)); - std::uniform_real_distribution distribution(MinimumValue, MaximumValue); - - for (size_t nd = 0; nd < N * D; nd++) { - Input[nd] = distribution(generator); - } - - Test(Input, Output, OutputReference, N, D, false, true); - Test(Input, Output, OutputReference, N, D, true, true); - Test(Input, Output, OutputReference, N, D, false, false); - Test(Input, Output, OutputReference, N, D, true, false); - } - - void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); - ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); - - constexpr float AbsoluteTolerance = 1e-6f; - constexpr float RelativeTolerance = 1e-6f; - - for (size_t nd = 0; nd < N * D; nd++) { - float diff = std::fabs(Output[nd] - OutputReference[nd]); - ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[nd]) * RelativeTolerance) - << "LogSoftmax:" << (int)LogSoftmax << " difference " << N << "/" << D - << ", got: " << Output[nd] << ", expecting: " << OutputReference[nd]; - } - } - - void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - for (size_t n = 0; n < N; n++) { - float MaximumValue = std::numeric_limits::lowest(); - - for (size_t d = 0; d < D; d++) { - MaximumValue = (std::max)(MaximumValue, Input[d]); - } - - if (SmoothSoftmax && MaximumValue < 0.0f) { - MaximumValue = 0.0f; - } - - double Sum = 0.0; - - for (size_t d = 0; d < D; d++) { - double e = std::exp(double(Input[d]) - double(MaximumValue)); - Sum += e; - Output[d] = float(e); - } - - if (SmoothSoftmax) { - Sum += expf(-MaximumValue); - } - - if (LogSoftmax) { - float Scale = float(std::log(Sum)); - - for (size_t d = 0; d < D; d++) { - Output[d] = Input[d] - MaximumValue - Scale; - } - - } else { - float Scale = float(Sum); - - for (size_t d = 0; d < D; d++) { - Output[d] /= Scale; - } - } - - Input += D; - Output += D; - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name(Threaded ? "Softmax_Threaded" : "Softmax_SingleThread"); - return suite_name.c_str(); - } - - MlasSoftmaxTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void ExecuteShort(void) override { - for (size_t d = 1; d < 128; d++) { - Test(1, d, -10.f, 10.f); - } - - Test(3, 128, 20.f, 30.f); - Test(63, 95, -150.f, 190.f); - Test(16, 211, 20.f, 30.f); - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - if (GetMlasThreadPool() != nullptr) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp deleted file mode 100644 index 0710981fa17c6..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ /dev/null @@ -1,441 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_sqnbitgemm.h - -Abstract: - - Tests for MLAS n-bit int block quantized GEMM. - ---*/ - -#include "test_util.h" -#include "mlas_q4.h" -#include "mlas_qnbit.h" - -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { - switch (ComputeType) { - case CompFp32: - return "Fp32"; - case CompInt8: - return "Int8"; - default: - return "unknown"; - } -} - -/** - * @brief Test class for n-bit int block quantized GEMM - * Note: only 2-D matmul supported for now - */ -template -class MlasSQNBitGemmTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferQuantAData; - MatrixGuardBuffer BufferQuantAScale; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferQuantBData; - MatrixGuardBuffer BufferPackedQuantBData; - MatrixGuardBuffer BufferQuantBZeroPoint; - MatrixGuardBuffer BufferQuantBScale; - MatrixGuardBuffer BufferDequantizedB; - MatrixGuardBuffer BufferBias; - MatrixGuardBuffer BufferWorkspace; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - - void CallGemm(size_t M, - size_t N, - size_t K, - const float* A, - size_t lda, - const void* /*QuantBData*/, - const void* PackedQuantBDataWorkspace, - const float* QuantBScale, - const void* QuantBZeroPoint, - const float* Bias, - float* C, - size_t ldc, - void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; - params.A = A; - params.lda = lda; - params.Bias = Bias; - params.C = C; - params.ldc = ldc; -#ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { - params.QuantBDataWorkspace = PackedQuantBDataWorkspace; - } -#endif - params.PackedQuantBData = static_cast(PackedQuantBDataWorkspace); - params.QuantBScale = QuantBScale; - params.QuantBZeroPoint = QuantBZeroPoint; - params.PostProcessor = nullptr; - - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); - } - - void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { - const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; - const size_t lda = K; - for (size_t m = 0; m < M; ++m) { - for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { - const size_t local_blk_len = std::min(K - k, BlkLen); - float blk_a[BlkLen]{}; - std::copy_n(A + m * lda + k, local_blk_len, blk_a); - - float amax = 0.0f; // max of absolute values of A block - for (size_t kk = 0; kk < local_blk_len; ++kk) { - float a = blk_a[kk]; - amax = std::max(amax, fabsf(a)); - } - - constexpr float range_max = (1 << 7) - 1; - const float scale = amax / range_max; - const float scale_reciprocal = scale != 0.0f ? 1.0f / scale : 0.0f; - - QuantAScale[m * BlockCountK + k_blk] = scale; - - for (size_t kk = 0; kk < BlkLen; ++kk) { - const float q = roundf(blk_a[kk] * scale_reciprocal); - QuantAData[m * BlockCountK * BlkLen + k + kk] = - static_cast( - std::clamp(q, - static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()))); - } - } - } - } - - void CallReferenceGemm_CompInt8(size_t M, - size_t N, - size_t K, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - const float* Bias, - float* C) { - const size_t BlockCountK = (K + BlkLen - 1) / BlkLen; - - int8_t* QuantAData = BufferQuantAData.GetBuffer(M * BlockCountK * BlkLen); - float* QuantAScale = BufferQuantAScale.GetBuffer(M * BlockCountK); - QuantizeA(M, K, A, QuantAData, QuantAScale); - - for (size_t m = 0; m < M; ++m) { - for (size_t n = 0; n < N; ++n) { - float sum = Bias == nullptr ? 0.0f : Bias[n]; - for (size_t k = 0, k_blk = 0; k < K; k += BlkLen, ++k_blk) { - const size_t k_blk_len = std::min(K - k, BlkLen); - - const float a_scale = QuantAScale[m * BlockCountK + k_blk]; - - const float b_scale = QuantBScale[n * BlockCountK + k_blk]; - - static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B"); - - uint8_t b_zp = 8; - if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2]; - b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); - } - - int32_t qsum = 0; - - for (size_t kk = 0; kk < k_blk_len; ++kk) { - const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk]; - const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2]; - const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; - qsum += qa * qb; - } - - sum += static_cast(qsum) * a_scale * b_scale; - } - - C[m * N + n] = sum; - } - } - } - - void CallReferenceGemm_CompFp32(size_t M, - size_t N, - size_t K, - const float* A, - const uint8_t* QuantBData, - const float* QuantBScale, - const uint8_t* QuantBZeroPoint, - const float* Bias, - float* C) { - float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); - MlasDequantizeBlockwise( - DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, - static_cast(K), static_cast(N), GetMlasThreadPool()); - // Note: DequantizedBData is in column major layout. - - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const float* a = A + m * K; - const float* b = DequantizedBData + n * K; - float* c = C + (m * N) + n; - - float sum = Bias == nullptr ? 0.0f : Bias[n]; - for (size_t k = 0; k < K; k++) { - sum += (*a) * (*b); - b += 1; - a += 1; - } - *c = sum; - } - } - } - - public: - void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - bool WithThreadpool, bool Symmetric, bool WithBias) { - MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; - - const float* A = BufferA.GetBuffer(K * M); - - const float* B = BufferB.GetBuffer(N * K); - - const float* Bias = nullptr; - if (WithBias) { - Bias = BufferBias.GetBuffer(N); - } - -#if 0 - auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { - for (size_t row = 0; row < nrows; ++row) { - for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * ncols + col] << ", "; - } - std::cout << "\n"; - } - }; - - auto print_matrix_col = [](size_t nrows, size_t ncols, size_t col, const float* data) { - for (size_t row = 0; row < nrows; ++row) { - std::cout << data[row * ncols + col] << ", "; - } - std::cout << "\n"; - }; - - std::cout << "A:\n"; - print_matrix(M, K, A); - std::cout << "B:\n"; - print_matrix(K, N, B); -#endif - - float* C = BufferC.GetBuffer(N * M, true); - float* CReference = BufferCReference.GetBuffer(N * M, true); - - // quantize B - uint8_t* QuantBData = nullptr; - float* QuantBScale = nullptr; - uint8_t* QuantBZeroPoint = nullptr; - { - size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, - static_cast(K), static_cast(N), - QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); - - QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); - QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); - if (!Symmetric) { - QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); - } - - MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, - B, BlkLen, - /* columnwise */ true, - static_cast(K), static_cast(N), - static_cast(N), - GetMlasThreadPool()); - } - - void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); - WorkspaceSize > 0) { - Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); - } - - void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); - PackedQuantBDataSize > 0) { - PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); - } - - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - - if (ComputeType == CompFp32) { - CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { - CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else { - FAIL() << "Test is not implemented for compute type " - << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; - } - - size_t f = 0; - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_TRUE(CloseEnough(C[f], CReference[f])) - << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " - << "M=" << M << ", N=" << N << ", K=" << K; - } - } - } - - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SQNBitGemm") + - "BlkBitWidth" + std::to_string(BlkBitWidth) + - "BlkLen" + std::to_string(BlkLen); - return suite_name.c_str(); - } -}; - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class SQNBitGemmShortExecuteTest : public MlasTestFixture> { - public: - explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - bool WithThreadpool, bool Symmetric, bool WithBias) - : M_(M), - N_(N), - K_(K), - ComputeType_(ComputeType), - WithThreadpool_(WithThreadpool), - Symmetric_(Symmetric), - WithBias_(WithBias) { - } - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test( - M_, N_, K_, ComputeType_, WithThreadpool_, Symmetric_, WithBias_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - bool WithThreadpool, bool Symmetric, bool WithBias) { - size_t tests_registered = 0; - - if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - std::stringstream ss; - ss << (WithThreadpool ? "SingleThread" : "Threaded") - << "/isSymmetric" << Symmetric - << "/M" << M << "xN" << N << "xK" << K - << "/hasBias" << WithBias - << "/computeType" << ComputeTypeName(ComputeType); - auto test_name = ss.str(); - - testing::RegisterTest( - MlasSQNBitGemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SQNBitGemmShortExecuteTest( - M, N, K, ComputeType, WithThreadpool, Symmetric, WithBias); - }); - - tests_registered += 1; - } - - return tests_registered; - } - - static size_t RegisterShortExecuteTests() { - size_t tests_registered = 0; - - for (MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType : {CompFp32, CompInt8}) { - for (bool WithThreadpool : {false, true}) { - for (bool Symmetric : {false, true}) { - for (size_t b = 1; b < 16; b++) { - tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); - } - for (size_t b = 16; b <= 256; b <<= 1) { - tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); - } - for (size_t b = 256; b < 320; b += 32) { - tests_registered += RegisterSingleTest(b, b, b, ComputeType, WithThreadpool, Symmetric, true); - } - for (size_t b = 1; b < 96; b++) { - tests_registered += RegisterSingleTest(1, b, 32, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(1, 32, b, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(1, b, b, ComputeType, WithThreadpool, Symmetric, false); - } - tests_registered += RegisterSingleTest(43, 500, 401, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(1, 2, 16, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(1, 2, 16, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(1, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(11, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(1, 1027, 1031, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(11, 1027, 1031, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(1, 527, 2131, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(11, 527, 2131, ComputeType, WithThreadpool, Symmetric, false); - tests_registered += RegisterSingleTest(1, 527, 2131, ComputeType, WithThreadpool, Symmetric, true); - tests_registered += RegisterSingleTest(11, 527, 2131, ComputeType, WithThreadpool, Symmetric, true); - // tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false); - } - } - } - - return tests_registered; - } - - private: - size_t M_, N_, K_; - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType_; - bool WithThreadpool_, Symmetric_, WithBias_; -}; - -static size_t SQNBitGemmRegisterAllShortExecuteTests() { - size_t count = 0; - - count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister( - [](bool is_short_execute) -> size_t { - if (is_short_execute) { - return SQNBitGemmRegisterAllShortExecuteTests(); - } - return 0; - }); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp deleted file mode 100644 index 243752bbea24e..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp +++ /dev/null @@ -1,82 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - test_sqnbitgemm_neon_fp16.cpp - -Abstract: - - Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. - ---*/ - -#include - -#include "test_util.h" -#include "core/mlas/lib/mlasi.h" - -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - -class MlasNeonFp16CastTest : public MlasTestBase { - private: - void TestFp16ToFp32(size_t count) { - std::vector src(count); - std::vector dest(count); - - for (size_t i = 0; i < count; i++) { - src[i] = static_cast(i); - } - - MlasCastF16ToF32KernelNeon(src.data(), dest.data(), count); - - for (size_t i = 0; i < count; i++) { - if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan - ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); - } - } - - void TestFp32ToFp16(size_t count) { - std::vector src(count); - std::vector dest(count); - - for (size_t i = 0; i < count; i++) { - src[i] = static_cast(i) + 0.125f; - } - - MlasCastF32ToF16KernelNeon(src.data(), dest.data(), count); - - for (size_t i = 0; i < count; i++) { - ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); - } - } - - public: - static const char* GetTestSuiteName() { - return "NeonFp16Cast"; - } - - void ExecuteShort(void) override { - TestFp16ToFp32(1 << 16); - TestFp16ToFp32(1); - TestFp16ToFp32(4); - TestFp16ToFp32(7); - TestFp32ToFp16(1 << 16); - TestFp32ToFp16(3); - TestFp32ToFp16(4); - TestFp32ToFp16(6); - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); - } - return count; -}); - -#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp deleted file mode 100644 index bb3aea02cc011..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_symm_qgemm.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "test_symm_qgemm_fixture.h" - -static size_t SymmQgemmRegistLongExecute() { - if (MlasSymmQgemmPackBSize(16, 16, true) == 0) { - return 0; - } - - size_t count = MlasLongExecuteTests>::RegisterLongExecute(); - - if (GetMlasThreadPool() != nullptr) { - count += MlasLongExecuteTests>::RegisterLongExecute(); - } - - return count; -} - -static size_t SymmQgemmRegistShortExecute() { - if (MlasSymmQgemmPackBSize(16, 16, true) == 0) { - return 0; - } - - size_t count = SymmQgemmShortExecuteTest::RegisterShortExecuteTests(); - - if (GetMlasThreadPool() != nullptr) { - count += SymmQgemmShortExecuteTest::RegisterShortExecuteTests(); - } - - return count; -} - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? SymmQgemmRegistShortExecute() : SymmQgemmRegistLongExecute(); -}); \ No newline at end of file diff --git a/onnxruntime/test/mlas/unittest/test_symm_qgemm.h b/onnxruntime/test/mlas/unittest/test_symm_qgemm.h deleted file mode 100644 index da49aadf0ea03..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_symm_qgemm.h +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_util.h" - -template -class MlasSymmQgemmTestBase : public MlasTestBase { - protected: - MLAS_THREADPOOL* threadpool_; - - MlasSymmQgemmTestBase() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} - - void TestGemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - int32_t offa, - bool AIsSigned, - const int8_t* B, - size_t ldb, - int32_t* C, - size_t ldc) { - MLAS_GEMM_QUANT_SHAPE_PARAMS GemmShape; - GemmShape.M = M; - GemmShape.N = N; - GemmShape.K = K; - GemmShape.AIsSigned = AIsSigned; - GemmShape.BIsSigned = true; - - size_t PackedBSize = MlasSymmQgemmPackBSize(N, K, AIsSigned); - int8_t* PackedB = (int8_t*)BufferBPacked.GetBuffer(PackedBSize * BatchSize); - - std::vector GemmParameters(BatchSize); - - for (size_t i = 0; i < GemmParameters.size(); i++) { - auto& params = GemmParameters[i]; - params.A = A + (M * K * i); - params.lda = lda; - params.C = C + (M * N * i); - params.ldc = ldc; - - MlasSymmQgemmPackB(N, K, B + (K * N * i), ldb, AIsSigned, offa, PackedB + PackedBSize * i); - params.B = PackedB + PackedBSize * i; - } - - MlasSymmQgemmBatch(GemmShape, GemmParameters.data(), BatchSize, threadpool_); - } - - private: - MatrixGuardBuffer BufferBPacked; -}; - -template -class MlasSymmQgemmTest; - -template -class MlasSymmQgemmTest : public MlasSymmQgemmTestBase { - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize, int32_t offa) { - // Symmetric kernel will have limited buffer overrun when reading the input buffer - constexpr size_t OVERRUN = 15; - const uint8_t* A = BufferA.GetBuffer(K * M * BatchSize + OVERRUN); - const int8_t* B = BufferB.GetBuffer(N * K * BatchSize); - int32_t* C = BufferC.GetBuffer(N * M * BatchSize); - int32_t* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - - Test(M, N, K, BatchSize, A, K, offa, B, N, C, CReference, N); - } - - void Test(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const uint8_t* A, - size_t lda, - int32_t offa, - const int8_t* B, - size_t ldb, - int32_t* C, - int32_t* CReference, - size_t ldc) { - std::fill_n(C, M * N * BatchSize, -1); - std::fill_n(CReference, M * N * BatchSize, -1); - - this->TestGemm(M, N, K, BatchSize, A, lda, offa, std::is_signed::value, B, ldb, C, ldc); - ReferenceQgemm(M, N, K, BatchSize, (const AType*)A, lda, (AType)offa, B, ldb, (const int8_t)0, CReference, ldc); - - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { - ASSERT_EQ(C[f], CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " - << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K - << ", offa=" << offa << ", offb=--"; - } - } - } - } - - private: - void ReferenceQgemm(size_t M, - size_t N, - size_t K, - size_t BatchSize, - const AType* A, - size_t lda, - AType offa, - const int8_t* B, - size_t ldb, - int8_t offb, - int32_t* C, - size_t ldc) { - for (size_t batch = 0; batch < BatchSize; batch++) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - const AType* a = A + (M * K * batch) + (m * lda); - const int8_t* b = B + (K * N * batch) + n; - int32_t* c = C + (M * N * batch) + (m * ldc) + n; - int32_t sum = 0; - - for (size_t k = 0; k < K; k++) { - sum += ((int32_t(*b) - offb) * (int32_t(*a) - offa)); - b += ldb; - a += 1; - } - - *c = sum; - } - } - } - } - - MatrixGuardBuffer BufferA; - MatrixGuardBuffer BufferB; - MatrixGuardBuffer BufferC; - MatrixGuardBuffer BufferCReference; - - public: - static const char* GetTestSuiteName() { - static std::string suite_name = std::string("SymmQgemm") + - (std::is_signed::value ? "S8" : "U8") + - "_Int32" + - (Threaded ? "_Threaded" : "_SingleThread"); - return suite_name.c_str(); - } - - void ExecuteLong(void) override { - static const int32_t zero_points[] = {-18, 124}; - - for (size_t a = 0; a < _countof(zero_points); a++) { - int32_t offa = zero_points[a]; - - for (size_t M = 16; M < 160; M += 32) { - for (size_t N = 16; N < 160; N += 32) { - static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; - for (size_t k = 0; k < _countof(ks); k++) { - size_t K = ks[k]; - - Test(M, N, K, 1, offa); - Test(M + 1, N, K, 1, offa); - Test(M, N + 1, K, 1, offa); - Test(M + 1, N + 1, K, 1, offa); - Test(M + 3, N + 2, K, 1, offa); - Test(M + 4, N, K, 1, offa); - Test(M, N + 4, K, 1, offa); - Test(M + 4, N + 4, K, 1, offa); - Test(M + 3, N + 7, K, 1, offa); - Test(M + 8, N, K, 1, offa); - Test(M, N + 8, K, 1, offa); - Test(M + 12, N + 12, K, 1, offa); - Test(M + 13, N, K, 1, offa); - Test(M, N + 15, K, 1, offa); - Test(M + 15, N + 15, K, 1, offa); - Test(M, N, K, 7 + a, offa); - Test(M + 3, N, K, 7 + a, offa); - Test(M, N + 1, K, 7 + a, offa); - Test(M + 12, N, K, 7 + a, offa); - Test(M, N + 15, K, 7 + a, offa); - Test(M + 15, N + 15, K, 7 + a, offa); - } - } - printf("a %zd/%zd b %zd M %zd\n", a, _countof(zero_points), _countof(zero_points), M); - } - } - - for (size_t M = 1; M < 160; M++) { - for (size_t N = 1; N < 160; N++) { - for (size_t K = 1; K < 160; K++) { - Test(M, N, K, 1, 18); - } - } - printf("M %zd\n", M); - } - - for (size_t M = 160; M < 320; M += 24) { - for (size_t N = 112; N < 320; N += 24) { - for (size_t K = 1; K < 16; K++) { - Test(M, N, K, 1, 1); - } - for (size_t K = 16; K < 160; K += 32) { - Test(M, N, K, 1, -5); - } - } - printf("M %zd\n", M); - } - } -}; diff --git a/onnxruntime/test/mlas/unittest/test_symm_qgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_symm_qgemm_fixture.h deleted file mode 100644 index 71c022211d5d4..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_symm_qgemm_fixture.h +++ /dev/null @@ -1,81 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "test_symm_qgemm.h" - -// -// Short Execute() test helper to register each test separately by all parameters. -// -template -class SymmQgemmShortExecuteTest; - -template -class SymmQgemmShortExecuteTest : public MlasTestFixture> { - public: - explicit SymmQgemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, int32_t offa) - : M_(M), N_(N), K_(K), Batch_(Batch), offa_(offa) { - } - - void TestBody() override { - MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, offa_); - } - - static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, int32_t offa) { - std::stringstream ss; - ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" - << "offa" << offa; - auto test_name = ss.str(); - - testing::RegisterTest( - MlasSymmQgemmTest::GetTestSuiteName(), - test_name.c_str(), - nullptr, - test_name.c_str(), - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture>* { - return new SymmQgemmShortExecuteTest( - M, N, K, Batch, offa); - }); - - return 1; - } - - static size_t RegisterShortExecuteTests() { - size_t test_registered = 0; - - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, 21); - test_registered += RegisterSingleTest(b, b, b, 2 + b / 4, -21); - } - for (size_t b = 1; b < 16; b++) { - test_registered += RegisterSingleTest(b, b, b, 1, 17); - } - for (size_t b = 16; b <= 256; b <<= 1) { - test_registered += RegisterSingleTest(b, b, b, 1, -1); - } - for (size_t b = 256; b < 320; b += 32) { - test_registered += RegisterSingleTest(b, b, b, 1, 85); - } - for (size_t b = 1; b < 96; b++) { - test_registered += RegisterSingleTest(1, b, 32, 1, 0); - test_registered += RegisterSingleTest(1, 32, b, 1, 0); - test_registered += RegisterSingleTest(1, b, b, 1, 0); - test_registered += RegisterSingleTest(1, b, 32, 3, 0); - test_registered += RegisterSingleTest(1, 32, b, 5, 0); - } - test_registered += RegisterSingleTest(43, 500, 401, 7, 113); - test_registered += RegisterSingleTest(2003, 212, 1020, 3, -5); - test_registered += RegisterSingleTest(202, 2003, 1023, 3, 15); - - return test_registered; - } - - private: - size_t M_, N_, K_, Batch_; - int32_t offa_; -}; diff --git a/onnxruntime/test/mlas/unittest/test_transpose.cpp b/onnxruntime/test/mlas/unittest/test_transpose.cpp deleted file mode 100644 index 8fa98411a21ab..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_transpose.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -template -class MlasTransposeTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - void - Test(size_t M, size_t N) { - ElementType* Input = BufferInput.GetBuffer(M * N); - ElementType* Output = BufferOutput.GetBuffer(M * N); - ElementType* OutputReference = BufferOutputReference.GetBuffer(M * N); - - MlasTranspose(Input, Output, M, N); - ReferenceTranspose(Input, OutputReference, M, N); - - ASSERT_EQ(memcmp(Output, OutputReference, M * N * sizeof(ElementType)), 0) << " [" << M << "," << N << "]"; - } - - void ReferenceTranspose(const ElementType* Input, ElementType* Output, size_t M, size_t N) { - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - Output[n * M + m] = Input[m * N + n]; - } - } - } - - public: - static const char* GetTestSuiteName() { - static const std::string suite_name = std::string("Transpose_Size") + std::to_string(int(sizeof(ElementType))); - return suite_name.c_str(); - } - - void ExecuteShort(void) override { - for (size_t m = 1; m <= 32; m++) { - for (size_t n = 1; n <= 32; n++) { - Test(m, n); - } - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h deleted file mode 100644 index 8eecda900ff27..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "mlas.h" -#include "gtest/gtest.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#if defined(_WIN32) -#include -#else -#include -#endif -#if !defined(BUILD_MLAS_NO_ONNXRUNTIME) -#include "core/platform/threadpool.h" -#endif - -#if !defined(UNUSED_VARIABLE) -#if defined(__GNUC__) -#define UNUSED_VARIABLE __attribute__((unused)) -#else -#define UNUSED_VARIABLE -#endif -#endif - -#if !defined(_countof) -#define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0])) -#endif - -MLAS_THREADPOOL* GetMlasThreadPool(void); - -template -class MatrixGuardBuffer { - public: - MatrixGuardBuffer() { - _BaseBuffer = nullptr; - _BaseBufferSize = 0; - _ElementsAllocated = 0; - } - - ~MatrixGuardBuffer(void) { - ReleaseBuffer(); - } - - T* GetFilledBuffer(size_t Elements, std::function const& fillFunc) { - // - // Check if the internal buffer needs to be reallocated. - // - - if (Elements > _ElementsAllocated) { - ReleaseBuffer(); - - // - // Reserve a virtual address range for the allocation plus an unmapped - // guard region. - // - - constexpr size_t BufferAlignment = 64 * 1024; - constexpr size_t GuardPadding = 256 * 1024; - - size_t BytesToAllocate = ((Elements * sizeof(T)) + BufferAlignment - 1) & ~(BufferAlignment - 1); - - _BaseBufferSize = BytesToAllocate + GuardPadding; - -#if defined(_WIN32) - _BaseBuffer = VirtualAlloc(NULL, _BaseBufferSize, MEM_RESERVE, PAGE_NOACCESS); -#else - _BaseBuffer = mmap(0, _BaseBufferSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); -#endif - - if (_BaseBuffer == nullptr) { - abort(); - } - - // - // Commit the number of bytes for the allocation leaving the upper - // guard region as unmapped. - // - -#if defined(_WIN32) - if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { - ORT_THROW_EX(std::bad_alloc); - } -#else - if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { - abort(); - } -#endif - - _ElementsAllocated = BytesToAllocate / sizeof(T); - _GuardAddress = (T*)((unsigned char*)_BaseBuffer + BytesToAllocate); - } - - // - // - // - - T* GuardAddress = _GuardAddress; - T* buffer = GuardAddress - Elements; - fillFunc(buffer, Elements); - - return buffer; - } - - T* GetBuffer(size_t Elements, bool ZeroFill = false) { - if (ZeroFill) { - return GetFilledBuffer( - Elements, - [](T* start, size_t size) { - std::fill_n(start, size, T(0)); - }); - } - - return GetFilledBuffer( - Elements, - [](T* start, size_t size) { - constexpr int offset = -21; - constexpr int range = 43; - - int FillValue = 11; - T* FillAddress = start; - for (size_t i = 0; i < size; i++) { - auto itemv = FillValue - offset; - *FillAddress++ = (T)(itemv); - - FillValue += 7; - FillValue %= range; - } - }); - } - - void ReleaseBuffer(void) { - if (_BaseBuffer != nullptr) { -#if defined(_WIN32) - VirtualFree(_BaseBuffer, 0, MEM_RELEASE); -#else - munmap(_BaseBuffer, _BaseBufferSize); -#endif - - _BaseBuffer = nullptr; - _BaseBufferSize = 0; - } - - _ElementsAllocated = 0; - } - - private: - size_t _ElementsAllocated; - void* _BaseBuffer; - size_t _BaseBufferSize; - T* _GuardAddress; -}; - -class MlasTestBase { - public: - virtual ~MlasTestBase(void) {} - - virtual void ExecuteShort(void) {} - - virtual void ExecuteLong(void) {} -}; - -typedef std::function TestRegister; - -bool AddTestRegister(TestRegister test_register); - -// -// Base Test Fixture which setup/teardown MlasTest in one test suite. -// -template -class MlasTestFixture : public testing::Test { - public: - static void SetUpTestSuite() { - mlas_tester = new TMlasTester(); - }; - - static void TearDownTestSuite() { - if (nullptr != mlas_tester) { - delete mlas_tester; - } - mlas_tester = nullptr; - }; - - static inline TMlasTester* mlas_tester = nullptr; -}; - -// Long Execute test. It is too heavy to register each single test, treat long execute big groups. -template -class MlasLongExecuteTests : public MlasTestFixture { - public: - void TestBody() override { - MlasTestFixture::mlas_tester->ExecuteLong(); - } - - static size_t RegisterLongExecute() { - testing::RegisterTest( - TMlasTester::GetTestSuiteName(), - "LongExecute", - nullptr, - "LongExecute", - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture* { - return new MlasLongExecuteTests(); - }); - return 1; - } -}; - -// Some short Execute may not need to distinguish each parameters, -// because they finish quickly, and may disturb others by inject too many small tests. -// Register it as whole using following helper. -template -class MlasDirectShortExecuteTests : public MlasTestFixture { - public: - void TestBody() override { - MlasTestFixture::mlas_tester->ExecuteShort(); - } - - static size_t RegisterShortExecute() { - testing::RegisterTest( - TMlasTester::GetTestSuiteName(), - "ShortExecute", - nullptr, - nullptr, - __FILE__, - __LINE__, - // Important to use the fixture type as the return type here. - [=]() -> MlasTestFixture* { - return new MlasDirectShortExecuteTests(); - }); - return 1; - } -}; - -inline void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D) { - const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); - int64_t batch_count = input_shape[0]; - int64_t channel_count = input_shape[1]; - int64_t nchwc_channel_count = (channel_count + nchwc_block_size - 1) & ~(nchwc_block_size - 1); - int64_t spatial_count = input_shape[2] * input_shape[3]; - for (int64_t n = 0; n < batch_count; n++) { - MlasReorderInputNchw(S, D, static_cast(channel_count), static_cast(spatial_count)); - S += spatial_count * channel_count; - D += spatial_count * nchwc_channel_count; - } -} - -inline bool CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} diff --git a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc deleted file mode 100644 index d866045ba4962..0000000000000 --- a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc +++ /dev/null @@ -1,121 +0,0 @@ -#include "common.h" - -#include -#include "core/mlas/lib/mlasi.h" -#include "core/util/math_cpuonly.h" -#include "core/util/qmath.h" - -// vanilla implementation of FindMinMax -static void BM_FindMinMaxPlainLoop(benchmark::State& state) { - const size_t batch_size = static_cast(state.range(0)); - float* data = GenerateArrayWithRandomValue(batch_size, -1, 1); - - float min = std::numeric_limits::max(); - float max = std::numeric_limits::lowest(); - for (auto _ : state) { - for (size_t i = 0; i != batch_size; ++i) { - if (min > data[i]) { - min = data[i]; - } - if (max < data[i]) { - max = data[i]; - } - } - } - - // To prevent to optimize out min and max - data[0] = min * max; - aligned_free(data); -} - -BENCHMARK(BM_FindMinMaxPlainLoop) - ->UseRealTime() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kNanosecond) - ->Arg(100) - ->Arg(1000) - ->Arg(10000) - ->Arg(20000) - ->Arg(40000) - ->Arg(80000) - ->Arg(98304) - ->Arg(160000); - -// Eigen implementation of FindMinMax -static void BM_FindMinMaxEigen(benchmark::State& state) { - const size_t batch_size = static_cast(state.range(0)); - float* data = GenerateArrayWithRandomValue(batch_size, -1, 1); - - for (auto _ : state) { - onnxruntime::ConstEigenVectorMap(data, batch_size).minCoeff(); - onnxruntime::ConstEigenVectorMap(data, batch_size).maxCoeff(); - } - aligned_free(data); -} - -BENCHMARK(BM_FindMinMaxEigen) - ->UseRealTime() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kNanosecond) - ->Arg(100) - ->Arg(1000) - ->Arg(10000) - ->Arg(20000) - ->Arg(40000) - ->Arg(80000) - ->Arg(98304) - ->Arg(160000); - -// MLAS sse2 implementation -static void BM_FindMinMaxMlasSSE2(benchmark::State& state) { - const size_t batch_size = static_cast(state.range(0)); - float* data = GenerateArrayWithRandomValue(batch_size, -1, 1); - float min = std::numeric_limits::max(); - float max = std::numeric_limits::lowest(); - for (auto _ : state) { - MlasReduceMinimumMaximumF32Kernel(data, &min, &max, batch_size); - } - aligned_free(data); -} - -BENCHMARK(BM_FindMinMaxMlasSSE2) - ->UseRealTime() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kNanosecond) - ->Arg(100) - ->Arg(1000) - ->Arg(10000) - ->Arg(20000) - ->Arg(40000) - ->Arg(80000) - ->Arg(98304) - ->Arg(160000); - -#ifdef MLAS_TARGET_AMD64 - -// MLAS avx implementation -static void BM_FindMinMaxMlasAvx(benchmark::State& state) { - const size_t batch_size = static_cast(state.range(0)); - float* data = GenerateArrayWithRandomValue(batch_size, -1, 1); - float min = std::numeric_limits::max(); - float max = std::numeric_limits::lowest(); - for (auto _ : state) { - MlasReduceMinimumMaximumF32KernelAvx(data, &min, &max, batch_size); - } - aligned_free(data); -} - -BENCHMARK(BM_FindMinMaxMlasAvx) - ->UseRealTime() - ->UseRealTime() - ->Unit(benchmark::TimeUnit::kNanosecond) - ->Arg(100) - ->Arg(1000) - ->Arg(10000) - ->Arg(20000) - ->Arg(40000) - ->Arg(80000) - ->Arg(98304) - ->Arg(160000); - -#endif // MLAS_TARGET_AMD64 diff --git a/onnxruntime/test/onnx/microbenchmark/resize.cc b/onnxruntime/test/onnx/microbenchmark/resize.cc index 020680c12b8f5..65cca6d52056a 100644 --- a/onnxruntime/test/onnx/microbenchmark/resize.cc +++ b/onnxruntime/test/onnx/microbenchmark/resize.cc @@ -5,7 +5,6 @@ #include #include "core/common/safeint.h" #include "core/framework/allocator.h" -#include "core/mlas/lib/mlasi.h" #include "core/providers/cpu/tensor/upsample.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 3aec0d5a67e94..916a40bf3ca18 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -19,7 +19,7 @@ #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/bias_dropout_fusion.h" #include "core/optimizer/bias_gelu_fusion.h" diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 538f60040418c..b70e14a17d725 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -3,7 +3,7 @@ #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/session/environment.h" #include "core/session/inference_session.h" #include "core/framework/tensorprotoutils.h" diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index a247fea7e5f53..ba902458aaffa 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -6,7 +6,7 @@ #include "gtest/gtest.h" #include "graph_transform_test_builder.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/graph/graph.h" namespace onnxruntime { diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc index ccfa1f1159937..e932a58f96aff 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -5,7 +5,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..1aacf6070a989 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -9,7 +9,7 @@ #include "core/framework/int4.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" diff --git a/onnxruntime/test/optimizer/resnet50_fusion_test.cc b/onnxruntime/test/optimizer/resnet50_fusion_test.cc index 5cb0206156a84..47d8dac1dca79 100644 --- a/onnxruntime/test/optimizer/resnet50_fusion_test.cc +++ b/onnxruntime/test/optimizer/resnet50_fusion_test.cc @@ -5,7 +5,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_add_act_fusion.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "gtest/gtest.h" #include "graph_transform_test_builder.h" #include "test/test_environment.h" diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 8ca0f6d845a09..716bf11c4920b 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -7,7 +7,7 @@ #include #include #include -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/graph/constants.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index d0069a0069646..07a40ac77b3c5 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/framework/run_options.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index aa57451e7892a..fa115931d71b6 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -8,7 +8,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" diff --git a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc index 096263792727a..ad26f72409f37 100644 --- a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/quantization/quantize_linear_matmul.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 4253e36e02548..8b910dc7fe405 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index d4e0af5011525..824f26aaf1882 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK) diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc index 2bc0df5e3635f..4d678f82bf994 100644 --- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc @@ -4,7 +4,7 @@ #include #include -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/util/math.h" #include "default_providers.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 4954b82690e0f..3011263f5cc18 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "core/providers/cpu/tensor/space_depth_ops.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 3fcb9045ee7e6..b74d0ad6cb301 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -18,7 +18,7 @@ #include #include "core/framework/float16.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" namespace onnxruntime { namespace test { diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 29a309920c74b..671034e0d0ca5 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -4,7 +4,7 @@ #include #include "orttraining/core/optimizer/graph_transformer_utils.h" -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/optimizer/bias_dropout_fusion.h" #include "core/optimizer/bias_gelu_fusion.h" #include "core/optimizer/bias_softmax_fusion.h" diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index f4b9c08bd90cd..e43563bd671fd 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -4,7 +4,7 @@ #include "orttraining/training_ops/cpu/op_gradients.h" #include -#include "core/mlas/inc/mlas.h" +#include "mlas.h" #include "core/providers/common.h" #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/cpu/math/matmul_helper.h" diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9624f9112c49f..6373283f78633 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2066,7 +2066,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): dll_path = os.pathsep.join(dll_path_list) if not ctest_path and not is_windows(): - executables = ["onnxruntime_test_all", "onnxruntime_mlas_test"] + executables = ["onnxruntime_test_all"] if args.build_shared_lib: executables.append("onnxruntime_shared_lib_test") executables.append("onnxruntime_global_thread_pools_test") diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 7f131590c900b..61ea555d8b48d 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -335,13 +335,12 @@ stages: LLVM_PROFILE_FILE="%p.profraw" CFLAGS="-g -fprofile-instr-generate -fcoverage-mapping" CXXFLAGS="-g -fprofile-instr-generate -fcoverage-mapping" CC=clang CXX=clang++ python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir=$(Build.BinariesDirectory) --config Debug --parallel --skip_submodule_sync --build_shared_lib --enable_onnx_tests --cmake_extra_defines RUN_MODELTEST_IN_DEBUG_MODE=ON cd Debug - ./onnxruntime_mlas_test #Merge the multiple prof data into a single indexed profile data file llvm-profdata merge -sparse -o ort.profdata *.profraw #Create coverage report, output the result to 'report.json' - llvm-cov export -summary-only -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops > $(Build.BinariesDirectory)/report.json + llvm-cov export -summary-only -instr-profile=ort.profdata onnxruntime_test_all -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops > $(Build.BinariesDirectory)/report.json - llvm-cov show -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops --format=html -output-dir=$(Build.ArtifactStagingDirectory) + llvm-cov show -instr-profile=ort.profdata onnxruntime_test_all -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops --format=html -output-dir=$(Build.ArtifactStagingDirectory) workingDirectory: $(Build.BinariesDirectory) - ${{ if or(startsWith(variables['System.CollectionUri'], 'https://dev.azure.com/aiinfra/'),startsWith(variables['System.CollectionUri'], 'https://aiinfra.visualstudio.com/')) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index b3a47039005a9..5e6c658903d67 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.193 + version: 1.0.198 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.193 + version: 1.0.198 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here.