Skip to content

Commit

Permalink
use can implement
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Dec 2, 2024
1 parent 439808c commit 83e3c85
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
17 changes: 5 additions & 12 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,6 @@ struct Options {
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations, 100);

// validate the arguments
bool m_valid = m > 0 && m % 16 == 0;
bool n_valid = n > 0 && n % 32 == 0;
bool k_valid = k > 0 && k % 32 == 0;
bool l_valid = l > 0;
if (!(m_valid && n_valid && k_valid && l_valid)) {
std::cout << "invalid arguments. Must be a multiple of (16, 32, 32)\n";
std::exit(1);
}
}

/// Prints the usage statement.
Expand Down Expand Up @@ -237,7 +227,10 @@ struct ExampleRunner {
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

gemm_op.can_implement(arguments);
if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){
std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::exit(1);
}

gemm_op.initialize(arguments, workspace.get());

Expand Down Expand Up @@ -348,7 +341,7 @@ int main(int argc, const char** argv)
XE_2D_U32x8x16_ST_N,
void, void>;

// Mainloop
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
Expand Down
8 changes: 8 additions & 0 deletions include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ class GemmUniversal<
static bool
can_implement(Arguments const& args) {
auto m = get<0>(args.problem_shape);
auto n = get<1>(args.problem_shape);
auto k = get<2>(args.problem_shape);
bool m_valid = m > 0;
bool n_valid = n > 0 && n % 4 == 0;
bool k_valid = k > 0 && k % get<2>(TileShape{}) == 0;
if (!(m_valid and n_valid and k_valid)) return false;
bool mode_implementable = args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable && TileScheduler::can_implement(args.scheduler);
Expand Down

0 comments on commit 83e3c85

Please sign in to comment.