diff --git a/cmake/onemkl.cmake b/cmake/onemkl.cmake index a73d1189d..046082c2c 100644 --- a/cmake/onemkl.cmake +++ b/cmake/onemkl.cmake @@ -30,6 +30,8 @@ include_guard() include(ExternalProject) +option(CUTLASS_SYCL_DISCONNECT_ONEMKL_UPDATE "Stop onemkl from updating when the git tag is not changed" YES) + set(ONEMKL_INSTALL_DIR ${CMAKE_BINARY_DIR}/deps/oneMKL) set(ONEMKL_INCLUDE_DIR ${ONEMKL_INSTALL_DIR}/include) set(ONEMKL_LIB_DIR ${ONEMKL_INSTALL_DIR}/lib) @@ -40,7 +42,7 @@ ExternalProject_Add( PREFIX ${ONEMKL_INSTALL_DIR} GIT_REPOSITORY "https://github.com/oneapi-src/oneMKL.git" - GIT_TAG "v0.5" + GIT_TAG "v0.6" CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} @@ -57,6 +59,7 @@ ExternalProject_Add( -DTARGET_DOMAINS=rng INSTALL_DIR ${ONEMKL_INSTALL_DIR} BUILD_BYPRODUCTS ${ONEMKL_LIB} + UPDATE_DISCONNECTED ${CUTLASS_SYCL_DISCONNECT_ONEMKL_UPDATE} ) add_library(oneMKL SHARED IMPORTED) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 8c64fdb04..c96fdff2a 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -227,7 +227,10 @@ struct ExampleRunner { size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op.can_implement(arguments); + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } gemm_op.initialize(arguments, workspace.get()); @@ -338,7 +341,7 @@ int main(int argc, const char** argv) XE_2D_U32x8x16_ST_N, void, void>; -// Mainloop + // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< GEMMDispatchPolicy, TileShape, diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 1a014ba16..2b41f3d31 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -153,9 +153,18 @@ class GemmUniversal< static bool can_implement(Arguments const& args) { - bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + auto m = get<0>(args.problem_shape); + auto n = get<1>(args.problem_shape); + auto k = get<2>(args.problem_shape); + // TODO(codeplay): base *_valid on the atom shapes + bool m_valid = m > 0; + bool n_valid = n > 0 && n % 4 == 0; + bool k_valid = k > 0 && k % get<2>(TileShape{}) == 0; + bool shape_implementable = (m_valid && n_valid && k_valid); + + bool mode_implementable = args.mode == GemmUniversalMode::kGemm || (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); - return mode_implementable && TileScheduler::can_implement(args.scheduler); + return shape_implementable && mode_implementable && TileScheduler::can_implement(args.scheduler); } static int @@ -219,7 +228,7 @@ class GemmUniversal< int sub_group_id = thread_idx / SubgroupSize; constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) constexpr auto subgroup_shape = SubgroupTileShape{}; - + Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l) Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l) Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)