Skip to content

Commit

Permalink
optional onednn/dnnl
Browse files Browse the repository at this point in the history
cmake fix

add options
  • Loading branch information
Devjiu committed Jan 13, 2025
1 parent b5ac6b3 commit 330d674
Show file tree
Hide file tree
Showing 22 changed files with 259 additions and 1,374 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ project(triton CXX C)
include(CTest)

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(DNNLROOT "/home/jovyan/intel/oneapi/dnnl/2024.1")

# Options
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
Expand Down
11 changes: 4 additions & 7 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ struct SharedMemoryObject {
}

Value getCSwizzleOffset(int order) const {
assert(order >= 0 && order < strides.size());
assert(order >= 0 && order < static_cast<int>(strides.size()));
return offsets[order];
}

Expand Down Expand Up @@ -512,7 +512,6 @@ inline SmallVector<Value>
emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
const BlockedEncodingAttr &blockedLayout,
RankedTensorType type) {
MLIRContext *ctx = rewriter.getContext();
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout));
Expand Down Expand Up @@ -557,7 +556,6 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
inline SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
RankedTensorType type) {
auto ctx = type.getContext();
auto shape = type.getShape();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
Expand Down Expand Up @@ -1205,9 +1203,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
// Order
auto inOrder = triton::gpu::getOrder(srcEncoding);
auto outOrder = triton::gpu::getOrder(resSharedLayout);
assert(maxPhase == 1 ||
outVec * maxPhase <= srcShape[outOrder[0]] &&
"Swizzling would generate out of bounds memory accesses");
assert((maxPhase == 1 || outVec * maxPhase <= srcShape[outOrder[0]]) &&
"Swizzling would generate out of bounds memory accesses");
// Tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy,
/*withCTAOffset=*/false);
Expand Down Expand Up @@ -1424,7 +1421,7 @@ inline Value packLLVector(Location loc, ValueRange vals,
assert(vals.size() > 0);
auto vecType = vec_ty(vals[0].getType(), vals.size());
Value vec = undef(vecType);
for (int i = 0; i < vals.size(); i++) {
for (int i = 0; i < static_cast<int>(vals.size()); i++) {
vec = insert_element(vec, vals[i], i32_val(i));
}
return vec;
Expand Down
47 changes: 8 additions & 39 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,13 @@ def TTC_BrgemmCreate : TTC_Op<"brgemm_create", [NoMemoryEffect]> {
AnyTypeOf<[AnyInteger, Index]>:$lda,
AnyTypeOf<[AnyInteger, Index]>:$ldb,
AnyTypeOf<[AnyInteger, Index]>:$ldc,
I64:$dtypeA,
I64:$dtypeB,
I64:$dtypeC
// TODO: Maybe Use properties
TypeAttr:$dtypeA,
TypeAttr:$dtypeB,
TypeAttr:$dtypeC
);

let results = (outs Index:$result);

// let assemblyFormat = [{
// $prefix attr-dict (`:` $val^ `:` type($val))?
// }];

// let hasVerifier = 1;
}

def TTC_BrgemmCall : TTC_Op<"brgemm_call",
Expand All @@ -231,20 +226,13 @@ def TTC_BrgemmCall : TTC_Op<"brgemm_call",
Index:$kernel_hash,
Arg<AnyMemRef, "", [MemRead]>:$A_ptr,
Arg<AnyMemRef, "", [MemRead]>:$B_ptr,
// AnyTypeOf<[TT_Float, TT_Int, TT_Ptr, TTC_Vector]>:$offsets,
Arg<AnyMemRef, "", [MemWrite]>:$C_ptr,
Arg<AnyMemRef, "", [MemWrite]>:$scratchpad,

Index:$stepA,
Index:$stepB,
Index:$numBatches
);

// let assemblyFormat = [{
// $prefix attr-dict (`:` $val^ `:` type($val))?
// }];

// let hasVerifier = 1;
}

