diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 035ef1d1a..c96fdff2a 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -83,16 +83,6 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations, 100); - - // validate the arguments - bool m_valid = m > 0 && m % 16 == 0; - bool n_valid = n > 0 && n % 32 == 0; - bool k_valid = k > 0 && k % 32 == 0; - bool l_valid = l > 0; - if (!(m_valid && n_valid && k_valid && l_valid)) { - std::cout << "invalid arguments. Must be a multiple of (16, 32, 32)\n"; - std::exit(1); - } } /// Prints the usage statement. @@ -237,7 +227,10 @@ struct ExampleRunner { size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op.can_implement(arguments); + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } gemm_op.initialize(arguments, workspace.get()); @@ -348,7 +341,7 @@ int main(int argc, const char** argv) XE_2D_U32x8x16_ST_N, void, void>; -// Mainloop + // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< GEMMDispatchPolicy, TileShape, diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index b27f2a3cf..b359821b8 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -153,6 +153,14 @@ class GemmUniversal< static bool can_implement(Arguments const& args) { + auto m = get<0>(args.problem_shape); + auto n = get<1>(args.problem_shape); + auto k = get<2>(args.problem_shape); + bool m_valid = m > 0; + bool n_valid = n > 0 && n % 4 == 0; + bool k_valid = k > 0 && k % get<2>(TileShape{}) == 0; + if (!(m_valid and n_valid and k_valid)) return false; + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); return mode_implementable && TileScheduler::can_implement(args.scheduler);