Skip to content

Commit

Permalink
Merge pull request #181 from eschnett/eschnett/vect
Browse files Browse the repository at this point in the history
Arith: Correct type conversions for vector/scalar binary operations
  • Loading branch information
eschnett authored Jul 26, 2023
2 parents 5abb23e + 5c6face commit 2604a4b
Showing 1 changed file with 102 additions and 56 deletions.
158 changes: 102 additions & 56 deletions Arith/src/vect.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ construct_array(const F &f) {
// arithmetic operations, which is most useful for multi-dimensional
// array indices.

template <typename T, int D> struct vect;

template <typename T> struct is_vect : std::false_type {};
template <typename T, int D> struct is_vect<vect<T, D> > : std::true_type {};
template <typename T> constexpr bool is_vect_v = is_vect<T>::value;

template <typename T, int D> struct vect {
array<T, D> elts;

Expand Down Expand Up @@ -322,112 +328,152 @@ template <typename T, int D> struct vect {
y);
}

friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator+(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() + std::declval<T>()), D>
operator+(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a + b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator-(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() - std::declval<T>()), D>
operator-(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a - b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator*(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() - std::declval<T>()), D>
operator*(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a * b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator/(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() / std::declval<T>()), D>
operator/(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a / b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator%(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() % std::declval<T>()), D>
operator%(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a % b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
div_floor(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(div_floor(std::declval<U>(), std::declval<T>())), D>
div_floor(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return div_floor(a, b); }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
mod_floor(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(mod_floor(std::declval<U>(), std::declval<T>())), D>
mod_floor(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return mod_floor(a, b); }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator&(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() & std::declval<T>()), D>
operator&(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a & b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator|(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() | std::declval<T>()), D>
operator|(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a | b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator^(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() ^ std::declval<T>()), D>
operator^(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a ^ b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator<<(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() << std::declval<T>()), D>
operator<<(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a << b; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator>>(const T &a, const vect &x) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<U>() >> std::declval<T>()), D>
operator>>(const U &a, const vect &x) {
return fmap([&](const T &b) ARITH_INLINE { return a >> b; }, x);
}

friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator+(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() + std::declval<U>()), D>
operator+(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b + a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator-(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() - std::declval<U>()), D>
operator-(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b - a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator*(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() * std::declval<U>()), D>
operator*(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b * a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator/(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() / std::declval<U>()), D>
operator/(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b / a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator%(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() % std::declval<U>()), D>
operator%(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b % a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
div_floor(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(div_floor(std::declval<T>(), std::declval<U>())), D>
div_floor(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return div_floor(b, a); }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
mod_floor(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(mod_floor(std::declval<T>(), std::declval<U>())), D>
mod_floor(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return mod_floor(b, a); }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator&(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() & std::declval<U>()), D>
operator&(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b & a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator|(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() | std::declval<U>()), D>
operator|(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b | a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator^(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() ^ std::declval<U>()), D>
operator^(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b ^ a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator<<(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() << std::declval<U>()), D>
operator<<(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b << a; }, x);
}
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator>>(const vect &x, const T &a) {
template <typename U, std::enable_if_t<!is_vect_v<U> > * = nullptr>
friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST
vect<decltype(std::declval<T>() >> std::declval<U>()), D>
operator>>(const vect &x, const U &a) {
return fmap([&](const T &b) ARITH_INLINE { return b >> a; }, x);
}

constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator+=(const vect &x) {
return *this = *this + x;
}
constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator-=(const vect &x) {
return *this = *this - x;
}
template <typename U>
constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect
operator+=(const vect<U, D> &x) {
Expand Down

0 comments on commit 2604a4b

Please sign in to comment.