Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bit operators on vectors of integers (#34) #56

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
13 changes: 13 additions & 0 deletions include/NZSL/Math/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ namespace nzsl
constexpr bool operator==(const Vector& vec) const;
constexpr bool operator!=(const Vector& vec) const;

constexpr Vector operator~() const;
constexpr Vector operator&(const Vector& vec) const;
constexpr Vector operator|(const Vector& vec) const;
constexpr Vector operator^(const Vector& vec) const;
constexpr Vector operator<<(const Vector& vec) const;
constexpr Vector operator>>(const Vector& vec) const;

constexpr Vector operator&=(const Vector& vec);
constexpr Vector operator|=(const Vector& vec);
constexpr Vector operator^=(const Vector& vec);
constexpr Vector operator<<=(const Vector& vec);
constexpr Vector operator>>=(const Vector& vec);

static constexpr bool ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference = std::numeric_limits<T>::epsilon());
static constexpr Vector CrossProduct(const Vector& lhs, const Vector& rhs);
static T Distance(const Vector& lhs, const Vector& rhs);
Expand Down
106 changes: 106 additions & 0 deletions include/NZSL/Math/Vector.inl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp

#include <NazaraUtils/Hash.hpp>
#include <NazaraUtils/MathUtils.hpp>

namespace nzsl
{
Expand Down Expand Up @@ -302,6 +303,111 @@ namespace nzsl
return !operator==(vec);
}

template<typename T, std::size_t N>
constexpr Vector<T, N> Vector<T, N>::operator~() const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = ~values[i];

return result;
}

template<typename T, std::size_t N>
constexpr Vector<T, N> Vector<T, N>::operator&(const Vector<T, N>& vec) const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = values[i] & vec.values[i];

return result;
}

template<typename T, std::size_t N>
constexpr Vector<T, N> Vector<T, N>::operator|(const Vector& vec) const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = values[i] | vec.values[i];

return result;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator^(const Vector& vec) const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = values[i] ^ vec.values[i];

return result;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator<<(const Vector& vec) const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = values[i] << vec.values[i];

return result;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator>>(const Vector& vec) const
{
Vector<T, N> result;
for (std::size_t i = 0; i < N; ++i)
result.values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]);

return result;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator&=(const Vector& vec)
{
for (std::size_t i = 0; i < N; ++i)
values[i] &= vec.values[i];

return this;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator|=(const Vector& vec)
{
for (std::size_t i = 0; i < N; ++i)
values[i] |= vec.values[i];

return this;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator^=(const Vector& vec)
{
for (std::size_t i = 0; i < N; ++i)
values[i] ^= vec.values[i];

return this;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator<<=(const Vector& vec)
{
for (std::size_t i = 0; i < N; ++i)
values[i] <<= vec.values[i];

return this;
}

template<typename T, std::size_t N>
inline constexpr Vector<T, N> Vector<T, N>::operator>>=(const Vector& vec)
{
for (std::size_t i = 0; i < N; ++i)
values[i] = Nz::ArithmeticRightShift(values[i], vec.values[i]);

return this;
}

template<typename T, std::size_t N>
constexpr bool Vector<T, N>::ApproxEqual(const Vector& lhs, const Vector& rhs, T maxDifference)
{
Expand Down
21 changes: 18 additions & 3 deletions src/NZSL/Ast/SanitizeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5727,6 +5727,7 @@ namespace nzsl::Ast
if (leftType != PrimitiveType::Int32 && leftType != PrimitiveType::UInt32)
throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) };

TypeMustMatch(leftExprType, rightExprType, sourceLocation);
return leftExprType;
}

Expand Down Expand Up @@ -5837,11 +5838,25 @@ namespace nzsl::Ast
case BinaryType::BitwiseAnd:
case BinaryType::BitwiseOr:
case BinaryType::BitwiseXor:
case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
case BinaryType::ShiftLeft:
case BinaryType::ShiftRight:
throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) };
{
if (leftType.type != PrimitiveType::Int32 && leftType.type != PrimitiveType::UInt32)
throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) };

