From 66ae951eb2fd00c1c3c7bf5928c349442211c9f7 Mon Sep 17 00:00:00 2001 From: Kris Rowe Date: Thu, 14 Dec 2023 21:49:51 -0600 Subject: [PATCH] Set SYCL subgroup size via kernel property or `@simd_length` attribute. (#726) --- src/occa/internal/lang/attribute.cpp | 5 + src/occa/internal/lang/attribute.hpp | 2 + .../internal/lang/builtins/attributes.hpp | 1 + .../lang/builtins/attributes/simdLength.cpp | 50 ++++++++++ .../lang/builtins/attributes/simdLength.hpp | 24 +++++ src/occa/internal/lang/modes/dpcpp.cpp | 94 ++++++++++++++++--- src/occa/internal/lang/modes/dpcpp.hpp | 3 +- src/occa/internal/lang/modes/okl.cpp | 1 + src/occa/internal/lang/modes/withLauncher.cpp | 4 + tests/src/internal/lang/modes/dpcpp.cpp | 56 +++++++++++ 10 files changed, 227 insertions(+), 13 deletions(-) create mode 100644 src/occa/internal/lang/builtins/attributes/simdLength.cpp create mode 100644 src/occa/internal/lang/builtins/attributes/simdLength.hpp diff --git a/src/occa/internal/lang/attribute.cpp b/src/occa/internal/lang/attribute.cpp index 85e6750ab..4ce0ee659 100644 --- a/src/occa/internal/lang/attribute.cpp +++ b/src/occa/internal/lang/attribute.cpp @@ -54,6 +54,11 @@ namespace occa { bool attributeArg_t::exists() const { return expr; } + + bool attributeArg_t::canEvaluate() const { + if (!expr) return false; + return expr->canEvaluate(); + } //================================== //---[ Attribute ]------------------ diff --git a/src/occa/internal/lang/attribute.hpp b/src/occa/internal/lang/attribute.hpp index 0bdb0bb1d..9a8066a9e 100644 --- a/src/occa/internal/lang/attribute.hpp +++ b/src/occa/internal/lang/attribute.hpp @@ -63,6 +63,8 @@ namespace occa { void clear(); bool exists() const; + + bool canEvaluate() const; }; //================================== diff --git a/src/occa/internal/lang/builtins/attributes.hpp b/src/occa/internal/lang/builtins/attributes.hpp index fb4307225..251459b9c 100644 --- a/src/occa/internal/lang/builtins/attributes.hpp +++ b/src/occa/internal/lang/builtins/attributes.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #endif diff --git a/src/occa/internal/lang/builtins/attributes/simdLength.cpp b/src/occa/internal/lang/builtins/attributes/simdLength.cpp new file mode 100644 index 000000000..f1303eb4a --- /dev/null +++ b/src/occa/internal/lang/builtins/attributes/simdLength.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include + +namespace occa { +namespace lang { +namespace attributes { + + const std::string& simdLength::name() const { return name_;} + + bool simdLength::forStatementType(const int sType) const { + return (sType & statementType::for_); + } + + bool simdLength::isValid(const attributeToken_t &attr) const { + if (attr.kwargs.size()) { + attr.printError(name_ + " does not take kwargs"); + return false; + } + + if (1 != attr.args.size()) { + attr.printError(name_ + " takes one argument"); + return false; + } + + const auto& attr_arg = attr.args[0]; + if (!attr_arg.canEvaluate()) { + attr.printError(name_ + " cannot evaluate argument"); + return false; + } + + primitive value = attr_arg.expr->evaluate(); + if (!value.isInteger()) { + attr.printError(name_ + " take an integer argument"); + return false; + } + + if(0 > value.to()) { + attr.printError(name_ + " arguments must be postive!"); + return false; + } + + return true; + } + +} +} +} diff --git a/src/occa/internal/lang/builtins/attributes/simdLength.hpp b/src/occa/internal/lang/builtins/attributes/simdLength.hpp new file mode 100644 index 000000000..da5ebf906 --- /dev/null +++ b/src/occa/internal/lang/builtins/attributes/simdLength.hpp @@ -0,0 +1,24 @@ +#ifndef OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER +#define OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER + +#include + +namespace occa { +namespace lang { +namespace attributes { + +class simdLength : public attribute_t { +public: + simdLength() = default; + const std::string& name() const override; + bool forStatementType(const int sType) const override; + bool isValid(const attributeToken_t &attr) const override; +private: + static const inline std::string name_{"simd_length"}; +}; + +} +} +} + +#endif diff --git a/src/occa/internal/lang/modes/dpcpp.cpp b/src/occa/internal/lang/modes/dpcpp.cpp index 54a675fb4..eea56cd79 100644 --- a/src/occa/internal/lang/modes/dpcpp.cpp +++ b/src/occa/internal/lang/modes/dpcpp.cpp @@ -5,7 +5,78 @@ #include #include #include -// #include +#include + +namespace { + +class dpcppLambda_t : public occa::lang::lambda_t { +public: + int simd_length{-1}; + + dpcppLambda_t(occa::lang::capture_t capture_, int simd_length_) + : lambda_t(capture_), simd_length(simd_length_) {} + + dpcppLambda_t(const dpcppLambda_t& other) + : lambda_t(other), simd_length(other.simd_length) {} + + ~dpcppLambda_t() = default; + + bool equals(const type_t &other) const override { + const dpcppLambda_t &other_ = other.to(); + if (simd_length != other_.simd_length) return false; + return lambda_t::equals(other); + } + + void printDeclaration(occa::lang::printer &pout) const override { + pout << "["; + + switch (this->capture) { + case occa::lang::capture_t::byValue: + pout << "="; + break; + case occa::lang::capture_t::byReference: + pout << "&"; + break; + default: + pout << "???"; + break; + } + + pout << "]("; + + if (!args.empty()) { + const std::string argIndent = pout.indentFromNewline(); + args[0]->printDeclaration(pout); + for (std::size_t i = 1; i < args.size(); ++i) { + pout << ",\n" << argIndent; + args[i]->printDeclaration(pout); + } + } + pout << ") "; + + if (0 < simd_length) { + pout << "[[intel::reqd_sub_group_size("; + pout.print(simd_length); + pout << ")]]"; + } + + pout << " {"; + + pout.printNewline(); + pout.pushInlined(false); + pout.addIndentation(); + + body->print(pout); + + pout.removeIndentation(); + pout.popInlined(); + pout.printNewline(); + pout.printIndentation(); + pout << "}\n"; + } +}; + +} namespace occa { @@ -20,6 +91,7 @@ namespace occa shared("auto", qualifierType::custom) { okl::addOklAttributes(*this); + simd_length_default = settings_.get("simd_length",-1); } void dpcppParser::onClear() @@ -79,15 +151,7 @@ namespace occa std::string dpcppParser::launchBoundsAttribute(const int innerDims[3]) { - std::stringstream ss; - ss << "[[sycl::reqd_work_group_size(" - << innerDims[2] - << "," - << innerDims[1] - << "," - << innerDims[0] - << ")]]\n"; - return ss.str(); + return ""; } // @note: As of SYCL 2020 this will need to change from `CL/sycl.hpp` to `sycl.hpp` @@ -188,9 +252,15 @@ namespace occa lambda_t &cg_function = *(new lambda_t(capture_t::byReference)); cg_function.addArgument(sycl_handler); - lambda_t &sycl_kernel = *(new lambda_t(capture_t::byValue)); - sycl_kernel.addArgument(sycl_nditem); + int simd_length = simd_length_default; + if (k.hasAttribute("simd_length")) { + const attributeToken_t& attr = k.attributes["simd_length"]; + simd_length = attr.args[0].expr->evaluate(); + } + + dpcppLambda_t& sycl_kernel = *(new dpcppLambda_t(capture_t::byValue, simd_length)); + sycl_kernel.addArgument(sycl_nditem); sycl_kernel.body->swap(k); lambdaNode sycl_kernel_node(sycl_kernel.source, sycl_kernel); diff --git a/src/occa/internal/lang/modes/dpcpp.hpp b/src/occa/internal/lang/modes/dpcpp.hpp index 2affd6954..6186e1a68 100644 --- a/src/occa/internal/lang/modes/dpcpp.hpp +++ b/src/occa/internal/lang/modes/dpcpp.hpp @@ -42,13 +42,14 @@ namespace occa void setSharedQualifiers(); void setKernelQualifiers(function_t &function); void migrateLocalDecls(functionDeclStatement &kernelSmnt); - void setLaunchBounds(); void setupAtomics(); static bool transformAtomicBlockStatement(blockStatement &blockSmnt); static bool transformAtomicBasicExpressionStatement(expressionStatement &exprSmnt); private: + int simd_length_default; + inline int dpcppDimensionOrder(const int index) { return 2 - index; } }; } // namespace okl diff --git a/src/occa/internal/lang/modes/okl.cpp b/src/occa/internal/lang/modes/okl.cpp index d65bf37d0..241b0e6e8 100644 --- a/src/occa/internal/lang/modes/okl.cpp +++ b/src/occa/internal/lang/modes/okl.cpp @@ -392,6 +392,7 @@ namespace occa { parser.addAttribute(); parser.addAttribute(); parser.addAttribute(); + parser.addAttribute(); } void setOklLoopIndices(functionDeclStatement &kernelSmnt) { diff --git a/src/occa/internal/lang/modes/withLauncher.cpp b/src/occa/internal/lang/modes/withLauncher.cpp index 4214c431b..0d6e78d6f 100644 --- a/src/occa/internal/lang/modes/withLauncher.cpp +++ b/src/occa/internal/lang/modes/withLauncher.cpp @@ -437,6 +437,10 @@ namespace occa { forStatement &newForSmnt = (forStatement&) forSmnt.clone(); newKernelSmnt.set(newForSmnt); + if (newForSmnt.hasAttribute("simd_length")) { + newKernelSmnt.addAttribute(newForSmnt.attributes["simd_length"]); + } + bool addLaunchBoundsAttribute{true}; int kernelInnerDims[3] = {1,1,1}; if (newForSmnt.hasAttribute("max_inner_dims")) { diff --git a/tests/src/internal/lang/modes/dpcpp.cpp b/tests/src/internal/lang/modes/dpcpp.cpp index 51c8dea62..bd8a86e78 100644 --- a/tests/src/internal/lang/modes/dpcpp.cpp +++ b/tests/src/internal/lang/modes/dpcpp.cpp @@ -28,6 +28,7 @@ void testSharedAnnotation(); void testBarriers(); void testAtomic(); void testSource(); +void testSimdLength(); int main(const int argc, const char **argv) { parser.settings["okl/validate"] = true; @@ -38,6 +39,7 @@ int main(const int argc, const char **argv) { testSharedAnnotation(); testBarriers(); testSource(); + testSimdLength(); return 0; } @@ -163,3 +165,57 @@ void testSource() { "}\n" ); } + +void testSimdLengthAttribute() { + const std::string kernel_source = R"( + @kernel void f() { + @outer @simd_length(16) + for (int o = 0; o < 1; ++o) { + @inner for (int i = 0; i < 32; ++i) { + int j = i + o; + } + } + } + )"; + + parser.parseSource(kernel_source); + ASSERT_TRUE(parser.success); + + printer pout; + parser.root.print(pout); + const std::string translated_source = pout.str(); + + auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]"); + ASSERT_TRUE(std::string::npos != pos); +} + +void testSimdLengthProperty() { + const std::string kernel_source = R"( + @kernel void f() { + @outer for (int o = 0; o < 1; ++o) { + @inner for (int i = 0; i < 32; ++i) { + int j = i + o; + } + } + } + )"; + + occa::json properties; + properties["simd_length"] = 16; + occa::lang::okl::dpcppParser dpcpp_parser(properties); + + dpcpp_parser.parseSource(kernel_source); + ASSERT_TRUE(parser.success); + + printer pout; + dpcpp_parser.root.print(pout); + const std::string translated_source = pout.str(); + + auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]"); + ASSERT_TRUE(std::string::npos != pos); +} + +void testSimdLength() { + testSimdLengthAttribute(); + testSimdLengthProperty(); +}