Skip to content

Commit

Permalink
Formatting, pvc gemm args validation, onemkl cmake (#159)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Alejandro Acosta <[email protected]>
  • Loading branch information
FMarno and aacostadiaz authored Dec 5, 2024
1 parent 744b891 commit e6466a9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
5 changes: 4 additions & 1 deletion cmake/onemkl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ include_guard()

include(ExternalProject)

option(CUTLASS_SYCL_DISCONNECT_ONEMKL_UPDATE "Stop onemkl from updating when the git tag is not changed" YES)

set(ONEMKL_INSTALL_DIR ${CMAKE_BINARY_DIR}/deps/oneMKL)
set(ONEMKL_INCLUDE_DIR ${ONEMKL_INSTALL_DIR}/include)
set(ONEMKL_LIB_DIR ${ONEMKL_INSTALL_DIR}/lib)
Expand All @@ -40,7 +42,7 @@ ExternalProject_Add(

PREFIX ${ONEMKL_INSTALL_DIR}
GIT_REPOSITORY "https://github.com/oneapi-src/oneMKL.git"
GIT_TAG "v0.5"
GIT_TAG "v0.6"

CMAKE_ARGS
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
Expand All @@ -57,6 +59,7 @@ ExternalProject_Add(
-DTARGET_DOMAINS=rng
INSTALL_DIR ${ONEMKL_INSTALL_DIR}
BUILD_BYPRODUCTS ${ONEMKL_LIB}
UPDATE_DISCONNECTED ${CUTLASS_SYCL_DISCONNECT_ONEMKL_UPDATE}
)

add_library(oneMKL SHARED IMPORTED)
7 changes: 5 additions & 2 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ struct ExampleRunner {
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

gemm_op.can_implement(arguments);
if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){
std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::exit(1);
}

gemm_op.initialize(arguments, workspace.get());

Expand Down Expand Up @@ -338,7 +341,7 @@ int main(int argc, const char** argv)
XE_2D_U32x8x16_ST_N,
void, void>;

// Mainloop
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
Expand Down
15 changes: 12 additions & 3 deletions include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,18 @@ class GemmUniversal<
static bool
can_implement(Arguments const& args) {
bool mode_implementable = args.mode == GemmUniversalMode::kGemm or
auto m = get<0>(args.problem_shape);
auto n = get<1>(args.problem_shape);
auto k = get<2>(args.problem_shape);
// TODO(codeplay): base *_valid on the atom shapes
bool m_valid = m > 0;
bool n_valid = n > 0 && n % 4 == 0;
bool k_valid = k > 0 && k % get<2>(TileShape{}) == 0;
bool shape_implementable = (m_valid && n_valid && k_valid);
bool mode_implementable = args.mode == GemmUniversalMode::kGemm ||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable && TileScheduler::can_implement(args.scheduler);
return shape_implementable && mode_implementable && TileScheduler::can_implement(args.scheduler);
}
static int
Expand Down Expand Up @@ -219,7 +228,7 @@ class GemmUniversal<
int sub_group_id = thread_idx / SubgroupSize;
constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = SubgroupTileShape{};
Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l)
Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l)
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Expand Down

0 comments on commit e6466a9

Please sign in to comment.