def TTC_BrgemmNeedsPacking : TTC_Op<"brgemm_needs_packing",
Expand All @@ -268,9 +256,8 @@ def TTC_CallBrgemmWithTransform : TTC_Op<"pack_and_brgemm",
Index:$brgemm_kernel_hash,
Arg<AnyMemRef, "", [MemRead]>:$A_ptr,
Arg<AnyMemRef, "", [MemRead]>:$B_ptr,
// AnyTypeOf<[TT_Float, TT_Int, TT_Ptr, TTC_Vector]>:$offsets,
Arg<AnyMemRef, "", [MemWrite]>:$C_ptr,
Arg<AnyMemRef, "", [MemWrite]>:$scratchpad,
// Arg<AnyMemRef, "", [MemWrite]>:$scratchpad,

AnyTypeOf<[AnyInteger, Index]>:$stepA,
AnyTypeOf<[AnyInteger, Index]>:$stepB,
Expand All @@ -288,20 +275,14 @@ def TTC_TransformCreate : TTC_Op<"transform_create", [NoMemoryEffect]> {
let arguments = (ins
AnyTypeOf<[AnyInteger, Index]>:$K,
AnyTypeOf<[AnyInteger, Index]>:$N,
// I64:$in_pack_type,
AnyTypeOf<[AnyInteger, Index]>:$in_ld,
AnyTypeOf<[AnyInteger, Index]>:$out_ld,
I64:$in_dt,
I64:$out_dt
// TODO: Maybe Use properties
TypeAttr:$in_dt,
TypeAttr:$out_dt
);

let results = (outs Index:$result);

// let assemblyFormat = [{
// $prefix attr-dict (`:` $val^ `:` type($val))?
// }];

// let hasVerifier = 1;
}

def TTC_TransformCall : TTC_Op<"transform_call",
Expand All @@ -326,12 +307,6 @@ def TTC_ConfigureHW : TTC_Op<"configure_hw", [MemoryEffects<[MemWrite<GlobalMemo

// M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC
let arguments = (ins Index:$brgemm_kernel_hash);

// let assemblyFormat = [{
// $prefix attr-dict (`:` $val^ `:` type($val))?
// }];

// let hasVerifier = 1;
}

def TTC_ReleaseHW : TTC_Op<"release_hw", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
Expand All @@ -341,12 +316,6 @@ def TTC_ReleaseHW : TTC_Op<"release_hw", [MemoryEffects<[MemWrite<GlobalMemory>]

// M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC
let arguments = (ins);

// let assemblyFormat = [{
// $prefix attr-dict (`:` $val^ `:` type($val))?
// }];

// let hasVerifier = 1;
}

#endif
2 changes: 1 addition & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]

libraries += ["gcc"]
cc_cmd += ["-g"]
# Use dynamic lookup to load Python library on Mac
if system == "Darwin":
cc_cmd += ["-undefined", "dynamic_lookup"]
Expand Down
106 changes: 0 additions & 106 deletions test/TritonCPU/dot-to-onednn.mlir

This file was deleted.

24 changes: 19 additions & 5 deletions third_party/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
find_package(dnnl CONFIG)
if (dnnl_FOUND)
message(STATUS "Found OneDNN/DNNL")
add_compile_definitions(ONEDNN_AVAILABLE)
get_target_property(dnnl_include DNNL::dnnl INTERFACE_INCLUDE_DIRECTORIES)
# currently used only in triton_cpu.cc and in ConvertDotToOneDNN
include_directories(${dnnl_include})
else ()
message(STATUS "Could NOT find OneDNN/DNNL")
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(/usr/local/include)
link_directories(/usr/local/lib)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms)
target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms dnnl PRIVATE Python3::Module pybind11::headers)
target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms PRIVATE Python3::Module pybind11::headers)
endif()

add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_onednn.cpp)
target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport dnnl)
if (dnnl_FOUND)
add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_onednn.cpp)
target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport DNNL::dnnl)
else ()
add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp)
target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport)
endif()

# Build and link sleef
set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE)
Expand Down
Loading

0 comments on commit 330d674

Please sign in to comment.