TypeMustMatch(leftExprType, rightExprType, sourceLocation);
return leftExprType;
}

case BinaryType::LogicalAnd:
case BinaryType::LogicalOr:
{
if (leftType.type != PrimitiveType::Boolean)
throw CompilerBinaryUnsupportedError{ sourceLocation, "left", ToString(leftExprType, sourceLocation) };

TypeMustMatch(leftExprType, rightExprType, sourceLocation);
return PrimitiveType::Boolean;
}
}
}

Expand Down
48 changes: 48 additions & 0 deletions tests/src/Tests/ArithmeticTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,24 @@ fn main()
let r = x ^ y;
let r = x << y;
let r = x >> y;

let x = vec3[i32](0, 1, 2);
let y = vec3[i32](2, 1, 0);

let r = x & y;
let r = x | y;
let r = x ^ y;
let r = x << y;
let r = x >> y;

let x = vec3[u32](u32(0), u32(1), u32(2));
let y = vec3[u32](u32(2), u32(1), u32(0));

let r = x & y;
let r = x | y;
let r = x ^ y;
let r = x << y;
let r = x >> y;
}
)";

Expand All @@ -319,6 +337,20 @@ void main()
uint r_8 = x_2 ^ y_2;
uint r_9 = x_2 << y_2;
uint r_10 = x_2 >> y_2;
ivec3 x_3 = ivec3(0, 1, 2);
ivec3 y_3 = ivec3(2, 1, 0);
ivec3 r_11 = x_3 & y_3;
ivec3 r_12 = x_3 | y_3;
ivec3 r_13 = x_3 ^ y_3;
ivec3 r_14 = x_3 << y_3;
ivec3 r_15 = x_3 >> y_3;
uvec3 x_4 = uvec3(uint(0), uint(1), uint(2));
uvec3 y_4 = uvec3(uint(2), uint(1), uint(0));
uvec3 r_16 = x_4 & y_4;
uvec3 r_17 = x_4 | y_4;
uvec3 r_18 = x_4 ^ y_4;
uvec3 r_19 = x_4 << y_4;
uvec3 r_20 = x_4 >> y_4;
}
)");

Expand All @@ -340,6 +372,20 @@ fn main()
let r: u32 = x ^ y;
let r: u32 = x << y;
let r: u32 = x >> y;
let x: vec3[i32] = vec3[i32](0, 1, 2);
let y: vec3[i32] = vec3[i32](2, 1, 0);
let r: vec3[i32] = x & y;
let r: vec3[i32] = x | y;
let r: vec3[i32] = x ^ y;
let r: vec3[i32] = x << y;
let r: vec3[i32] = x >> y;
let x: vec3[u32] = vec3[u32](u32(0), u32(1), u32(2));
let y: vec3[u32] = vec3[u32](u32(2), u32(1), u32(0));
let r: vec3[u32] = x & y;
let r: vec3[u32] = x | y;
let r: vec3[u32] = x ^ y;
let r: vec3[u32] = x << y;
let r: vec3[u32] = x >> y;
}
)");

Expand Down Expand Up @@ -1028,6 +1074,8 @@ fn main()
let r = -r * +r;
let r = ~42;
let r = ~u32(42);
let r = ~vec3[i32](1, 2, 3);
let r = ~vec3[u32](u32(1), u32(2), u32(3));
}
)";

Expand Down
5 changes: 2 additions & 3 deletions tests/src/Tests/OptimizationTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ fn main()

let output7 = -42 << 10;
let output8 = -42 >> 10;
let output9 = u32(1) << 10;
let output10 = u32(1024) >> 10;
let output9 = u32(1) << u32(10);
let output10 = u32(1024) >> u32(10);
}
)", R"(
[entry(frag)]
Expand Down Expand Up @@ -193,7 +193,6 @@ fn main()
)");
}


WHEN("eliminating multiple split branches")
{
PropagateConstantAndExpect(R"(
Expand Down
Loading