From baae5e3f28af7e1d1d0e22d441953ce9c8dbe7a9 Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Thu, 11 Jan 2024 20:33:24 +0100 Subject: [PATCH 1/6] Add Bitwise operator on vectors --- include/NZSL/Math/Vector.hpp | 13 ++++ include/NZSL/Math/Vector.inl | 137 +++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/include/NZSL/Math/Vector.hpp b/include/NZSL/Math/Vector.hpp index bc3c087..30ebfbb 100644 --- a/include/NZSL/Math/Vector.hpp +++ b/include/NZSL/Math/Vector.hpp @@ -70,6 +70,19 @@ namespace nzsl constexpr bool operator==(const Vector& vec) const; constexpr bool operator!=(const Vector& vec) const; + constexpr Vector operator~() const; + constexpr Vector operator&(const Vector& vec) const; + constexpr Vector operator|(const Vector& vec) const; + constexpr Vector operator^(const Vector& vec) const; + constexpr Vector operator<<(const Vector& vec) const; + constexpr Vector operator>>(const Vector& vec) const; + + constexpr Vector operator&=(const Vector& vec); + constexpr Vector operator|=(const Vector& vec); + constexpr Vector operator^=(const Vector& vec); + constexpr Vector operator<<=(const Vector& vec); + constexpr Vector operator>>=(const Vector& vec); + static constexpr bool ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference = std::numeric_limits::epsilon()); static constexpr Vector CrossProduct(const Vector& lhs, const Vector& rhs); static T Distance(const Vector& lhs, const Vector& rhs); diff --git a/include/NZSL/Math/Vector.inl b/include/NZSL/Math/Vector.inl index 99c193f..9815712 100644 --- a/include/NZSL/Math/Vector.inl +++ b/include/NZSL/Math/Vector.inl @@ -302,6 +302,143 @@ namespace nzsl return !operator==(vec); } + template + constexpr Vector Vector::operator~() const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = ~values[i]; + + return result; + } + + template + constexpr Vector Vector::operator&(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] & vec.values[i]; + + return result; + } + + template + constexpr Vector Vector::operator|(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] | vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator^(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] ^ vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator<<(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] << vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator>>(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + { +#if NAZARA_CHECK_CPP_VER(NAZARA_CPP20) + // C++20 ensures that right shift performs an arthmetic shift on signed integers + result.values[i] = values[i] >> vec.values[i]; +#else + // Implement arithmetic shift on C++ <=17 + if constexpr (std::is_signed_v) + { + if (values[i] < 0 && vec.values[i] > 0) + result.values[i] = (values[i] >> vec.values[i]) | ~(~static_cast>(0u) >> vec.values[i]); + else + result.values[i] = values[i] >> vec.values[i]; + } + else + result.values[i] = values[i] >> vec.values[i]; +#endif + } + + return result; + } + + template + inline constexpr Vector Vector::operator&=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] &= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator|=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] |= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator^=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] ^= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator<<=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] <<= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator>>=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + { +#if NAZARA_CHECK_CPP_VER(NAZARA_CPP20) + // C++20 ensures that right shift performs an arthmetic shift on signed integers + values[i] = values[i] >> vec.values[i]; +#else + // Implement arithmetic shift on C++ <=17 + if constexpr (std::is_signed_v) + { + if (values[i] < 0 && vec.values[i] > 0) + values[i] = (values[i] >> vec.values[i]) | ~(~static_cast>(0u) >> vec.values[i]); + else + values[i] = values[i] >> vec.values[i]; + } + else + values[i] = values[i] >> vec.values[i]; +#endif + } + + return this; + } + template constexpr bool Vector::ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference) { From d0f05be06da2f6f5ed2647d9b92712393731246e Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Wed, 17 Jan 2024 20:16:07 +0100 Subject: [PATCH 2/6] Vector: Make use of ArithmeticRightShift function --- include/NZSL/Math/Vector.inl | 37 +++--------------------------------- 1 file changed, 3 insertions(+), 34 deletions(-) diff --git a/include/NZSL/Math/Vector.inl b/include/NZSL/Math/Vector.inl index 9815712..57392a3 100644 --- a/include/NZSL/Math/Vector.inl +++ b/include/NZSL/Math/Vector.inl @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include namespace nzsl { @@ -357,23 +358,7 @@ namespace nzsl { Vector result; for (std::size_t i = 0; i < N; ++i) - { -#if NAZARA_CHECK_CPP_VER(NAZARA_CPP20) - // C++20 ensures that right shift performs an arthmetic shift on signed integers - result.values[i] = values[i] >> vec.values[i]; -#else - // Implement arithmetic shift on C++ <=17 - if constexpr (std::is_signed_v) - { - if (values[i] < 0 && vec.values[i] > 0) - result.values[i] = (values[i] >> vec.values[i]) | ~(~static_cast>(0u) >> vec.values[i]); - else - result.values[i] = values[i] >> vec.values[i]; - } - else - result.values[i] = values[i] >> vec.values[i]; -#endif - } + result.values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]); return result; } @@ -418,23 +403,7 @@ namespace nzsl inline constexpr Vector Vector::operator>>=(const Vector& vec) { for (std::size_t i = 0; i < N; ++i) - { -#if NAZARA_CHECK_CPP_VER(NAZARA_CPP20) - // C++20 ensures that right shift performs an arthmetic shift on signed integers - values[i] = values[i] >> vec.values[i]; -#else - // Implement arithmetic shift on C++ <=17 - if constexpr (std::is_signed_v) - { - if (values[i] < 0 && vec.values[i] > 0) - values[i] = (values[i] >> vec.values[i]) | ~(~static_cast>(0u) >> vec.values[i]); - else - values[i] = values[i] >> vec.values[i]; - } - else - values[i] = values[i] >> vec.values[i]; -#endif - } + values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]); return this; } From 610abe8a50da8ff3949fe1790eb7b26d4efcbf71 Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Wed, 17 Jan 2024 20:24:36 +0100 Subject: [PATCH 3/6] add missing TypeMustMatch check --- src/NZSL/Ast/SanitizeVisitor.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index dfad18d..d3aaeda 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -5712,6 +5712,7 @@ namespace nzsl::Ast if (leftType != PrimitiveType::Int32 && leftType != PrimitiveType::UInt32) throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; } From 9b982838af5949f2f547bbd42320b850a2694e49 Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Wed, 17 Jan 2024 20:44:36 +0100 Subject: [PATCH 4/6] Extend sanitization rules to allow bitwise ops - Only for Vector of integers --- src/NZSL/Ast/SanitizeVisitor.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index d3aaeda..0a02f77 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -5823,11 +5823,25 @@ namespace nzsl::Ast case BinaryType::BitwiseAnd: case BinaryType::BitwiseOr: case BinaryType::BitwiseXor: - case BinaryType::LogicalAnd: - case BinaryType::LogicalOr: case BinaryType::ShiftLeft: case BinaryType::ShiftRight: - throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + { + if (leftType.type != PrimitiveType::Int32 && leftType.type != PrimitiveType::UInt32) + throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + + TypeMustMatch(leftExprType, rightExprType, sourceLocation); + return leftExprType; + } + + case BinaryType::LogicalAnd: + case BinaryType::LogicalOr: + { + if (leftType.type != PrimitiveType::Boolean) + throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + + TypeMustMatch(leftExprType, rightExprType, sourceLocation); + return PrimitiveType::Boolean; + } } } From f6d5d0b14b96f497a806a323c4e615f99da64750 Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Fri, 26 Jan 2024 11:47:01 +0100 Subject: [PATCH 5/6] Fix OptimizationTests of Bitwise operation --- tests/src/Tests/OptimizationTests.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/src/Tests/OptimizationTests.cpp b/tests/src/Tests/OptimizationTests.cpp index 37539ab..8df5eb5 100644 --- a/tests/src/Tests/OptimizationTests.cpp +++ b/tests/src/Tests/OptimizationTests.cpp @@ -79,8 +79,8 @@ fn main() let output7 = -42 << 10; let output8 = -42 >> 10; - let output9 = u32(1) << 10; - let output10 = u32(1024) >> 10; + let output9 = u32(1) << u32(10); + let output10 = u32(1024) >> u32(10); } )", R"( [entry(frag)] From ee6d2254e289bedcb41b1be16845333d79d451e9 Mon Sep 17 00:00:00 2001 From: EstebanNunez Date: Mon, 11 Mar 2024 13:47:53 +0100 Subject: [PATCH 6/6] Add test case for Vector bitwise operator --- tests/src/Tests/ArithmeticTests.cpp | 48 +++++++++++++++++++++++++++ tests/src/Tests/OptimizationTests.cpp | 1 - 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/src/Tests/ArithmeticTests.cpp b/tests/src/Tests/ArithmeticTests.cpp index a9515d7..05d51c6 100644 --- a/tests/src/Tests/ArithmeticTests.cpp +++ b/tests/src/Tests/ArithmeticTests.cpp @@ -296,6 +296,24 @@ fn main() let r = x ^ y; let r = x << y; let r = x >> y; + + let x = vec3[i32](0, 1, 2); + let y = vec3[i32](2, 1, 0); + + let r = x & y; + let r = x | y; + let r = x ^ y; + let r = x << y; + let r = x >> y; + + let x = vec3[u32](u32(0), u32(1), u32(2)); + let y = vec3[u32](u32(2), u32(1), u32(0)); + + let r = x & y; + let r = x | y; + let r = x ^ y; + let r = x << y; + let r = x >> y; } )"; @@ -319,6 +337,20 @@ void main() uint r_8 = x_2 ^ y_2; uint r_9 = x_2 << y_2; uint r_10 = x_2 >> y_2; + ivec3 x_3 = ivec3(0, 1, 2); + ivec3 y_3 = ivec3(2, 1, 0); + ivec3 r_11 = x_3 & y_3; + ivec3 r_12 = x_3 | y_3; + ivec3 r_13 = x_3 ^ y_3; + ivec3 r_14 = x_3 << y_3; + ivec3 r_15 = x_3 >> y_3; + uvec3 x_4 = uvec3(uint(0), uint(1), uint(2)); + uvec3 y_4 = uvec3(uint(2), uint(1), uint(0)); + uvec3 r_16 = x_4 & y_4; + uvec3 r_17 = x_4 | y_4; + uvec3 r_18 = x_4 ^ y_4; + uvec3 r_19 = x_4 << y_4; + uvec3 r_20 = x_4 >> y_4; } )"); @@ -340,6 +372,20 @@ fn main() let r: u32 = x ^ y; let r: u32 = x << y; let r: u32 = x >> y; + let x: vec3[i32] = vec3[i32](0, 1, 2); + let y: vec3[i32] = vec3[i32](2, 1, 0); + let r: vec3[i32] = x & y; + let r: vec3[i32] = x | y; + let r: vec3[i32] = x ^ y; + let r: vec3[i32] = x << y; + let r: vec3[i32] = x >> y; + let x: vec3[u32] = vec3[u32](u32(0), u32(1), u32(2)); + let y: vec3[u32] = vec3[u32](u32(2), u32(1), u32(0)); + let r: vec3[u32] = x & y; + let r: vec3[u32] = x | y; + let r: vec3[u32] = x ^ y; + let r: vec3[u32] = x << y; + let r: vec3[u32] = x >> y; } )"); @@ -1028,6 +1074,8 @@ fn main() let r = -r * +r; let r = ~42; let r = ~u32(42); + let r = ~vec3[i32](1, 2, 3); + let r = ~vec3[u32](u32(1), u32(2), u32(3)); } )"; diff --git a/tests/src/Tests/OptimizationTests.cpp b/tests/src/Tests/OptimizationTests.cpp index 8df5eb5..51d3954 100644 --- a/tests/src/Tests/OptimizationTests.cpp +++ b/tests/src/Tests/OptimizationTests.cpp @@ -193,7 +193,6 @@ fn main() )"); } - WHEN("eliminating multiple split branches") { PropagateConstantAndExpect(R"(