1- From e0189210ee8e532bea15f0592a801f2264b62834 Mon Sep 17 00:00:00 2001
1+ From 887ae8599b205921bc4fd34da3c1de767f5568ae Mon Sep 17 00:00:00 2001
22From: Garra1980 <
[email protected] >
3- Date: Wed, 13 Aug 2025 17:08:31 +0200
3+ Date: Tue, 23 Sep 2025 21:22:18 +0200
44Subject: [PATCH] Add support for VectorAnyINTEL capability
55
66---
77 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +-
88 mlir/include/mlir/IR/CommonTypeConstraints.td | 86 +++++++++++
99 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +-
10- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 23 ++-
10+ mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 26 + ++-
1111 .../SPIRV/Transforms/SPIRVConversion.cpp | 135 +++++++++++++++---
1212 .../arith-to-spirv-unsupported.mlir | 4 +-
1313 .../ArithToSPIRV/arith-to-spirv.mlir | 34 +++++
@@ -21,13 +21,13 @@ Subject: [PATCH] Add support for VectorAnyINTEL capability
2121 mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 34 ++---
2222 mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +-
2323 mlir/test/Target/SPIRV/ocl-ops.mlir | 6 +
24- 17 files changed, 322 insertions(+), 67 deletions(-)
24+ 17 files changed, 324 insertions(+), 68 deletions(-)
2525
2626diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
27- index bdfd728d1d0b..31e8bc288d5b 100644
27+ index 0e42d08cdb1f..f821b0d2e59b 100644
2828--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
2929+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
30- @@ -4233 ,7 +4233 ,14 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
30+ @@ -4240 ,7 +4240 ,14 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
3131 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
3232 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
3333 def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
@@ -43,7 +43,7 @@ index bdfd728d1d0b..31e8bc288d5b 100644
4343 [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
4444 // Component type check is done in the type parser for the following SPIR-V
4545 // dialect-specific types so we use "Any" here.
46- @@ -4286 ,7 +4293 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
46+ @@ -4293 ,7 +4300 ,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
4747 "Matrix">;
4848
4949 class SPIRV_VectorOf<Type type> :
@@ -53,10 +53,10 @@ index bdfd728d1d0b..31e8bc288d5b 100644
5353 class SPIRV_ScalarOrVectorOf<Type type> :
5454 AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
5555diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
56- index b682f4c025a4..298553c83947 100644
56+ index 6b4e3dd60319..987b33c055e9 100644
5757--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
5858+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
59- @@ -648 ,6 +648 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
59+ @@ -654 ,6 +654 ,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
6060 ScalableVectorOfLength<allowedLengths>.summary,
6161 "::mlir::VectorType">;
6262
@@ -150,7 +150,7 @@ index b682f4c025a4..298553c83947 100644
150150 // Negative values for `n` index in reverse.
151151 class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
152152diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
153- index fcf152649197..bbc538a19840 100644
153+ index c8efdf009422..5236dc299f81 100644
154154--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
155155+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
156156@@ -186,9 +186,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -169,10 +169,10 @@ index fcf152649197..bbc538a19840 100644
169169 return Type();
170170 }
171171diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
172- index ddb342621f37..952c474fd34d 100644
172+ index 7e9a80e7d73a..1db6233cf73f 100644
173173--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
174174+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
175- @@ -98 ,9 +98 ,10 @@ bool CompositeType::classof(Type type) {
175+ @@ -186 ,9 +186 ,10 @@ bool CompositeType::classof(Type type) {
176176 }
177177
178178 bool CompositeType::isValid(VectorType type) {
@@ -186,31 +186,34 @@ index ddb342621f37..952c474fd34d 100644
186186 }
187187
188188 Type CompositeType::getElementType(unsigned index) const {
189- @@ -171,7 +172,21 @@ void CompositeType::getCapabilities(
190- .Case<VectorType>([&](VectorType type) {
191- auto vecSize = getNumElements();
192- if (vecSize == 8 || vecSize == 16) {
193- - static const Capability caps[] = {Capability::Vector16};
194- + static const Capability caps[] = {Capability::Vector16,
195- + Capability::VectorAnyINTEL};
196- + ArrayRef<Capability> ref(caps, std::size(caps));
197- + capabilities.push_back(ref);
198- + }
199- + // If the vector size is between [2 to 2^32 - 1]
200- + // and not of any size 2, 3, 4, 8, and 16
201- + // VectorAnyIntel Capability must be present
202- + // for the SPIR-V to be valid
203- + llvm::SmallVector<uint32_t, 5> allowedVecRange = {2, 3, 4, 8, 16};
204- + if (vecSize >= 2 &&
205- + (llvm::none_of(allowedVecRange, [&](uint32_t allowedVecSize) {
206- + return vecSize == allowedVecSize;
207- + }))) {
208- + static const Capability caps[] = {Capability::VectorAnyINTEL};
209- ArrayRef<Capability> ref(caps, std::size(caps));
210- capabilities.push_back(ref);
211- }
189+ @@ -221,8 +222,23 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
190+
191+ int64_t vecSize = type.getNumElements();
192+ if (vecSize == 8 || vecSize == 16) {
193+ - static constexpr auto cap = Capability::Vector16;
194+ - capabilities.push_back(cap);
195+ + static const Capability caps[] = {Capability::Vector16,
196+ + Capability::VectorAnyINTEL};
197+ + ArrayRef<Capability> ref(caps, std::size(caps));
198+ + capabilities.push_back(ref);
199+ + }
200+ + // If the vector size is between [2 to 2^32 - 1]
201+ + // and not of any size 2, 3, 4, 8, and 16
202+ + // VectorAnyIntel Capability must be present
203+ + // for the SPIR-V to be valid
204+ + llvm::SmallVector<uint32_t, 5> allowedVecRange = {2, 3, 4, 8, 16};
205+ + if (vecSize >= 2 &&
206+ + (llvm::none_of(allowedVecRange, [&](uint32_t allowedVecSize) {
207+ + return vecSize == allowedVecSize;
208+ + }))) {
209+ + static const Capability caps[] = {Capability::VectorAnyINTEL};
210+ + ArrayRef<Capability> ref(caps, std::size(caps));
211+ + capabilities.push_back(ref);
212+ }
213+ }
214+
212215diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
213- index 49f4ce8de7c7..eef55a427486 100644
216+ index 122f61e0a66a..c6f37e9345ed 100644
214217--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
215218+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
216219@@ -84,9 +84,13 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
@@ -426,10 +429,10 @@ index 9d7ab2be096e..3aa22e261f7c 100644
426429 }
427430
428431diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
429- index 6e2352e706ac..4c9d2e147bc6 100644
432+ index 3cb529459899..e881d512bf2e 100644
430433--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
431434+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
432- @@ -1479 ,6 +1479 ,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
435+ @@ -1532 ,6 +1532 ,40 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
433436 %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
434437 // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
435438 %3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
@@ -720,10 +723,10 @@ index 6aaaa6012fef..60ef7afeeeed 100644
720723 // expected-error @+1 {{expected ':'}}
721724 %2 = spirv.CL.s_abs %arg0, %arg1 : i32
722725diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
723- index b80e17f979da..32103f7b9c57 100644
726+ index ec47035d088b..44ce439a3f53 100644
724727--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
725728+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
726- @@ -6 ,9 +6 ,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
729+ @@ -11 ,9 +11 ,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, BFloat1
727730 %0 = spirv.FMul %arg0, %arg1 : f32
728731 spirv.Return
729732 }
@@ -737,10 +740,10 @@ index b80e17f979da..32103f7b9c57 100644
737740 }
738741 spirv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {
739742diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
740- index 9a2e4cf62e37..31a7f616d648 100644
743+ index 17accd93e824..ed9a9976e89b 100644
741744--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
742745+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
743- @@ -39 ,6 +39 ,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
746+ @@ -44 ,6 +44 ,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Vec
744747 spirv.Return
745748 }
746749
0 commit comments