diff --git a/include/NZSL/Math/Vector.hpp b/include/NZSL/Math/Vector.hpp index bc3c087..30ebfbb 100644 --- a/include/NZSL/Math/Vector.hpp +++ b/include/NZSL/Math/Vector.hpp @@ -70,6 +70,19 @@ namespace nzsl constexpr bool operator==(const Vector& vec) const; constexpr bool operator!=(const Vector& vec) const; + constexpr Vector operator~() const; + constexpr Vector operator&(const Vector& vec) const; + constexpr Vector operator|(const Vector& vec) const; + constexpr Vector operator^(const Vector& vec) const; + constexpr Vector operator<<(const Vector& vec) const; + constexpr Vector operator>>(const Vector& vec) const; + + constexpr Vector operator&=(const Vector& vec); + constexpr Vector operator|=(const Vector& vec); + constexpr Vector operator^=(const Vector& vec); + constexpr Vector operator<<=(const Vector& vec); + constexpr Vector operator>>=(const Vector& vec); + static constexpr bool ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference = std::numeric_limits::epsilon()); static constexpr Vector CrossProduct(const Vector& lhs, const Vector& rhs); static T Distance(const Vector& lhs, const Vector& rhs); diff --git a/include/NZSL/Math/Vector.inl b/include/NZSL/Math/Vector.inl index 99c193f..57392a3 100644 --- a/include/NZSL/Math/Vector.inl +++ b/include/NZSL/Math/Vector.inl @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include namespace nzsl { @@ -302,6 +303,111 @@ namespace nzsl return !operator==(vec); } + template + constexpr Vector Vector::operator~() const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = ~values[i]; + + return result; + } + + template + constexpr Vector Vector::operator&(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] & vec.values[i]; + + return result; + } + + template + constexpr Vector Vector::operator|(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] | vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator^(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] ^ vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator<<(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = values[i] << vec.values[i]; + + return result; + } + + template + inline constexpr Vector Vector::operator>>(const Vector& vec) const + { + Vector result; + for (std::size_t i = 0; i < N; ++i) + result.values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]); + + return result; + } + + template + inline constexpr Vector Vector::operator&=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] &= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator|=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] |= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator^=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] ^= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator<<=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] <<= vec.values[i]; + + return this; + } + + template + inline constexpr Vector Vector::operator>>=(const Vector& vec) + { + for (std::size_t i = 0; i < N; ++i) + values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]); + + return this; + } + template constexpr bool Vector::ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference) { diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index c0892a6..ad504d8 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -5727,6 +5727,7 @@ namespace nzsl::Ast if (leftType != PrimitiveType::Int32 && leftType != PrimitiveType::UInt32) throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + TypeMustMatch(leftExprType, rightExprType, sourceLocation); return leftExprType; } @@ -5837,11 +5838,25 @@ namespace nzsl::Ast case BinaryType::BitwiseAnd: case BinaryType::BitwiseOr: case BinaryType::BitwiseXor: - case BinaryType::LogicalAnd: - case BinaryType::LogicalOr: case BinaryType::ShiftLeft: case BinaryType::ShiftRight: - throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + { + if (leftType.type != PrimitiveType::Int32 && leftType.type != PrimitiveType::UInt32) + throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + + TypeMustMatch(leftExprType, rightExprType, sourceLocation); + return leftExprType; + } + + case BinaryType::LogicalAnd: + case BinaryType::LogicalOr: + { + if (leftType.type != PrimitiveType::Boolean) + throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) }; + + TypeMustMatch(leftExprType, rightExprType, sourceLocation); + return PrimitiveType::Boolean; + } } } diff --git a/tests/src/Tests/ArithmeticTests.cpp b/tests/src/Tests/ArithmeticTests.cpp index a9515d7..05d51c6 100644 --- a/tests/src/Tests/ArithmeticTests.cpp +++ b/tests/src/Tests/ArithmeticTests.cpp @@ -296,6 +296,24 @@ fn main() let r = x ^ y; let r = x << y; let r = x >> y; + + let x = vec3[i32](0, 1, 2); + let y = vec3[i32](2, 1, 0); + + let r = x & y; + let r = x | y; + let r = x ^ y; + let r = x << y; + let r = x >> y; + + let x = vec3[u32](u32(0), u32(1), u32(2)); + let y = vec3[u32](u32(2), u32(1), u32(0)); + + let r = x & y; + let r = x | y; + let r = x ^ y; + let r = x << y; + let r = x >> y; } )"; @@ -319,6 +337,20 @@ void main() uint r_8 = x_2 ^ y_2; uint r_9 = x_2 << y_2; uint r_10 = x_2 >> y_2; + ivec3 x_3 = ivec3(0, 1, 2); + ivec3 y_3 = ivec3(2, 1, 0); + ivec3 r_11 = x_3 & y_3; + ivec3 r_12 = x_3 | y_3; + ivec3 r_13 = x_3 ^ y_3; + ivec3 r_14 = x_3 << y_3; + ivec3 r_15 = x_3 >> y_3; + uvec3 x_4 = uvec3(uint(0), uint(1), uint(2)); + uvec3 y_4 = uvec3(uint(2), uint(1), uint(0)); + uvec3 r_16 = x_4 & y_4; + uvec3 r_17 = x_4 | y_4; + uvec3 r_18 = x_4 ^ y_4; + uvec3 r_19 = x_4 << y_4; + uvec3 r_20 = x_4 >> y_4; } )"); @@ -340,6 +372,20 @@ fn main() let r: u32 = x ^ y; let r: u32 = x << y; let r: u32 = x >> y; + let x: vec3[i32] = vec3[i32](0, 1, 2); + let y: vec3[i32] = vec3[i32](2, 1, 0); + let r: vec3[i32] = x & y; + let r: vec3[i32] = x | y; + let r: vec3[i32] = x ^ y; + let r: vec3[i32] = x << y; + let r: vec3[i32] = x >> y; + let x: vec3[u32] = vec3[u32](u32(0), u32(1), u32(2)); + let y: vec3[u32] = vec3[u32](u32(2), u32(1), u32(0)); + let r: vec3[u32] = x & y; + let r: vec3[u32] = x | y; + let r: vec3[u32] = x ^ y; + let r: vec3[u32] = x << y; + let r: vec3[u32] = x >> y; } )"); @@ -1028,6 +1074,8 @@ fn main() let r = -r * +r; let r = ~42; let r = ~u32(42); + let r = ~vec3[i32](1, 2, 3); + let r = ~vec3[u32](u32(1), u32(2), u32(3)); } )"; diff --git a/tests/src/Tests/OptimizationTests.cpp b/tests/src/Tests/OptimizationTests.cpp index 37539ab..51d3954 100644 --- a/tests/src/Tests/OptimizationTests.cpp +++ b/tests/src/Tests/OptimizationTests.cpp @@ -79,8 +79,8 @@ fn main() let output7 = -42 << 10; let output8 = -42 >> 10; - let output9 = u32(1) << 10; - let output10 = u32(1024) >> 10; + let output9 = u32(1) << u32(10); + let output10 = u32(1024) >> u32(10); } )", R"( [entry(frag)] @@ -193,7 +193,6 @@ fn main() )"); } - WHEN("eliminating multiple split branches") { PropagateConstantAndExpect(R"(