diff --git a/CMakeLists.txt b/CMakeLists.txt index 0bba5d6aa88ab..d29d5410362c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,6 @@ project(triton CXX C) include(CTest) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -set(DNNLROOT "/home/jovyan/intel/oneapi/dnnl/2024.1") # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index a1c37efb52f1c..5e66d62f0bcff 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -304,7 +304,7 @@ struct SharedMemoryObject { } Value getCSwizzleOffset(int order) const { - assert(order >= 0 && order < strides.size()); + assert(order >= 0 && order < static_cast(strides.size())); return offsets[order]; } @@ -512,7 +512,6 @@ inline SmallVector emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, const BlockedEncodingAttr &blockedLayout, RankedTensorType type) { - MLIRContext *ctx = rewriter.getContext(); auto shape = type.getShape(); Value threadId = getThreadId(rewriter, loc); Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); @@ -557,7 +556,6 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, inline SmallVector> emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, RankedTensorType type) { - auto ctx = type.getContext(); auto shape = type.getShape(); auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); @@ -1205,9 +1203,8 @@ inline DenseMap getSwizzledSharedPtrs( // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); - assert(maxPhase == 1 || - outVec * maxPhase <= srcShape[outOrder[0]] && - "Swizzling would generate out of bounds memory accesses"); + assert((maxPhase == 1 || outVec * maxPhase <= srcShape[outOrder[0]]) && + "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy, /*withCTAOffset=*/false); @@ -1424,7 +1421,7 @@ inline Value packLLVector(Location loc, ValueRange vals, assert(vals.size() > 0); auto vecType = vec_ty(vals[0].getType(), vals.size()); Value vec = undef(vecType); - for (int i = 0; i < vals.size(); i++) { + for (int i = 0; i < static_cast(vals.size()); i++) { vec = insert_element(vec, vals[i], i32_val(i)); } return vec; diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index f8e2a62ee129b..ae0c4c6f351de 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -205,18 +205,13 @@ def TTC_BrgemmCreate : TTC_Op<"brgemm_create", [NoMemoryEffect]> { AnyTypeOf<[AnyInteger, Index]>:$lda, AnyTypeOf<[AnyInteger, Index]>:$ldb, AnyTypeOf<[AnyInteger, Index]>:$ldc, - I64:$dtypeA, - I64:$dtypeB, - I64:$dtypeC + // TODO: Maybe Use properties + TypeAttr:$dtypeA, + TypeAttr:$dtypeB, + TypeAttr:$dtypeC ); let results = (outs Index:$result); - - // let assemblyFormat = [{ - // $prefix attr-dict (`:` $val^ `:` type($val))? - // }]; - - // let hasVerifier = 1; } def TTC_BrgemmCall : TTC_Op<"brgemm_call", @@ -231,7 +226,6 @@ def TTC_BrgemmCall : TTC_Op<"brgemm_call", Index:$kernel_hash, Arg:$A_ptr, Arg:$B_ptr, - // AnyTypeOf<[TT_Float, TT_Int, TT_Ptr, TTC_Vector]>:$offsets, Arg:$C_ptr, Arg:$scratchpad, @@ -239,12 +233,6 @@ def TTC_BrgemmCall : TTC_Op<"brgemm_call", Index:$stepB, Index:$numBatches ); - - // let assemblyFormat = [{ - // $prefix attr-dict (`:` $val^ `:` type($val))? - // }]; - - // let hasVerifier = 1; } def TTC_BrgemmNeedsPacking : TTC_Op<"brgemm_needs_packing", @@ -268,9 +256,8 @@ def TTC_CallBrgemmWithTransform : TTC_Op<"pack_and_brgemm", Index:$brgemm_kernel_hash, Arg:$A_ptr, Arg:$B_ptr, - // AnyTypeOf<[TT_Float, TT_Int, TT_Ptr, TTC_Vector]>:$offsets, Arg:$C_ptr, - Arg:$scratchpad, + // Arg:$scratchpad, AnyTypeOf<[AnyInteger, Index]>:$stepA, AnyTypeOf<[AnyInteger, Index]>:$stepB, @@ -288,20 +275,14 @@ def TTC_TransformCreate : TTC_Op<"transform_create", [NoMemoryEffect]> { let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$K, AnyTypeOf<[AnyInteger, Index]>:$N, - // I64:$in_pack_type, AnyTypeOf<[AnyInteger, Index]>:$in_ld, AnyTypeOf<[AnyInteger, Index]>:$out_ld, - I64:$in_dt, - I64:$out_dt + // TODO: Maybe Use properties + TypeAttr:$in_dt, + TypeAttr:$out_dt ); let results = (outs Index:$result); - - // let assemblyFormat = [{ - // $prefix attr-dict (`:` $val^ `:` type($val))? - // }]; - - // let hasVerifier = 1; } def TTC_TransformCall : TTC_Op<"transform_call", @@ -326,12 +307,6 @@ def TTC_ConfigureHW : TTC_Op<"configure_hw", [MemoryEffects<[MemWrite]>]> { @@ -341,12 +316,6 @@ def TTC_ReleaseHW : TTC_Op<"release_hw", [MemoryEffects<[MemWrite] // M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC let arguments = (ins); - - // let assemblyFormat = [{ - // $prefix attr-dict (`:` $val^ `:` type($val))? - // }]; - - // let hasVerifier = 1; } #endif diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index f8137ede1c4cd..58a52fb8f82f2 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -56,8 +56,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + libraries += ["gcc"] - cc_cmd += ["-g"] # Use dynamic lookup to load Python library on Mac if system == "Darwin": cc_cmd += ["-undefined", "dynamic_lookup"] diff --git a/test/TritonCPU/dot-to-onednn.mlir b/test/TritonCPU/dot-to-onednn.mlir deleted file mode 100644 index d6c8c53cc6699..0000000000000 --- a/test/TritonCPU/dot-to-onednn.mlir +++ /dev/null @@ -1,106 +0,0 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-convert-dot-to-onednn -canonicalize | FileCheck %s - -// Replacement of a contraction operation with a single tile_mulf operation. - -// CHECK-LABEL: @test_single_mulf -// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x32xbf16> -// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> -// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index -// CHECK: %[[ACC:.+]] = amx.tile_zero : vector<16x16xf32> -// CHECK-NEXT: %[[LHS:.+]] = amx.tile_load %3[%4#0, %4#1] -// CHECK-NEXT: %[[RHS:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] -// CHECK-NEXT: %[[RES:.+]] = amx.tile_mulf %[[LHS]], %[[RHS]], %[[ACC]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES]] : memref<16x16xf32, strided<[16, 1]>>, vector<16x16xf32> - -// #loc = loc(unknown) -// #map = affine_map<(d0, d1, d2) -> (d0, d2)> -// #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -// #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -// module { -// tt.func public @test_single_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { -// %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) -// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf32> loc(#loc) -// %c16_i64 = arith.constant 16 : i64 loc(#loc) -// %c32_i64 = arith.constant 32 : i64 loc(#loc) -// %c1_i64 = arith.constant 1 : i64 loc(#loc) -// %c0_i32 = arith.constant 0 : i32 loc(#loc) -// %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) -// %1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) -// %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) -// %3 = triton_cpu.extract_memref %0 : > -> memref<16x32xbf16, strided<[32, 1]>> loc(#loc) -// %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) -// %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x32xbf16, strided<[32, 1]>>, vector<16x32xbf16> loc(#loc) -// %6 = triton_cpu.extract_memref %1 : > -> memref<32x16xbf16, strided<[16, 1]>> loc(#loc) -// %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) -// %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<32x16xbf16, strided<[16, 1]>>, vector<32x16xbf16> loc(#loc) -// %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst_0 : vector<16x32xbf16>, vector<32x16xbf16> into vector<16x16xf32> loc(#loc) -// %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> loc(#loc) -// %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) -// vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[16, 1]>> loc(#loc) -// tt.return loc(#loc) -// } loc(#loc) -// } loc(#loc) - -// -----// IR Dump Before OneDNNOpsToLLVM (triton-cpu-onednn-ops-to-llvm) ('builtin.module' operation) //----- // -#loc = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0) -module { - tt.func public @matmul_blocked_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5906:0)) attributes {noinline = false} { - %cst = arith.constant 0.000000e+00 : f32 loc(#loc1) - %cst_0 = arith.constant dense<0.000000e+00> : vector<32x16xf32> loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %c1_i64 = arith.constant 1 : i64 loc(#loc1) - %c16_i32 = arith.constant 16 : i32 loc(#loc1) - %c32_i32 = arith.constant 32 : i32 loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = tt.get_program_id y : i32 loc(#loc3) - %2 = arith.muli %0, %c32_i32 : i32 loc(#loc4) - %3 = arith.muli %1, %c16_i32 : i32 loc(#loc5) - %4 = arith.extsi %arg3 : i32 to i64 loc(#loc6) - %5 = arith.extsi %arg5 : i32 to i64 loc(#loc6) - %6 = arith.extsi %arg6 : i32 to i64 loc(#loc6) - %7 = tt.make_tensor_ptr %arg0, [%4, %5], [%6, %c1_i64], [%2, %c0_i32] {order = array} : > loc(#loc6) - %8 = arith.extsi %arg4 : i32 to i64 loc(#loc7) - %9 = arith.extsi %arg7 : i32 to i64 loc(#loc7) - %10 = tt.make_tensor_ptr %arg1, [%5, %8], [%9, %c1_i64], [%c0_i32, %3] {order = array} : > loc(#loc7) - %11 = arith.divsi %arg5, %c16_i32 : i32 loc(#loc20) - %12:3 = scf.for %arg9 = %c0_i32 to %11 step %c1_i32 iter_args(%arg10 = %cst_0, %arg11 = %7, %arg12 = %10) -> (vector<32x16xf32>, !tt.ptr>, !tt.ptr>) : i32 { - %17 = triton_cpu.extract_memref %arg11 : > -> memref> loc(#loc11) - %18:2 = triton_cpu.extract_indices %arg11 : > -> index, index loc(#loc11) - %19 = vector.transfer_read %17[%18#0, %18#1], %cst : memref>, vector<32x16xf32> loc(#loc11) - %20 = triton_cpu.extract_memref %arg12 : > -> memref> loc(#loc12) - %21:2 = triton_cpu.extract_indices %arg12 : > -> index, index loc(#loc12) - %22 = vector.transfer_read %20[%21#0, %21#1], %cst : memref>, vector<16x16xf32> loc(#loc12) - %23 = triton_cpu.dot %19, %22, %arg10, inputPrecision = tf32 : vector<32x16xf32> * vector<16x16xf32> -> vector<32x16xf32> loc(#loc13) - %24 = tt.advance %arg11, [%c0_i32, %c16_i32] : > loc(#loc14) - %25 = tt.advance %arg12, [%c16_i32, %c0_i32] : > loc(#loc15) - scf.yield %23, %24, %25 : vector<32x16xf32>, !tt.ptr>, !tt.ptr> loc(#loc16) - } loc(#loc10) - %13 = arith.extsi %arg8 : i32 to i64 loc(#loc17) - %14 = tt.make_tensor_ptr %arg2, [%4, %8], [%13, %c1_i64], [%2, %3] {order = array} : > loc(#loc17) - %15 = triton_cpu.extract_memref %14 : > -> memref> loc(#loc18) - %16:2 = triton_cpu.extract_indices %14 : > -> index, index loc(#loc18) - vector.transfer_write %12#0, %15[%16#0, %16#1] : vector<32x16xf32>, memref> loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5916:26) -#loc3 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5917:26) -#loc4 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5927:29) -#loc5 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5928:29) -#loc6 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5931:106) -#loc7 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5933:106) -#loc8 = loc("/home/jovyan/triton-cpu/python/triton/language/standard.py":40:28) -#loc9 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5935:36) -#loc10 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5935:51) -#loc11 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5936:20) -#loc12 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5937:20) -#loc13 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5938:32) -#loc14 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5939:44) -#loc15 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5940:44) -#loc16 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5940:8) -#loc17 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5951:36) -#loc18 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5952:26) -#loc19 = loc("/home/jovyan/triton-cpu/python/test/unit/language/test_core.py":5952:4) -#loc20 = loc(callsite(#loc8 at #loc9)) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 76b3162bcd38b..e24d25a65019b 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,16 +1,30 @@ +find_package(dnnl CONFIG) +if (dnnl_FOUND) + message(STATUS "Found OneDNN/DNNL") + add_compile_definitions(ONEDNN_AVAILABLE) + get_target_property(dnnl_include DNNL::dnnl INTERFACE_INCLUDE_DIRECTORIES) + # currently used only in triton_cpu.cc and in ConvertDotToOneDNN + include_directories(${dnnl_include}) +else () + message(STATUS "Could NOT find OneDNN/DNNL") +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -include_directories(/usr/local/include) -link_directories(/usr/local/lib) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms dnnl PRIVATE Python3::Module pybind11::headers) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms PRIVATE Python3::Module pybind11::headers) endif() -add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_onednn.cpp) -target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport dnnl) +if (dnnl_FOUND) + add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_onednn.cpp) + target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport DNNL::dnnl) +else () + add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) + target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) +endif() # Build and link sleef set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) diff --git a/third_party/cpu/TestOneDNNukernel.cpp b/third_party/cpu/TestOneDNNukernel.cpp deleted file mode 100644 index 12b8e696f52b5..0000000000000 --- a/third_party/cpu/TestOneDNNukernel.cpp +++ /dev/null @@ -1,634 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include - -using namespace dnnl; -using namespace dnnl::ukernel; - -using tag = memory::format_tag; -using dt = memory::data_type; - -using read_lock_guard_t = std::shared_lock; -using write_lock_guard_t = std::unique_lock; -static std::shared_mutex g_brgemm_lock; - -#ifdef Nope -static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, - Attribute attr) { - auto context = rewriter.getContext(); - auto tattr = dyn_cast_or_null(attr); - assert(tattr); - if (tattr == TypeAttr::get(FloatType::getF32(context))) { - return static_cast(dnnl_f32); - } else if (tattr == TypeAttr::get(FloatType::getF64(context))) { - return static_cast(dnnl_f64); - } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { - return static_cast(dnnl_bf16); - } else if (tattr == TypeAttr::get(FloatType::getF16(context))) { - return static_cast(dnnl_f16); - } else if (tattr == TypeAttr::get( - IntegerType::get(context, 32, IntegerType::Signed))) { - return static_cast(dnnl_s32); - } else if (tattr == - TypeAttr::get(IntegerType::get(context, 8, IntegerType::Signed))) { - return static_cast(dnnl_s8); - } else if (tattr == TypeAttr::get(IntegerType::get(context, 8, - IntegerType::Unsigned))) { - return static_cast(dnnl_u8); - } - return static_cast(dnnl_data_type_undef); -} - -static constexpr int PALETTE_SIZE = 64; -static constexpr int DEFAULT_KERNEL_SIZE = 1024; -static constexpr int MAX_KERNEL_SIZE = 2048; - -using read_lock_guard_t = std::shared_lock; -using write_lock_guard_t = std::unique_lock; -static std::shared_mutex g_brgemm_lock; - -struct brgemm_cache_info_t { - brgemm_desc_t desc; - brgemm_kernel_t *kernel; - std::unique_ptr palette; -}; - -static std::vector g_cache(DEFAULT_KERNEL_SIZE); -static int64_t g_kernel_id = -1; - -// TODO(haixin): use syscall to determine page size? -static constexpr size_t SCRATCH_SIZE = 2 * 4096; -// TODO(haixin): need to use custom thread management for scratch in the future? -static thread_local char scratch[SCRATCH_SIZE] = {0}; - -int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, - int64_t LDB, int64_t LDC, int64_t stride_a, - int64_t stride_b, float beta, int64_t dtypeA, - int64_t dtypeB) { - auto dnnl_dtypeA = static_cast(dtypeA); - auto dnnl_dtypeB = static_cast(dtypeB); - int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA); - int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB); - brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size}; - - write_lock_guard_t g(g_brgemm_lock); - g_kernel_id++; - assert(g_kernel_id < MAX_KERNEL_SIZE && - "Too many brgemm kernels are created"); - if (g_kernel_id >= DEFAULT_KERNEL_SIZE) { - if (g_kernel_id >= (int64_t)g_cache.size()) { - g_cache.resize(g_kernel_id + 1); - } - } - - dnnl::impl::status_t status = brgemm_desc_init( - &g_cache[g_kernel_id].desc, cpu_isa_t::isa_undef, - brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB, - /*transA=*/false, /*transB=*/false, brgemm_layout_t::brgemm_row_major, - 1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info); - assert(status == dnnl::impl::status::success && - "Failed to initialize BRGEMM descriptor"); - - status = brgemm_kernel_create(&g_cache[g_kernel_id].kernel, - g_cache[g_kernel_id].desc); - assert(status == dnnl::impl::status::success && - "Failed to JIT BRGEMM kernel"); - - brgemm_attr_t dnnl_attrs; - brgemm_desc_set_attr(&g_cache[g_kernel_id].desc, dnnl_attrs); - - if (g_cache[g_kernel_id].desc.is_tmm) { - g_cache[g_kernel_id].palette.reset(new char[PALETTE_SIZE]); - status = brgemm_init_tiles(g_cache[g_kernel_id].desc, - g_cache[g_kernel_id].palette.get()); - assert(status == dnnl::impl::status::success && - "Failed to initialize palette for BRGEMM"); - } - - return g_kernel_id; -} - -void dnnl_brgemm_tileconfig(int64_t kernel_idx) { - std::unique_ptr lock_guard; - if (kernel_idx >= DEFAULT_KERNEL_SIZE) { - lock_guard = std::make_unique(g_brgemm_lock); - } - assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && - "Invalid kernel handler"); - brgemm_desc_t &desc = g_cache[kernel_idx].desc; - if (!desc.is_tmm) { - return; - } - char *palette_buffer = g_cache[kernel_idx].palette.get(); - assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel"); - amx_tile_configure(palette_buffer); -} - -void dnnl_brgemm_tilerelease() { - if (!mayiuse(avx512_core_amx)) { - return; - } - - amx_tile_release(); -} - -void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, - void *B, uint64_t B_offset, void *C, uint64_t C_offset, - int num) { - std::unique_ptr lock_guard; - if (kernel_idx >= DEFAULT_KERNEL_SIZE) { - lock_guard = std::make_unique(g_brgemm_lock); - } - assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && - "Invalid kernel handler"); - brgemm_desc_t &desc = g_cache[kernel_idx].desc; - brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel; - assert(kernel && "Invalid brgemm kernel pointer"); - size_t A_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; - size_t B_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; - size_t C_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; - char *A_arith = static_cast(A) + A_offset_in_bytes; - char *B_arith = static_cast(B) + B_offset_in_bytes; - char *C_arith = static_cast(C) + C_offset_in_bytes; - brgemm_kernel_execute(kernel, num, A_arith, B_arith, nullptr, C_arith, - scratch); -} -#endif - -inline dnnl::memory::dim product(const dnnl::memory::dims &dims) { - return std::accumulate(dims.begin(), dims.end(), (dnnl::memory::dim)1, - std::multiplies()); -} - -void *create_brgemm_ukernel(int64_t M, int64_t N, int64_t K_k, - int64_t batch_size, int64_t lda, int64_t ldb, - int64_t ldc, int64_t dtypeA, int64_t dtypeB, - int64_t dtypeC) { - using K = std::array; - std::cout << "Args: M - " << M << ", N - " << N << ", K - " << K_k - << ", bath - " << batch_size << ", lda - " << lda << ", ldb - " - << ldb << ", ldc - " << ldc << ", dtype a - " << dtypeA - << ", dtype b - " << dtypeB << ", dtype c - " << dtypeC << "\n"; - K key{M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC}; - - static std::map savedUkernels; - { - read_lock_guard_t r_g(g_brgemm_lock); - if (savedUkernels.count(key) != 0) { - return &savedUkernels.find(key)->second; - } - } - - write_lock_guard_t w_g(g_brgemm_lock); - - if (savedUkernels.count(key) != 0) { - return &savedUkernels.find(key)->second; - } - - auto dnnl_dtypeA = static_cast(dtypeA); - auto dnnl_dtypeB = static_cast(dtypeB); - auto dnnl_dtypeC = static_cast(dtypeC); - - std::cout << std::boolalpha; - std::cout << "Is fp32? " << (dnnl_dtypeA == dt::f32) << "\n"; - - dnnl::ukernel::brgemm brg; - brg = dnnl::ukernel::brgemm(M, N, K_k, batch_size, lda, ldb, ldc, dnnl_dtypeA, - dnnl_dtypeB, dnnl_dtypeC); - - std::cout << "Brg: " << &brg << "\n"; - // Instruct the kernel to append the result to C tensor. - brg.set_add_C(true); - // Finalize the initialization. - brg.finalize(); - // Generate the executable JIT code for the objects. - brg.generate(); - - auto it = savedUkernels.insert({key, brg}); - std::cout << "Ptr: " << &it.first->second << "\n"; - return &it.first->second; -} - -void *create_transform_ukernel(int64_t K, int64_t N, int64_t in_ld, - int64_t out_ld, int64_t inDtype, - int64_t outDtype) { - using K_t = std::array; - K_t key{K, N, in_ld, out_ld, inDtype, outDtype}; - - static std::map savedUkernels; - { - read_lock_guard_t r_g(g_brgemm_lock); - if (savedUkernels.count(key) != 0) { - return &savedUkernels.find(key)->second; - } - } - - write_lock_guard_t w_g(g_brgemm_lock); - - if (savedUkernels.count(key) != 0) { - return &savedUkernels.find(key)->second; - } - - // Packing B tensor routine. The BRGeMM ukernel expects B passed in a - // special VNNI format for low precision data types, e.g., bfloat16_t. - // Note: the routine doesn't provide a `batch_size` argument in the - // constructor as it can be either incorporated into `K` dimension, or - // manually iterated over in a for-loop on the user side. - dnnl::ukernel::transform pack_B( - /* K = */ K, /* N = */ N, - /* in_pack_type = */ pack_type::no_trans, - /* in_ld = */ in_ld, - /* out_ld = */ out_ld, - /* in_dt = */ static_cast(inDtype), - /* out_dt = */ static_cast(outDtype)); - - // Pack B routine execution. - // Note: usually should be split to process only that part of B that the - // ukernel will execute. - pack_B.generate(); - - auto it = savedUkernels.insert({key, pack_B}); - return &it.first->second; -} - -void call_all(const void *transform_k, const void *brg_k, void *A_ptr, - void *original_B_ptr, void *C_ptr, void *scratchpad, - int64_t A_step_in_bytes, int64_t B_step_in_bytes, - int64_t B_block_size_in_bytes, int64_t num_batches, - bool skip_packing = false) { - - void *blocked_data = original_B_ptr; - std::cout << "Call Transform: " << transform_k << " Brg: " << brg_k - << ", a: " << A_ptr << ", b: " << original_B_ptr << ", c: " << C_ptr - << ", scr: " << scratchpad << "\n"; - std::cout << "steps: " << A_step_in_bytes << " " << B_step_in_bytes << " " - << B_block_size_in_bytes << " n: " << num_batches << "\n"; - - auto pack_B = reinterpret_cast(transform_k); - auto brg = reinterpret_cast(brg_k); - std::cout << " vanilla check pack: " - << ((brg->get_B_pack_type() == pack_type::pack32) ? "true" - : "false") - << "\n"; - bool need_packing = - brg->get_B_pack_type() == pack_type::pack32 && !skip_packing; - if (need_packing) { - std::cout << "Will be packed. \n"; - // output - - // blocked_B_size = block_K * block_n * dtype; // ldb * K_k * - // memory::data_type_size(b_dt); - - blocked_data = new uint8_t[B_block_size_in_bytes * num_batches]; - - pack_B->execute(original_B_ptr, blocked_data); - } - - brg->set_hw_context(); - - std::vector> A_B_offsets(num_batches); - for (memory::dim i = 0; i < num_batches; i++) { - const memory::dim A_offset_i = - i * A_step_in_bytes; // * a_dt_size; // K_k * a_dt_size; - const memory::dim B_offset_i = - need_packing ? i * B_block_size_in_bytes : i * B_step_in_bytes; - A_B_offsets[i] = std::make_pair(A_offset_i, B_offset_i); - } - - size_t scratchpad_size = brg->get_scratchpad_size(); - std::vector scratchpad_sm(scratchpad_size); - - // An execute call. `A_B` is a vector of pointers to A and packed B - // tensors. `acc_ptr` is a pointer to an accumulator buffer. - brg->execute(A_ptr, blocked_data, A_B_offsets, C_ptr, scratchpad_sm.data()); - - dnnl::ukernel::brgemm::release_hw_context(); - - if (need_packing) { - delete blocked_data; - }; -} - -int main() { - // brgemm_example(); - // return 0; - - // dnnl_brgemm_dispatch(16, 16, 16, 0, 0, 0 ); - // dnnl_brgemm_tileconfig(); - - // dnnl_brgemm_execute(); - - // dnnl_brgemm_execute(); - - // ukernel dimensions. - // K is for a whole tensor, K_k is for a single ukernel. - const memory::dim M = 32, K = 32, K_k = 16, N = 32; - if (K % K_k != 0) { - printf("K_k must divide K.\n"); - return 0; - } - const memory::dim n_calls = K / K_k; - std::cout << "n_cals: " << n_calls << "\n"; - - const memory::dim lda = K; - const memory::dim ldb = N; - const memory::dim ldc = N; // Leading dimension for accumulator. - const memory::dim ldd = N; // Leading dimension for an actual output. - const memory::dim batch_size = n_calls; - - memory::data_type a_dt = dt::f32; // dt::bf16; - memory::data_type b_dt = dt::f32; // dt::bf16; - memory::data_type c_dt = dt::f32; // Accumulator data type. - memory::data_type d_dt = dt::f32; // Output data type. - - // A, B, and C tensors dimensions. - memory::dims A_dims = {M, K}; - memory::dims B_dims = {K, N}; - memory::dims C_dims = {M, N}; - memory::dims D_dims = {M, N}; - - // Allocate buffers with user data. - std::vector A_user_data(product(A_dims)); - std::vector B_user_data(product(B_dims)); - std::vector D_data(product(D_dims)); // For reference comparison - std::vector D_user_data(product(D_dims)); // For reference comparison - - // Initialize A. - std::generate(A_user_data.begin(), A_user_data.end(), []() { - static int i = 1; - return i++ % 4; - }); - // Initialize B. - std::generate(B_user_data.begin(), B_user_data.end(), []() { - static int i = 6; - static int sign_gen = 0; - int sign = (sign_gen++ % 2) ? -1 : 1; - float val = sign * (i++ % 5); - return val; - }); - - // Create execution dnnl::engine. Needed for reorders to operate over input - // data. - dnnl::engine engine(engine::kind::cpu, 0); - - // Create dnnl::stream. Needed for reorders for the same reason. - dnnl::stream engine_stream(engine); - - // Create f32 memories. They are used as data holders and reorder into - // memories passed to the ukernel. - auto A_f32_md = memory::desc(A_dims, dt::f32, tag::ab); - auto B_f32_md = memory::desc(B_dims, dt::f32, tag::ab); - auto D_f32_md = memory::desc(D_dims, dt::f32, tag::ab); - - auto A_f32_mem = memory(A_f32_md, engine, A_user_data.data()); - auto B_f32_mem = memory(B_f32_md, engine, B_user_data.data()); - auto D_f32_mem = memory(D_f32_md, engine, D_data.data()); - - // Create ukernel memories in requested data types. - // Note that all formats are `ab`. - auto A_md = memory::desc(A_dims, a_dt, tag::ab); - auto B_md = memory::desc(B_dims, b_dt, tag::ab); - - auto C_md = memory::desc(C_dims, c_dt, tag::ab); - auto D_md = memory::desc(D_dims, d_dt, tag::ab); - - auto A_mem = memory(A_md, engine); - auto B_mem = memory(B_md, engine); - - auto C_mem = memory(C_md, engine); - auto D_mem = memory(D_md, engine); - - const auto *A_ptr = reinterpret_cast(A_mem.get_data_handle()); - auto *B_ptr = reinterpret_cast(B_mem.get_data_handle()); - - const size_t a_dt_size = - memory::data_type_size(A_mem.get_desc().get_data_type()); - const size_t b_dt_size = - memory::data_type_size(B_mem.get_desc().get_data_type()); - - reorder(A_f32_mem, A_mem).execute(engine_stream, A_f32_mem, A_mem); - reorder(B_f32_mem, B_mem).execute(engine_stream, B_f32_mem, B_mem); - reorder(D_f32_mem, D_mem).execute(engine_stream, D_f32_mem, D_mem); - - float *C_ptr = reinterpret_cast(C_mem.get_data_handle()); - for (memory::dim i = 0; i < M * N; i++) { - C_ptr[i] = 0; - } - - auto brg_k = - create_brgemm_ukernel(M, N, K_k, batch_size, lda, ldb, ldc, 3, 3, 3); - auto tfrm = create_transform_ukernel(K_k * n_calls, N, N, ldb, 3, 3); - - // void *B_base_ptr = B_ptr; - - // blocked_B_size = ldb * K_k * memory::data_type_size(b_dt); - - call_all(tfrm, brg_k, (void *)A_ptr, (void *)B_ptr, (void *)C_ptr, nullptr, - K_k * a_dt_size, N * K_k * b_dt_size, - ldb * K_k * memory::data_type_size(b_dt), batch_size); - - printf("( m, n) val after \"brg\" call\n"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - printf("Acc buffer(C_ptr) res: (%2d, %2d) Ref:%12f\n", m, n, - C_ptr[m * N + n]); - // if (scratchpad.size() != 0) { - // printf("Out buffer res: (%2d, %2d) Ref:%12f\n", m, n, - // scratchpad[m * N + n]); - // } - } - } - - bool to_throw = false; - printf("( m, n, k) val\n"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - D_user_data[m * N + n] = 0; - for (int k = 0; k < K; k++) { - printf("(%2d, %2d, %2d) A: %12f B: %12f\n", m, n, k, - A_user_data[m * K + k], B_user_data[k * N + n]); - D_user_data[m * N + n] += - A_user_data[m * K + k] * B_user_data[k * N + n]; - } - const float diff = fabsf(D_user_data[m * N + n] - C_ptr[m * N + n]); - if (diff > 1.19e-7) { - to_throw = true; - if (true) { - printf("Error: (%2d, %2d) Ref:%12g Got:%12g Diff:%12g\n", m, n, - D_user_data[m * N + n], C_ptr[m * N + n], diff); - } - } - // else { - // printf("Matched res: (%2d, %2d) Ref:%12f\n", m, n, - // D_user_data[m * N + n]); - // } - } - } - if (to_throw) { - throw status::runtime_error; - } - - return 0; - - // Create BRGeMM ukernel objects. - // There are two objects: - // * `brg` is the main one which operates over partitioned K dimension. It - // utilizes `beta = 1.f` to accumulate into the same buffer. It also uses - // `batch_size` to process as much as `n_calls - 1` iterations. - // * `brg_po` is the ukernel that would be called the last in the chain - // since it has attributes attached to the object and those will execute - // after all accumulation over K dimension is done. - // Note: `beta = 1.f` makes a ukernel reusable over K but will require - // zeroing the correspondent piece of accumulation buffer. - brgemm brg; - if (batch_size > 0) { - try { - // Construct a basic brgemm object. - brg = brgemm(M, N, K_k, batch_size, lda, ldb, ldc, a_dt, b_dt, c_dt); - // Instruct the kernel to append the result to C tensor. - brg.set_add_C(true); - // Finalize the initialization. - brg.finalize(); - // Generate the executable JIT code for the objects. - brg.generate(); - } catch (error &e) { - // on any other error just re-throw - throw; - } - } - - // Query a scratchpad size and initialize a scratchpad buffer if the ukernel - // is expecting it. This is a service space needed, has nothing in common - // with accumulation buffer. - size_t scratchpad_size = brg.get_scratchpad_size(); - std::vector scratchpad(scratchpad_size); - - uint16_t *B_blocked = nullptr; - size_t blocked_B_size = 0; - - void *B_base_ptr = B_ptr; - - // Query the packing requirement from the kernel. It's enough to query - // packing requirements from a single object as long as only dimension - // settings change between objects. - // Note: example uses the one that always present regardless of dimensions. - const bool need_pack = brg.get_B_pack_type() == pack_type::pack32; - - // If packing is needed, create a dedicated object for data transformation. - if (need_pack) { - // Packing B tensor routine. The BRGeMM ukernel expects B passed in a - // special VNNI format for low precision data types, e.g., bfloat16_t. - // Note: the routine doesn't provide a `batch_size` argument in the - // constructor as it can be either incorporated into `K` dimension, or - // manually iterated over in a for-loop on the user side. - transform pack_B(/* K = */ K_k, /* N = */ N, - /* in_pack_type = */ pack_type::no_trans, /* in_ld = */ N, - /* out_ld = */ ldb, /* in_dt = */ b_dt, - /* out_dt = */ b_dt); - - // Size of the packed tensor. - blocked_B_size = ldb * K_k * memory::data_type_size(b_dt); - - B_blocked = new uint16_t[blocked_B_size * n_calls]; - B_base_ptr = B_blocked; - - // Pack B routine execution. - // Note: usually should be split to process only that part of B that the - // ukernel will execute. - pack_B.generate(); - - pack_B.execute(B_ptr, B_blocked); - } - - // BRGeMM ukernel execute section. - // Prepare buffers for execution. - std::vector> A_B_offsets(batch_size); - for (memory::dim i = 0; i < batch_size; i++) { - const memory::dim A_offset_i = i * K_k * a_dt_size; - const memory::dim B_offset_i = - need_pack ? i * blocked_B_size : i * N * K_k * b_dt_size; - A_B_offsets[i] = std::make_pair(A_offset_i, B_offset_i); - } - - if (brg) { - std::cout << "brg with bs: " << batch_size << " should be called.\n"; - // Make an object to call HW specialized routines. For example, prepare - // AMX unit. - brg.set_hw_context(); - - // An execute call. `A_B` is a vector of pointers to A and packed B - // tensors. `acc_ptr` is a pointer to an accumulator buffer. - brg.execute(A_ptr, B_base_ptr, A_B_offsets, C_ptr, scratchpad.data()); - - printf("( m, n) val after \"brg\" call\n"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - printf("Acc buffer(C_ptr) res: (%2d, %2d) Ref:%12f\n", m, n, - C_ptr[m * N + n]); - if (scratchpad.size() != 0) { - printf("Out buffer res: (%2d, %2d) Ref:%12f\n", m, n, - scratchpad[m * N + n]); - } - } - } - } - - // Once all computations are done, need to release HW context. - brgemm::release_hw_context(); - - // Clean up an extra buffer. - delete B_blocked; - - // Used for verification results, need unconditional reorder. - auto user_D_mem = memory(D_f32_md, engine, D_data.data()); - reorder(C_mem, user_D_mem).execute(engine_stream, C_mem, user_D_mem); - - // A simplified fast verification that ukernel returned expected results. - // Note: potential off-by-1 or 2 errors may pop up. This could be solved - // with more sparse filling. - // bool to_throw = false; - printf("( m, n, k) val\n"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - D_user_data[m * N + n] = 0; - for (int k = 0; k < K; k++) { - printf("(%2d, %2d, %2d) A: %12f B: %12f\n", m, n, k, - A_user_data[m * K + k], B_user_data[k * N + n]); - D_user_data[m * N + n] += - A_user_data[m * K + k] * B_user_data[k * N + n]; - } - const float diff = fabsf(D_user_data[m * N + n] - D_data[m * N + n]); - if (diff > 1.19e-7) { - to_throw = true; - if (true) { - printf("Error: (%2d, %2d) Ref:%12g Got:%12g Diff:%12g\n", m, n, - D_user_data[m * N + n], D_data[m * N + n], diff); - } - } else { - printf("Matched res: (%2d, %2d) Ref:%12f\n", m, n, - D_user_data[m * N + n]); - } - } - } - if (to_throw) { - throw status::runtime_error; - } - - return 0; -} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index b5a7b7f8c09dc..471899d689d33 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -164,7 +164,8 @@ def make_tttcir(self, mod, metadata, opt): cpu.passes.ttcpuir.add_optimize_masks(pm) passes.common.add_canonicalizer(pm) cpu.passes.ttcpuir.add_loop_invariant_code_motion(pm) - cpu.passes.ttcpuir.add_convert_dot_to_onednn(pm) + if cpu.onednn_available(): + cpu.passes.ttcpuir.add_convert_dot_to_onednn(pm, True) convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features) if convert_bf16_dot_product: @@ -206,7 +207,8 @@ def make_llir(self, src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_onednn_ops_to_llvmir(pm) + if cpu.onednn_available(): + cpu.passes.ttcpuir.add_onednn_ops_to_llvmir(pm, True) cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) cpu.passes.ttcpuir.add_expand_strided_metadata(pm) cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) @@ -272,7 +274,7 @@ def make_so(src, metadata, options): asm_path = os.path.join(tmpdir, "kernel.s") Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs - libs = ["m", "dnnl", "TritonCPURuntime", "sleef"] + libs = ["m", "TritonCPURuntime", "sleef"] so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 2eb002f7c95f6..3308fd23c6803 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -27,10 +27,6 @@ library_dirs = [_triton_C_dir] libraries = ["stdc++"] -# include_dirs = [os.path.join(_dirname, "include")] -# library_dirs = [os.path.join(_dirname, "lib"), _triton_C_dir] -# libraries = ["stdc++", "TritonCPURuntime"] - # Skip non-existent paths sys_include_dir = os.path.join(_dirname, "include") if os.path.exists(sys_include_dir): @@ -76,13 +72,10 @@ def __init__(self): pass def load_binary(self, name, kernel, shared_mem, device): - # lib_path = "/home/jovyan/triton-cpu/libkernel.so" with tempfile.NamedTemporaryFile(mode="wb", suffix=".so") as f: - # with open(lib_path, "wb") as f: f.write(kernel) f.flush() import ctypes - #print("load library: ", f.name) lib = ctypes.cdll.LoadLibrary(f.name) fn_ptr = getattr(lib, name) fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 7c1f13b738181..ed47934d244c0 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -33,6 +33,8 @@ std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); std::unique_ptr> createOneDNNOpsToLLVMPass(); std::unique_ptr> +createOneDNNOpsToLLVMPass(bool canReplace); +std::unique_ptr> createMathToVecLibPass(VecLib lib = VecLib::Sleef, std::set cpu_features = {}); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 96449bdbf2990..2361812cae908 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -69,6 +69,12 @@ def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { def OneDNNOpsToLLVM : Pass<"triton-cpu-onednn-ops-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Triton OneDNN operations to LLVM."; let description = [{}]; + let options = [ + Option<"canReplace", "replace", + "bool", /*default*/"false", + "Use OneDNN uKernel for matmul."> + ]; + let constructor = "mlir::triton::cpu::createOneDNNOpsToLLVMPass()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 4daec83d193f5..c75133c7c0e22 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -41,6 +41,8 @@ std::unique_ptr> createConvertDotGeneric(); std::unique_ptr> createCanonicalize(); std::unique_ptr> createConvertDotToOneDNN(); +std::unique_ptr> +createConvertDotToOneDNN(bool canReplace); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index fd58d3aef2388..b75caf1c9adfa 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -143,15 +143,9 @@ def ConvertDotToOneDNN : Pass<"triton-cpu-convert-dot-to-onednn", "mlir::ModuleO }]; let options = [ - Option<"convertInt8", "convert-i8", - "bool", /*default*/"true", - "Use AMX extensions for int8 type.">, - Option<"convertFp16", "convert-fp16", - "bool", /*default*/"true", - "Use AMX extensions for ifp16 type.">, - Option<"convertBf16", "convert-bf16", - "bool", /*default*/"true", - "Use AMX extensions for bf16 type.">, + Option<"canReplace", "replace", + "bool", /*default*/"false", + "Use OneDNN uKernel for matmul."> ]; let constructor = "mlir::triton::cpu::createConvertDotToOneDNN()"; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 9e37c855ee612..cec27a02b4039 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -16,3 +16,5 @@ add_triton_library(TritonCPUToLLVM LINK_LIBS PUBLIC MLIRVectorToLLVMPass ) + +set_source_files_properties(OneDNNOpsToLLVM.cpp PROPERTIES COMPILE_FLAGS "-Wall") diff --git a/third_party/cpu/lib/TritonCPUToLLVM/OneDNNOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/OneDNNOpsToLLVM.cpp index 16a6ff4f0eb6f..2a1428b3abc10 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/OneDNNOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/OneDNNOpsToLLVM.cpp @@ -13,10 +13,16 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_types.h" +#endif + namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_ONEDNNOPSTOLLVM #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -26,10 +32,43 @@ using namespace mlir::triton::cpu; namespace { +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_config.h" +#endif +void assert_on_onednn_missing() { +#if !defined(DNNL_EXPERIMENTAL_UKERNEL) + assert(false && "No OneDNN with uKernels available. Pass will be redundant."); +#endif +} + +inline Value intLLVMConst(Location loc, Type ty, int64_t val, + PatternRewriter &rewriter) { + return rewriter.create( + loc, IntegerAttr::get(getElementTypeOrSelf(ty), val)); +} + +static inline int64_t getDnnlDataTypeVal(Type ty) { +#if defined(DNNL_EXPERIMENTAL_UKERNEL) + ty = getElementTypeOrSelf(ty); + if (ty.isF32()) + return static_cast(dnnl_f32); + if (ty.isF64()) + return static_cast(dnnl_f64); + if (ty.isBF16()) + return static_cast(dnnl_bf16); + if (ty.isF16()) + return static_cast(dnnl_f16); +#endif + llvm_unreachable("Unexpected type for conversion to DNNL type."); +} + class TritonLLVMConversionTarget : public ConversionTarget { public: explicit TritonLLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + // addIllegalOp(); addLegalDialect(); addLegalOp(); } @@ -63,30 +102,19 @@ struct TransformCreateConversion matchAndRewrite(TransformCreate transformOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = transformOp.getLoc(); - auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); std::string dispatchName = "create_transform_ukernel"; - // Value k_int = transformOp.getK(); - // auto val_ty = k_int.getType(); - // if (val_ty.isIndex()) { - // Value k_int = - // rewriter.create(loc, i64_ty, transformOp.getK()) - // .getResult(); - // } - - // llvm::errs() << "k int: " << k_int << "\n"; - + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + auto inDnnType = intLLVMConst( + loc, integer64, getDnnlDataTypeVal(transformOp.getInDt()), rewriter); + auto outDnnType = intLLVMConst( + loc, integer64, getDnnlDataTypeVal(transformOp.getOutDt()), rewriter); auto transformArgs = SmallVector{ - adaptor.getK(), adaptor.getN(), adaptor.getInLd(), - adaptor.getOutLd(), adaptor.getInDt(), adaptor.getOutDt()}; + adaptor.getK(), adaptor.getN(), adaptor.getInLd(), + adaptor.getOutLd(), inDnnType, outDnnType}; auto transformArgTypes = SmallVector{i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty}; - // auto transformArgTypes = SmallVector{ - // transformOp.getK().getType(), transformOp.getN().getType(), - // transformOp.getInLd().getType(), transformOp.getOutLd().getType(), - // transformOp.getInDt().getType(), transformOp.getOutDt().getType()}; auto dispatched = LLVM::createLLVMCallOp( rewriter, loc, @@ -95,16 +123,6 @@ struct TransformCreateConversion getTypeConverter()->convertType(transformOp.getResult().getType())), transformArgs); - // transformOp.getResult().replaceAllUsesWith(dispatched.getResult()); - // llvm::errs() << "dispatched llvm call: " << dispatched << "\n"; - - auto mod = transformOp->getParentOfType(); - - // llvm::errs() << "[Fail] Mod op: " - // << "=============================\n" - // << mod << "\n =============================\n"; - - // rewriter.replaceAllOpUsesWith(transformOp, dispatched.getResult()); rewriter.replaceOp(transformOp, dispatched.getResult()); return success(); }; @@ -116,13 +134,10 @@ struct TransformCallConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(TransformCall transformOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // llvm::errs() << "invoke orig op call: " << transformOp << "\n"; auto loc = transformOp.getLoc(); auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - // std::string dispatchName = "create_transform_ukernel"; std::string invokeName = "call_transform"; auto transformArgs = @@ -136,7 +151,6 @@ struct TransformCallConversion : public ConvertOpToLLVMPattern { rewriter, loc, getFuncDecl(rewriter, invokeName, transformArgTypes, void_ty(ctx)), transformArgs); - // llvm::errs() << "invoked llvm call: " << dispatched << "\n"; rewriter.replaceOp(transformOp, dispatched); return success(); @@ -150,47 +164,23 @@ struct BrgemmCreateConversion : public ConvertOpToLLVMPattern { matchAndRewrite(BrgemmCreate brgemmOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = brgemmOp.getLoc(); - auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - - if (brgemmOp.getResult().getUses().empty()) { - // llvm::errs() << "!!!!!!!!!uses empty!!!!!!!!!\n"; - auto mod = brgemmOp->getParentOfType(); - - // llvm::errs() << "[Fail] Brgemm op: " - // << "=============================\n" - // << mod << "\n =============================\n"; - } + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); std::string dispatchName = "create_brgemm_ukernel"; - // Value batch_size_int = brgemmOp.getBatchSize(); - // auto val_ty = batch_size_int.getType(); - // if (val_ty.isIndex()) { - // Value batch_size_int = // brgemmOp.getOperand(3); - // rewriter.create(loc, i64_ty, - // brgemmOp.getOperand(3)) - // .getResult(); - // } - - // llvm::errs() << "bs size: " << batch_size_int << "\n"; - - // Value ldc_int = brgemmOp.getLdc(); - // val_ty = ldc_int.getType(); - // if (val_ty.isIndex()) { - // Value ldc_int = // brgemmOp.getOperand(6); - // rewriter.create(loc, i64_ty, - // brgemmOp.getOperand(6)) - // .getResult(); - // } - - // llvm::errs() << "ldc: " << ldc_int << "\n"; - - auto brgemmArgs = SmallVector{ - adaptor.getM(), adaptor.getN(), adaptor.getKK(), - adaptor.getBatchSize(), adaptor.getLda(), adaptor.getLdb(), - adaptor.getLdc(), adaptor.getDtypeA(), adaptor.getDtypeB(), - adaptor.getDtypeC()}; + auto lhsDnnType = intLLVMConst( + loc, integer64, getDnnlDataTypeVal(adaptor.getDtypeA()), rewriter); + auto rhsDnnType = intLLVMConst( + loc, integer64, getDnnlDataTypeVal(adaptor.getDtypeB()), rewriter); + auto accDnnType = intLLVMConst( + loc, integer64, getDnnlDataTypeVal(adaptor.getDtypeC()), rewriter); + + auto brgemmArgs = + SmallVector{adaptor.getM(), adaptor.getN(), + adaptor.getKK(), adaptor.getBatchSize(), + adaptor.getLda(), adaptor.getLdb(), + adaptor.getLdc(), lhsDnnType, + rhsDnnType, accDnnType}; SmallVector brgemmArgTypes{i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, i64_ty}; @@ -201,17 +191,6 @@ struct BrgemmCreateConversion : public ConvertOpToLLVMPattern { getTypeConverter()->convertType(brgemmOp.getResult().getType())), brgemmArgs); - // brgemmOp.getResult().replaceAllUsesWith(dispatched.getResult()); - - // llvm::errs() << "brgem res: " << brgemmOp.getResult() << "\n"; - // llvm::errs() << "brgemm result uses: \n"; - // for (auto &use : brgemmOp.getResult().getUses()) { - // llvm::errs() << "\t" << use.get() << "\n"; - //} - // llvm::errs() << "uses done ------ \n"; - // llvm::errs() << "brgemm dispatched llvm call: " << dispatched << "\n"; - // rewriter.replaceAllOpUsesWith(brgemmOp, dispatched.getResult()); - rewriter.replaceOp(brgemmOp, dispatched.getResult()); return success(); }; @@ -223,20 +202,10 @@ struct BrgemmCallConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(BrgemmCall brgemmOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // llvm::errs() << "invoke orig op call: " << brgemmOp << "\n"; auto loc = brgemmOp.getLoc(); auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - // auto brgem_kernel_params_op = - // adaptor.getKernelHash() - // .getDefiningOp(); - // if (brgem_kernel_params_op == nullptr) { - // return failure(); - // } - - // std::string dispatchName = "create_Brgemm_ukernel"; std::string invokeName = "call_brgemm"; auto kernel_hash_ptr = rewriter.create( @@ -251,11 +220,7 @@ struct BrgemmCallConversion : public ConvertOpToLLVMPattern { adaptor.getStepA(), adaptor.getStepB(), adaptor.getNumBatches()}; - // auto unranked = - // getTypeConverter()->convertType(brgemmOp.getOperand(0).getType()); - // auto brgemmArgTypes = SmallVector{ - // unranked, unranked, unranked, unranked, unranked, - // }; + auto brgemmArgTypes = SmallVector{ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), i64_ty, i64_ty, i64_ty}; @@ -264,7 +229,6 @@ struct BrgemmCallConversion : public ConvertOpToLLVMPattern { rewriter, loc, getFuncDecl(rewriter, invokeName, brgemmArgTypes, void_ty(ctx)), brgemmArgs); - // llvm::errs() << "invoked llvm call: " << dispatched << "\n"; rewriter.replaceOp(brgemmOp, dispatched); return success(); @@ -278,20 +242,10 @@ struct CallBrgemmWithTransformConversion LogicalResult matchAndRewrite(CallBrgemmWithTransform brgemmOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // llvm::errs() << "invoke orig op call: " << brgemmOp << "\n"; auto loc = brgemmOp.getLoc(); auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - - // auto brgem_kernel_params_op = - // adaptor.getKernelHash() - // .getDefiningOp(); - // if (brgem_kernel_params_op == nullptr) { - // return failure(); - // } - // std::string dispatchName = "create_Brgemm_ukernel"; std::string invokeName = "call_all"; auto tf_kernel_hash_ptr = rewriter.create( @@ -305,25 +259,19 @@ struct CallBrgemmWithTransformConversion MemRefDescriptor(adaptor.getAPtr()).alignedPtr(rewriter, loc), MemRefDescriptor(adaptor.getBPtr()).alignedPtr(rewriter, loc), MemRefDescriptor(adaptor.getCPtr()).alignedPtr(rewriter, loc), - MemRefDescriptor(adaptor.getScratchpad()).alignedPtr(rewriter, loc), adaptor.getStepA(), adaptor.getStepB(), adaptor.getBlockedBsize(), adaptor.getNumBatches()}; - // auto unranked = - // getTypeConverter()->convertType(brgemmOp.getOperand(0).getType()); - // auto brgemmArgTypes = SmallVector{ - // unranked, unranked, unranked, unranked, unranked, - // }; + auto brgemmArgTypes = SmallVector{ ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), - ptr_ty(ctx), i64_ty, i64_ty, i64_ty, i64_ty}; + i64_ty, i64_ty, i64_ty, i64_ty}; auto dispatched = LLVM::createLLVMCallOp( rewriter, loc, getFuncDecl(rewriter, invokeName, brgemmArgTypes, void_ty(ctx)), brgemmArgs); - // llvm::errs() << "invoked llvm call: " << dispatched << "\n"; rewriter.replaceOp(brgemmOp, dispatched); return success(); @@ -336,13 +284,9 @@ struct ConfigureHWConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(ConfigureHW configureHwOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // llvm::errs() << "invoke orig op call: " << configureHwOp << "\n"; - auto loc = configureHwOp.getLoc(); auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - // std::string dispatchName = "create_Brgemm_ukernel"; std::string invokeName = "prepare_hw_context"; auto configureArgs = SmallVector{adaptor.getBrgemmKernelHash()}; @@ -353,7 +297,6 @@ struct ConfigureHWConversion : public ConvertOpToLLVMPattern { rewriter, loc, getFuncDecl(rewriter, invokeName, configureArgTypes, void_ty(ctx)), configureArgs); - // llvm::errs() << "invoked llvm call: " << dispatched << "\n"; rewriter.replaceOp(configureHwOp, dispatched); return success(); @@ -366,13 +309,9 @@ struct ReleaseHWConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(ReleaseHW releaseHwOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // llvm::errs() << "invoke orig op call: " << releaseHwOp << "\n"; - auto loc = releaseHwOp.getLoc(); auto ctx = rewriter.getContext(); - auto typeConverter = getTypeConverter(); - // std::string dispatchName = "create_Brgemm_ukernel"; std::string invokeName = "release_hw_context"; SmallVector releaseArgs{}; @@ -382,7 +321,6 @@ struct ReleaseHWConversion : public ConvertOpToLLVMPattern { rewriter, loc, getFuncDecl(rewriter, invokeName, releaseArgTypes, void_ty(ctx)), releaseArgs); - // llvm::errs() << "invoked llvm call: " << dispatched << "\n"; rewriter.replaceOp(releaseHwOp, dispatched); return success(); @@ -390,12 +328,16 @@ struct ReleaseHWConversion : public ConvertOpToLLVMPattern { }; struct OneDNNOpsToLLVM - : public triton::impl::OneDNNOpsToLLVMBase { - using OneDNNOpsToLLVMBase::OneDNNOpsToLLVMBase; - - OneDNNOpsToLLVM() : OneDNNOpsToLLVMBase() {} + : public triton::cpu::impl::OneDNNOpsToLLVMBase { + OneDNNOpsToLLVM() = default; + OneDNNOpsToLLVM(bool canReplace) { this->canReplace = canReplace; } void runOnOperation() override { + if (!canReplace) { + LDBG("Pass disabled."); + return; + } + MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -404,27 +346,15 @@ struct OneDNNOpsToLLVM TritonLLVMConversionTarget conversionTarget(*context); RewritePatternSet patterns(context); - // mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, - // patterns); patterns.add(typeConverter); - // patterns.add(typeConverter); - // patterns.add(typeConverter); if (failed(applyPartialConversion(mod, conversionTarget, std::move(patterns)))) { - - // llvm::errs() << "[Fail] Mod op: " - // << "=============================\n" - // << mod << "\n =============================\n"; return signalPassFailure(); - } else { - // llvm::errs() << "[Succ] Mod op: " - // << "=============================\n" - // << mod << "\n =============================\n"; } } }; @@ -437,4 +367,12 @@ std::unique_ptr> createOneDNNOpsToLLVMPass() { return std::make_unique(); } +std::unique_ptr> +createOneDNNOpsToLLVMPass(bool isReplacementToOneDnnPossible) { + if (isReplacementToOneDnnPossible) { + assert_on_onednn_missing(); + } + return std::make_unique(isReplacementToOneDnnPossible); +} + } // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index 94077246f862b..91627a9edb817 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -4,9 +4,9 @@ add_triton_library(TritonCPUTransforms ConvertDotOp/ConvertDotGeneric.cpp ConvertDotOp/ConvertDotToAMX.cpp ConvertDotOp/ConvertDotToFMA.cpp + ConvertDotOp/ConvertDotToOneDNN.cpp Canonicalize.cpp ConvertDotProduct.cpp - ConvertDotToOneDNN.cpp ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp OptimizeMasks.cpp @@ -15,4 +15,4 @@ add_triton_library(TritonCPUTransforms TritonCPUTransformsPassIncGen ) -set_source_files_properties(ConvertDotToOneDNN.cpp PROPERTIES COMPILE_FLAGS "-Wall") +set_source_files_properties(ConvertDotOp/ConvertDotToOneDNN.cpp PROPERTIES COMPILE_FLAGS "-Wall") diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp index 5de96fea9417e..1425aca3aa938 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -29,9 +29,6 @@ bool isLoopCarriedAcc(Value acc) { return false; } - // We don't need this I guess - // blockArg.getArgNumber(); - Value updAcc = acc.getUsers().begin()->getResult(0); if (!updAcc.hasOneUse()) { LDBG(" No. Has multiple uses."); @@ -158,8 +155,9 @@ Value maybeCast(Location loc, Value val, Type dstElemTy, return rewriter.create(loc, dstTy, val); } -MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter) { +MemBuffer allocateTmpBufferStack(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocaPoint); auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); @@ -170,6 +168,19 @@ MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, return {memRef, indices}; } +MemBuffer allocateTmpBufferHeap(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(allocaPoint); + auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + Value memRef = rewriter.create( + loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(2, zeroIdx); + return {memRef, indices}; +} + Value shiftIndex(Location loc, Value index, int64_t offs, PatternRewriter &rewriter) { if (!offs) @@ -191,8 +202,8 @@ MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, PatternRewriter &rewriter) { LDBG("Storing vector to a temporary buffer: " << val); auto vecTy = cast(val.getType()); - MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); - rewriter.create(loc, val, buf.memRef, buf.indices); + MemBuffer buf = allocateTmpBufferStack(loc, vecTy, allocaPoint, rewriter); + op_write(val, buf.memRef, buf.indices); return buf; } diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h index a88a00f83af64..b84470351f479 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -63,8 +63,14 @@ Value maybeCast(Location loc, Value val, Type dstElemTy, PatternRewriter &rewriter); // Allocate temporary buffer on stack for specified vector type. -MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter); +MemBuffer allocateTmpBufferStack(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter); + +// Allocate temporary buffer on heap for specified vector type. +MemBuffer allocateTmpBufferHeap(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter); // Move index by specified offset. Do constannt folding if possible. Value shiftIndex(Location loc, Value index, int64_t offs, diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 1b6dd9269ac1e..6a0ee661a81b0 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -213,18 +213,6 @@ void setupBlockAndTileSizes(ArrayRef lhsShape, candidate.tilesInBlockN = accBlocksN; } -// Check if vector transfer read/write operation uses a mask -// or involves a bounds check. -template bool hasMaskOrBoundsCheck(T op) { - auto inBounds = op.getInBounds(); - Value mask = op.getMask(); - bool hasBoundsCheck = - std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { - return !cast(attr).getValue(); - }); - return hasBoundsCheck || mask; -} - // Check if a value is used only for a store and that this store can be // replaced with tile stores. In this case fill appropriate fields in the // candidate structure. @@ -407,8 +395,8 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, if (interleave) { LDBG(" Copying from the original memref with interleave: " << inputBuf.memRef); - auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), - allocaPoint, rewriter); + auto tmpBuf = allocateTmpBufferStack(loc, getSwizzledRhsTileType(vecTy), + allocaPoint, rewriter); copyWithInterleave(loc, vecTy, inputBuf, tmpBuf, rewriter); return tmpBuf; } @@ -423,7 +411,7 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, if (interleave) vecTy = getSwizzledRhsTileType(vecTy); - MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + MemBuffer buf = allocateTmpBufferStack(loc, vecTy, allocaPoint, rewriter); if (interleave) { interleaveAndStore(loc, val, buf.memRef, rewriter); @@ -451,8 +439,8 @@ MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, } LDBG("Allocating buffer for the result."); - return allocateTmpBuffer(loc, cast(val.getType()), allocaPoint, - rewriter); + return allocateTmpBufferStack(loc, cast(val.getType()), + allocaPoint, rewriter); } SmallVector shiftIndices(Location loc, ArrayRef indices, diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToOneDNN.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToOneDNN.cpp similarity index 58% rename from third_party/cpu/lib/TritonCPUTransforms/ConvertDotToOneDNN.cpp rename to third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToOneDNN.cpp index f52ec1b867d05..0025a181faacd 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToOneDNN.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToOneDNN.cpp @@ -1,32 +1,12 @@ -// save diagnostic state -// #pragma GCC diagnostic push +#include "ConvertDotCommon.h" -// turn off the specific warning. Can also use "-Wall" -// #pragma GCC diagnostic ignored "-Wall" -// #pragma GCC diagnostic ignored "-Weffc++" -// #pragma GCC diagnostic ignored "-pedantic" - -#include "cpu/include/TritonCPUTransforms/OptCommon.h" - -#include "cpu/include/Analysis/TensorPtrShapeInfo.h" #include "cpu/include/TritonCPUTransforms/Passes.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" - -#include "include/triton/Analysis/Utility.h" -#include "oneapi/dnnl/dnnl_types.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -#include "ConvertDotOp/ConvertDotCommon.h" - #include #include #include @@ -39,9 +19,7 @@ namespace cpu { } // namespace cpu } // namespace triton } // namespace mlir -// #pragma GCC diagnostic pop -// #define DEBUG_TYPE "triton-cpu-dot-to-onednn" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -51,19 +29,6 @@ using namespace mlir::triton::cpu; namespace { -static inline int64_t getDnnlDataTypeVal(Type ty) { - ty = getElementTypeOrSelf(ty); - if (ty.isF32()) - return static_cast(dnnl_f32); - if (ty.isF64()) - return static_cast(dnnl_f64); - if (ty.isBF16()) - return static_cast(dnnl_bf16); - if (ty.isF16()) - return static_cast(dnnl_f16); - llvm_unreachable("Unexpected type for conversion to DNNL type."); -} - // This structure is used to hold candidates for conversion to ukernel calls. struct DotOpCandidate { // Operation to convert. @@ -79,31 +44,11 @@ struct DotOpCandidate { bool isAccLoopCarried = false; bool canFuseLoop = false; - // If output buffer is used then keep the original vector store here. - Operation *origStore = nullptr; - // If input data is available in memory then input buffers hold it. MemBuffer lhsBuf; MemBuffer rhsBuf; - // If result is written to a memory, then we can use it directly for - // ukernel calls. - MemBuffer outBuf; }; -// Check if vector transfer read/write operation uses a mask -// or involves a bounds check. -template bool hasMaskOrBoundsCheck(T op) { - auto inBounds = op.getInBounds(); - Value mask = op.getMask(); - bool hasBoundsCheck = - std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { - return !cast(attr).getValue(); - }); - llvm::errs() << "mask: " << mask << " bounds check: " << hasBoundsCheck - << "\n"; - return hasBoundsCheck || mask; -} - bool isLoopInvariant(SmallVector vals, LoopLikeOpInterface loopLike) { for (Value val : vals) { LDBG("Checking value for invariance: " << val); @@ -152,13 +97,10 @@ bool checkInputShapes(VectorType lhsTy, VectorType resTy, return true; } -// Check if specified ContractionOp can be lowered to AMX operations. +// Check if specified ContractionOp can be lowered to OneDNN ukernel operations. // If conversion is possible, then true is returned and candidate // structure is filled with detailed transformation info. -bool isOneDNNCandidate(triton::cpu::DotOp op, bool supportInt8, - bool supportFp16, bool supportBf16, - DotOpCandidate &candidate) { - // MLIRContext *ctx = op.getContext(); +bool isOneDNNCandidate(triton::cpu::DotOp op, DotOpCandidate &candidate) { VectorType lhsTy = cast(op.getA().getType()); VectorType rhsTy = cast(op.getB().getType()); VectorType accTy = cast(op.getC().getType()); @@ -188,129 +130,19 @@ bool isOneDNNCandidate(triton::cpu::DotOp op, bool supportInt8, // Check if we can fuse dot op loop into a single brgemm call. if (candidate.isAccLoopCarried && !candidate.lhsBuf.step.empty() && !candidate.rhsBuf.step.empty()) { - SmallVector valsToCheck; - valsToCheck.append(candidate.lhsBuf.step); - valsToCheck.append(candidate.rhsBuf.step); + SmallVector valsToCheckInvariance; + valsToCheckInvariance.append(candidate.lhsBuf.step); + valsToCheckInvariance.append(candidate.rhsBuf.step); auto forOp = dyn_cast(op->getParentOp()); - candidate.canFuseLoop = isLoopInvariant(valsToCheck, forOp); + candidate.canFuseLoop = isLoopInvariant(valsToCheckInvariance, forOp); } - - // We don't need this I guess - // findOutputBuffer(op.getResult(), candidate); - return true; } -MemBuffer allocateTmpBufferLocal(Location loc, VectorType vecTy, - Operation *allocaPoint, - PatternRewriter &rewriter) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(allocaPoint); - auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); - Value memRef = rewriter.create( - loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); - Value zeroIdx = rewriter.create(loc, 0); - SmallVector indices(2, zeroIdx); - return {memRef, indices}; -} - -// Prepare temporary buffers to be used for tile loads. If the original -// value can be directly loaded to tiles from its original memory, then -// use it instead. Return empty buffer if source value is all zeros and -// skipForZeros is set. -// -// If interleave flag is set, then pre-pack RHS before sotring. See -// interleaveAndStore for more details. -MemBuffer prepareTensorBuffer(PatternRewriter &rewriter, Location loc, - Value val, - memref::ExtractStridedMetadataOp metadata, - bool readOnly, Operation *allocaPoint, - Value transform = nullptr) { - LDBG("Preparing buffer (interleave=" << (transform == nullptr) - << ") for a vector: " << val - << " readOnly: " << readOnly); - auto valLoad = val.getDefiningOp(); - // some extra conditions required - if (valLoad) { - LDBG("Lhs should take src memref!\n"); - Value memRef = valLoad.getSource(); - ValueRange indices = valLoad.getIndices(); - if (!transform) { - LDBG(" Reusing the original memref for a buffer: " << memRef); - auto vecTy = cast(val.getType()); - // auto memRefTy = MemRefType::get(vecTy.getShape(), - // vecTy.getElementType()); - auto ctx = rewriter.getContext(); - // Value memRef_view = rewriter.create( - // loc, memRefTy, memRef, {0, 0}, indices, - // metadata.getStrides()); - SmallVector strides(vecTy.getRank(), 1); - - Value memRef_view = rewriter.create( - loc, memRef, getAsOpFoldResult(indices), - getAsIndexOpFoldResult(ctx, vecTy.getShape()), - getAsIndexOpFoldResult(ctx, strides)); - return {memRef_view, indices}; - } - - // Just a stub, dont know why we should use Vector TYpe here - // auto vecTy = cast(val.getType()); - // auto transf_op = transform.getDefiningOp(); - - assert("Used( " && false); - // MemBuffer buf = allocateTmpBuffer( - // loc, - // {transf_op.getBlockK(), transf_op.getOutLd(), transf_op.getNCalls()}, - // vecTy.getElementType(), allocaPoint, rewriter); - - // LDBG(" Reusing the original memref for a buffer: " << memRef); - // transformIntoBuf(loc, transform, memRef, buf.memRef, rewriter); - // return buf; - } - - auto vecTy = cast(val.getType()); - MemBuffer buf = allocateTmpBufferLocal(loc, vecTy, allocaPoint, rewriter); - - if (transform) { - LDBG("Unhandled case for transform: " << val); - assert(false); - return {}; - } - auto rank = dyn_cast(buf.memRef.getType()).getRank(); - SmallVector inBounds(rank, false); - rewriter.create(loc, val, buf.memRef, buf.indices, - inBounds); - - return buf; -} - -// Return a buffer where the final result should be stored. If result can -// be directly stored to the output memory, then it is used as an output -// buffer. Otherwise, re-use accumulator buffer or create a new one. -MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, - const MemBuffer &outBuf, Operation *allocaPoint, - PatternRewriter &rewriter) { - if (!outBuf.empty()) { - LDBG("Output memory will be used for direct tile stores. Outbuf in " - "candidate: " - << outBuf.memRef.getType()); - return outBuf; - } - - if (!accBuf.empty()) { - LDBG("Result will be stored to accumulator buffer."); - return accBuf; - } - - LDBG("Allocating buffer for the result."); - return allocateTmpBufferLocal(loc, cast(val.getType()), - allocaPoint, rewriter); -} - -void replaceLoop(DotOpCandidate &candidate, - ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis, - PatternRewriter &rewriter) { +void replaceWholeLoop(DotOpCandidate &candidate, + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis, + PatternRewriter &rewriter) { triton::cpu::DotOp op = candidate.op; Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -332,13 +164,9 @@ void replaceLoop(DotOpCandidate &candidate, return rewriter.create(loc, memRefTy, ptr); }; - // VectorType lhsTy = cast(candidate.op.getA().getType()); - // VectorType rhsTy = cast(candidate.op.getB().getType()); auto addSubView = [&](Value vecVal, ValueRange indices, Value memRef) { LDBG(" Reusing the original memref for a buffer: " << memRef); auto vecTy = cast(vecVal.getType()); - // auto memRefTy = MemRefType::get(vecTy.getShape(), - // vecTy.getElementType()); auto ctx = rewriter.getContext(); SmallVector strides(vecTy.getRank(), 1); @@ -351,7 +179,6 @@ void replaceLoop(DotOpCandidate &candidate, auto lhsTritonPtr = candidate.lhsBuf.origBlockPtr; auto lhsMemRef = extractMemref(lhsTritonPtr); - // auto lhsRank = dyn_cast(memRef.getType()).getRank(); auto lhsIndices = rewriter.create(loc, lhsTritonPtr).getResults(); auto lhsSubView = addSubView(candidate.op.getA(), lhsIndices, lhsMemRef); @@ -360,7 +187,6 @@ void replaceLoop(DotOpCandidate &candidate, auto rhsTritonPtr = candidate.rhsBuf.origBlockPtr; auto rhsMemRef = extractMemref(rhsTritonPtr); - // auto rhsRank = dyn_cast(memRef.getType()).getRank(); auto rhsIndices = rewriter.create(loc, rhsTritonPtr).getResults(); auto rhsSubView = addSubView(candidate.op.getB(), rhsIndices, rhsMemRef); @@ -422,13 +248,8 @@ convertCandidate(DotOpCandidate &candidate, PatternRewriter &rewriter) { triton::cpu::DotOp op = candidate.op; Location loc = op.getLoc(); - // MLIRContext *ctx = op.getContext(); - - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); VectorType resTy = cast(op.getResult().getType()); Type resElemTy = resTy.getElementType(); - // FloatType float32 = FloatType::getF32(rewriter.getContext()); - // IndexType indexType = rewriter.getIndexType(); scf::ForOp forOp = dyn_cast(op->getParentOp()); Value numBatches = index_cst(1); @@ -440,7 +261,7 @@ convertCandidate(DotOpCandidate &candidate, // add loop carried-dependencies for accumulator tiles and accInitTiles // will be used as initializers for them. rewriter.setInsertionPoint(forOp); - replaceLoop(candidate, shapeInfoAnalysis, rewriter); + replaceWholeLoop(candidate, shapeInfoAnalysis, rewriter); LDBG("Loading accumulator to tiles before the loop."); numBatches = op_divui(op_subi(forOp.getUpperBound(), forOp.getLowerBound()), @@ -448,18 +269,11 @@ convertCandidate(DotOpCandidate &candidate, numBatches = op_index_cast(rewriter.getIndexType(), numBatches); } - // If we don't work with a loop and want to directly store tiles into output - // memory, then use the original store as insertion point to have its buffer - // values available for generated code. - if (!candidate.isAccLoopCarried && !candidate.outBuf.empty()) - rewriter.setInsertionPoint(candidate.origStore); - Operation *allocaPoint = op; while (!isa(allocaPoint->getParentOp())) allocaPoint = allocaPoint->getParentOp(); - ModuleOp module = op->template getParentOfType(); - + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); auto blockM = int_cst(integer64, candidate.blockM); auto blockN = int_cst(integer64, candidate.blockN); auto blockK = int_cst(integer64, candidate.blockK); @@ -492,13 +306,12 @@ convertCandidate(DotOpCandidate &candidate, } // Currently, acc always needs to be FP32. accToStore = maybeCast(loc, accToStore, rewriter.getF32Type(), rewriter); - accBuf = - prepareTensorBuffer(rewriter, loc, accToStore, {}, false, allocaPoint); + auto vecTy = cast(accToStore.getType()); + // Using heap as resBuffer can be pretty big. + accBuf = allocateTmpBufferHeap(loc, vecTy, allocaPoint, rewriter); + op_write(accToStore, accBuf.memRef, accBuf.indices); } - MemBuffer resBuf = prepareResultBuffer( - loc, op.getResult(), accBuf, candidate.outBuf, allocaPoint, rewriter); - auto metadataA = rewriter.create( loc, candidate.lhsBuf.memRef); auto metadataB = rewriter.create( @@ -510,25 +323,20 @@ convertCandidate(DotOpCandidate &candidate, Value ldb = metadataB.getStrides()[metadataB.getStrides().size() - 2]; Value ldc = metadataAcc.getStrides()[metadataAcc.getStrides().size() - 2]; - auto lhsDnnType = int_cst(integer64, getDnnlDataTypeVal(op.getA().getType())); - auto rhsDnnType = int_cst(integer64, getDnnlDataTypeVal(op.getB().getType())); - auto accDnnType = - int_cst(integer64, getDnnlDataTypeVal(rewriter.getF32Type())); - Value brgemm = rewriter.create( loc, rewriter.getIndexType(), blockM, blockN, blockK, numBatches, lda, - ldb, ldc, lhsDnnType, rhsDnnType, accDnnType); + ldb, ldc, op.getA().getType(), op.getB().getType(), + rewriter.getF32Type()); auto rhsTypeSize = int_cst(integer64, op.getB().getType().getElementTypeBitWidth() / 8); Value rhsBlockSizeInBytes = op_muli(op_muli(blockN, blockK), rhsTypeSize); Value transform = rewriter.create( - loc, rewriter.getIndexType(), blockK, blockN, ldb, blockN, rhsDnnType, - rhsDnnType); + loc, rewriter.getIndexType(), blockK, blockN, ldb, blockN, + op.getB().getType(), op.getB().getType()); LDBG("[prepareResultBuffer] prepared acc buf: " << accBuf.memRef); - LDBG("[prepareResultBuffer] prepared res buf: " << resBuf.memRef); LDBG("lhsBuf: { memref " << candidate.lhsBuf.memRef << "\n " << " indices " << candidate.lhsBuf.indices.size() << "\n" @@ -549,24 +357,18 @@ convertCandidate(DotOpCandidate &candidate, rewriter.create( loc, transform, brgemm, candidate.lhsBuf.memRef, candidate.rhsBuf.memRef, - resBuf.memRef, accBuf.memRef, lhsStepInBytes, rhsStepInBytes, - rhsBlockSizeInBytes, numBatches); + accBuf.memRef, lhsStepInBytes, rhsStepInBytes, rhsBlockSizeInBytes, + numBatches); if (candidate.isAccLoopCarried && candidate.canFuseLoop) { LDBG("Loading the result to a vector to replace orig op result."); - auto rank = dyn_cast(resBuf.memRef.getType()).getRank(); + auto rank = dyn_cast(accBuf.memRef.getType()).getRank(); SmallVector inBounds(rank, false); Value newVal = rewriter.create( - loc, cast(toFp32(resTy)), resBuf.memRef, resBuf.indices, + loc, cast(toFp32(resTy)), accBuf.memRef, accBuf.indices, inBounds); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, resElemTy, rewriter); - // rewriter.replaceOp(op, newVal); - // rewriter.eraseOp(*op.getResult().user_begin()); - // rewriter.eraseOp(op); - LDBG("Printing module before replace:"); - LDBG(module); - rewriter.replaceOp(forOp, ValueRange{newVal, candidate.lhsBuf.memRef, candidate.rhsBuf.memRef}); return success(); @@ -574,10 +376,10 @@ convertCandidate(DotOpCandidate &candidate, if (candidate.isAccLoopCarried) { rewriter.setInsertionPointAfter(forOp); - auto rank = dyn_cast(resBuf.memRef.getType()).getRank(); + auto rank = dyn_cast(accBuf.memRef.getType()).getRank(); SmallVector inBounds(rank, false); Value newVal = rewriter.create( - loc, cast(toFp32(resTy)), resBuf.memRef, resBuf.indices, + loc, cast(toFp32(resTy)), accBuf.memRef, accBuf.indices, inBounds); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, resElemTy, rewriter); @@ -587,24 +389,13 @@ convertCandidate(DotOpCandidate &candidate, rewriter.replaceOp(op, op.getC()); return success(); } - if (candidate.outBuf.empty()) { - LDBG("Loading the result to a vector to replace orig op result."); - Value newVal = rewriter.create( - loc, cast(toFp32(resTy)), resBuf.memRef, resBuf.indices); - // We might need to cast back to the original type. - newVal = maybeCast(loc, newVal, resElemTy, rewriter); - op.getResult().replaceAllUsesWith(newVal); - rewriter.eraseOp(op); - // rewriter.replaceOp(op, newVal); - // rewriter.eraseOp(*op.getResult().user_begin()); - // rewriter.eraseOp(op); - LDBG("Printing module before replace:"); - LDBG(module); - } else { - LDBG("Removing original operation and its use."); - rewriter.eraseOp(*op.getResult().user_begin()); - rewriter.eraseOp(op); - } + LDBG("Loading the result to a vector to replace orig op result."); + Value newVal = rewriter.create( + loc, cast(toFp32(resTy)), accBuf.memRef, accBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, resElemTy, rewriter); + op.getResult().replaceAllUsesWith(newVal); + rewriter.eraseOp(op); return success(); } @@ -612,8 +403,13 @@ convertCandidate(DotOpCandidate &candidate, struct ConvertDotToOneDNN : public triton::cpu::impl::ConvertDotToOneDNNBase { ConvertDotToOneDNN() = default; + ConvertDotToOneDNN(bool canReplace) { this->canReplace = canReplace; } void runOnOperation() override { + if (!canReplace) { + LDBG("Pass disabled."); + return; + } MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -621,10 +417,9 @@ struct ConvertDotToOneDNN ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); SmallVector candidates; - mod->walk([this, &candidates](triton::cpu::DotOp op) { + mod->walk([&candidates](triton::cpu::DotOp op) { DotOpCandidate candidate; - if (isOneDNNCandidate(op, convertInt8, convertFp16, convertBf16, - candidate)) { + if (isOneDNNCandidate(op, candidate)) { LLVM_DEBUG({ LDBG("Found OneDNN candidate"); LDBG(" Op: " << candidate.op); @@ -633,7 +428,6 @@ struct ConvertDotToOneDNN LDBG(" blockK: " << candidate.blockK); LDBG(" isAccLoopCarried: " << candidate.isAccLoopCarried); LDBG(" canFuseLoop: " << candidate.canFuseLoop); - LDBG(" Has output buffer: " << (bool)!candidate.outBuf.empty()); }); candidates.push_back(candidate); } @@ -655,14 +449,15 @@ struct ConvertDotToOneDNN } // namespace -namespace mlir { -namespace triton { -namespace cpu { +namespace mlir::triton::cpu { std::unique_ptr> createConvertDotToOneDNN() { return std::make_unique(); } -} // namespace cpu -} // namespace triton -} // namespace mlir +std::unique_ptr> +createConvertDotToOneDNN(bool isReplacementToOneDnnPossible) { + return std::make_unique(isReplacementToOneDnnPossible); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/runtime/runtime_onednn.cpp b/third_party/cpu/runtime/runtime_onednn.cpp index d09b878329205..cbaeb272aec6f 100644 --- a/third_party/cpu/runtime/runtime_onednn.cpp +++ b/third_party/cpu/runtime/runtime_onednn.cpp @@ -1,16 +1,18 @@ -// #include -// #include -// #include -// #include - -#include -#include -#include +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_types.h" +#include "oneapi/dnnl/dnnl_ukernel.hpp" +#include "oneapi/dnnl/dnnl_ukernel_types.h" +#if !defined(DNNL_EXPERIMENTAL_UKERNEL) +#error "DNNL Ukerenel ismissing" +#endif +#endif +#include #include #include #include #include +#include #if defined(_MSC_VER) #define EXPORT __declspec(dllexport) @@ -19,11 +21,11 @@ #else #define EXPORT #endif -#include -// using namespace dnnl::impl::cpu::x64; +#if defined(ONEDNN_AVAILABLE) using namespace dnnl; using namespace dnnl::ukernel; +#endif namespace dnnl { namespace impl { @@ -44,20 +46,6 @@ using read_lock_guard_t = std::shared_lock; using write_lock_guard_t = std::unique_lock; static std::shared_mutex g_brgemm_lock; -// TODO(haixin): use syscall to determine page size? -static constexpr size_t SCRATCH_SIZE = 2 * 4096; -// TODO(haixin): need to use custom thread management for scratch in the future? -static thread_local char scratch[SCRATCH_SIZE] = {0}; - -namespace { -template struct RawMemRefDescriptor { - const T *allocated; - const T *aligned; - intptr_t offset; - intptr_t sizesAndStrides[]; -}; -} // namespace - extern "C" { EXPORT void *create_brgemm_ukernel(int64_t M, int64_t N, int64_t K_k, @@ -65,18 +53,16 @@ EXPORT void *create_brgemm_ukernel(int64_t M, int64_t N, int64_t K_k, int64_t ldc, int64_t dtypeA, int64_t dtypeB, int64_t dtypeC) { using KeyT = std::array; - std::cout << "Args: M - " << M << ", N - " << N << ", K - " << K_k - << ", batch - " << batch_size << ", lda - " << lda << ", ldb - " - << ldb << ", ldc - " << ldc << ", dtype a - " << dtypeA - << ", dtype b - " << dtypeB << ", dtype c - " << dtypeC << "\n"; + // std::cout << "Args: M - " << M << ", N - " << N << ", K - " << K_k + // << ", batch - " << batch_size << ", lda - " << lda << ", ldb - " + // << ldb << ", ldc - " << ldc << ", dtype a - " << dtypeA + // << ", dtype b - " << dtypeB << ", dtype c - " << dtypeC << "\n"; KeyT key{M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC}; static std::map savedUkernels; { read_lock_guard_t r_g(g_brgemm_lock); if (savedUkernels.count(key) != 0) { - // std::cout << "reused kernel: " << &savedUkernels.find(key)->second - // << "\n"; return &savedUkernels.find(key)->second; } } @@ -84,9 +70,6 @@ EXPORT void *create_brgemm_ukernel(int64_t M, int64_t N, int64_t K_k, write_lock_guard_t w_g(g_brgemm_lock); if (savedUkernels.count(key) != 0) { - // std::cout << "reused kernel (second): " << - // &savedUkernels.find(key)->second - // << "\n"; return &savedUkernels.find(key)->second; } @@ -107,7 +90,6 @@ EXPORT void *create_brgemm_ukernel(int64_t M, int64_t N, int64_t K_k, auto it = savedUkernels.insert({key, brg}); auto ret = &it.first->second; - // std::cout << "[create] brg: " << ret << std::endl; return ret; } @@ -154,54 +136,29 @@ EXPORT void *create_transform_ukernel(int64_t K, int64_t N, int64_t in_ld, } EXPORT void call_all(const void *transform_k, const void *brg_k, void *A_ptr, - void *original_B_ptr, void *C_ptr, void *scratchpad, + void *original_B_ptr, void *C_ptr, // void *scratchpad, int64_t A_step_in_bytes, int64_t B_step_in_bytes, int64_t B_block_size_in_bytes, int64_t num_batches, bool skip_packing = false) { uint8_t *blocked_data = (uint8_t *)original_B_ptr; uint8_t *B_ptr_calc = (uint8_t *)original_B_ptr; - // std::cout << "Call Transform: " << transform_k << " Brg: " << brg_k - // << ", a: " << A_ptr << ", b: " << original_B_ptr << ", c: " << - // C_ptr - // << ", scr: " << scratchpad << "\n"; - // std::cout << "steps: " << A_step_in_bytes << " " << B_step_in_bytes << " " - // << B_block_size_in_bytes << " n: " << num_batches << "\n"; auto pack_B = reinterpret_cast(transform_k); auto brg = reinterpret_cast(brg_k); - // std::cout << " vanilla check pack: " - // << ((brg->get_B_pack_type() == pack_type::pack32) ? "true" - // : "false") - // << "\n"; + bool need_packing = brg->get_B_pack_type() == pack_type::pack32 && !skip_packing; if (need_packing) { - // std::cout << "Will be packed. \n"; - // output - - // blocked_B_size = block_K * block_n * dtype; // ldb * K_k * - // memory::data_type_size(b_dt); - blocked_data = new uint8_t[B_block_size_in_bytes * num_batches]; - pack_B->execute(original_B_ptr, blocked_data); } brg->set_hw_context(); - if (need_packing) { - // std::cout << "[packed] b_ptr_calc: " << (void *)B_ptr_calc - // << " blocked_ptr: " << (void *)blocked_data << "\n"; - } else { - // std::cout << "[unpacked] b_ptr_calc: " << (void *)B_ptr_calc - // << " blocked_ptr: " << (void *)blocked_data << "\n"; - } - std::vector> A_B_offsets(num_batches); for (memory::dim i = 0; i < num_batches; i++) { - const memory::dim A_offset_i = - i * A_step_in_bytes; // * a_dt_size; // K_k * a_dt_size; + const memory::dim A_offset_i = i * A_step_in_bytes; memory::dim B_offset_i; if (need_packing) { @@ -215,45 +172,10 @@ EXPORT void call_all(const void *transform_k, const void *brg_k, void *A_ptr, } size_t scratchpad_size = brg->get_scratchpad_size(); - // std::cout << "scratchpad size: " << scratchpad_size << "\n"; std::vector scratchpad_sm(scratchpad_size); - std::stringstream ss; - ss << "A values:\n"; - for (uint m = 0; m < 16; m++) { - for (uint k = 0; k < 32; k++) { - ss << ((float *)A_ptr)[32 * m + k] << " "; - } - ss << "\n"; - } - ss << "B values:\n"; - for (uint k = 0; k < 32; k++) { - for (uint n = 0; n < 16; n++) { - ss << ((float *)blocked_data)[16 * k + n] << " "; - } - ss << "\n"; - } - ss << "ACC input:\n"; - for (uint m = 0; m < 16; m++) { - for (uint n = 0; n < 16; n++) { - ss << ((float *)C_ptr)[16 * m + n] << " "; - } - ss << "\n"; - } - ss << "Offsets:\n"; - for (int i = 0; i < num_batches; ++i) - ss << " " << i << ": " << A_B_offsets[i].first << ", " - << A_B_offsets[i].second << "\n"; // An execute call. `A_B` is a vector of pointers to A and packed B // tensors. `acc_ptr` is a pointer to an accumulator buffer. brg->execute(A_ptr, blocked_data, A_B_offsets, C_ptr, scratchpad_sm.data()); - ss << "ACC output:\n"; - for (uint m = 0; m < 16; m++) { - for (uint n = 0; n < 16; n++) { - ss << ((float *)C_ptr)[16 * m + n] << " "; - } - ss << "\n"; - } - // std::cout << ss.str(); dnnl::ukernel::brgemm::release_hw_context(); @@ -264,26 +186,13 @@ EXPORT void call_all(const void *transform_k, const void *brg_k, void *A_ptr, EXPORT void call_transform(const void *transform_k, const void *original_data, void *blocked_data) { + assert(false && "Not tested"); auto pack_B = reinterpret_cast(transform_k); pack_B->execute(original_data, blocked_data); } -// Most questionable function - no shure where to leave forming of offsets lists -// maybe too difficult in client code -// void prepare_buffers(int64_t batch_size) { -// // BRGeMM ukernel execute section. -// // Prepare buffers for execution. -// std::vector> A_B_offsets(batch_size); -// for (memory::dim i = 0; i < batch_size; i++) { -// const memory::dim A_offset_i = i * K_k * a_dt_size; -// const memory::dim B_offset_i = -// need_pack ? i * blocked_B_size : i * N * K_k * b_dt_size; -// A_B_offsets[i] = std::make_pair(A_offset_i, B_offset_i); -// } -// } - -// for perf targets EXPORT void prepare_hw_context(const void *brg_k) { + assert(false && "Not tested"); auto brg = reinterpret_cast(brg_k); brg->set_hw_context(); } @@ -291,23 +200,7 @@ EXPORT void prepare_hw_context(const void *brg_k) { EXPORT void call_brgemm(const void *brg_k, void *A_ptr, void *B_ptr, void *C_ptr, void *scratchpad, int64_t A_step_in_bytes, int64_t B_step_in_bytes, int64_t num_batches) { - // std::cout << "Call Brg: " << brg_k << ", " << A_ptr << ", " << B_ptr << ", - // " - // << C_ptr << ", " << scratchpad << "\n"; - // std::cout << "steps: " << A_step_in_bytes << " " << B_step_in_bytes - // << " n: " << num_batches << "\n"; - - if (A_ptr == nullptr || B_ptr == nullptr || C_ptr == nullptr) { - std::cout << "----------------FAIL----------------\n"; - return; - } - - // auto dnnl_dtypeA = static_cast(dtypeA); - // auto dnnl_dtypeB = static_cast(dtypeB); - - // const size_t a_dt_size = memory::data_type_size(dnnl_dtypeA); - // const size_t b_dt_size = memory::data_type_size(dnnl_dtypeB); - + assert(false && "Not tested"); std::vector> A_B_offsets(num_batches); for (memory::dim i = 0; i < num_batches; i++) { const memory::dim A_offset_i = @@ -321,15 +214,12 @@ EXPORT void call_brgemm(const void *brg_k, void *A_ptr, void *B_ptr, size_t scratchpad_size = brg->get_scratchpad_size(); std::vector scratchpad_sm(scratchpad_size); - // std::vector> A_B_offsets(3); - // An execute call. `A_B` is a vector of pointers to A and packed B - // tensors. `acc_ptr` is a pointer to an accumulator buffer. brg->execute(A_ptr, B_ptr, A_B_offsets, C_ptr, scratchpad_sm.data()); } // at the end of execution EXPORT void release_hw_context() { - // Once all computations are done, need to release HW context. + assert(false && "Not tested"); dnnl::ukernel::brgemm::release_hw_context(); } diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index e76fc23b6b9e1..e5a8fb37faa7b 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -28,6 +28,17 @@ #include #include +#ifdef ONEDNN_AVAILABLE +#include "oneapi/dnnl/dnnl_config.h" +#endif +bool is_onednn_available() { +#ifdef DNNL_EXPERIMENTAL_UKERNEL + return true; +#else + return false; +#endif +} + namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { @@ -91,9 +102,11 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_loop_invariant_code_motion", [](mlir::PassManager &pm) { pm.addPass(mlir::createLoopInvariantCodeMotionPass()); }); - m.def("add_convert_dot_to_onednn", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertDotToOneDNN()); - }); + m.def("add_convert_dot_to_onednn", + [](mlir::PassManager &pm, bool isReplacementToOneDnnPossible) { + pm.addPass(mlir::triton::cpu::createConvertDotToOneDNN( + isReplacementToOneDnnPossible)); + }); m.def("add_convert_dot_to_amx", [](mlir::PassManager &pm, bool convertInt8, bool convertFp16, bool convertBf16) { pm.addPass(mlir::triton::cpu::createConvertDotToAMX( @@ -145,9 +158,11 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_debug_ops_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); }); - m.def("add_onednn_ops_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createOneDNNOpsToLLVMPass()); - }); + m.def("add_onednn_ops_to_llvmir", + [](mlir::PassManager &pm, bool isReplacementToOneDnnPossible) { + pm.addPass(mlir::triton::cpu::createOneDNNOpsToLLVMPass( + isReplacementToOneDnnPossible)); + }); m.def("add_expand_strided_metadata", [](mlir::PassManager &pm) { pm.addPass(mlir::memref::createExpandStridedMetadataPass()); }); @@ -200,6 +215,8 @@ void init_triton_cpu(py::module &&m) { #endif // __linux__ && ARCH_REQ_XCOMP_PERM }); + m.def("onednn_available", is_onednn_available); + m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; registry.insert