Skip to content

Commit

Permalink
adding xt::detail::get_fixed_size.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mykola Vankovych committed Jul 26, 2022
1 parent a159751 commit 1dd2a56
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 21 deletions.
18 changes: 9 additions & 9 deletions include/xtensor/xadapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ namespace xt
template <layout_type L = XTENSOR_DEFAULT_LAYOUT, class C, class SC,
XTL_REQUIRES(detail::has_fixed_size<std::decay_t<SC>>,
detail::not_a_pointer<std::remove_reference_t<C>>)>
inline xtensor_adaptor<C, std::tuple_size<std::decay_t<SC>>::value, L>
inline xtensor_adaptor<C, detail::get_fixed_size<std::decay_t<SC>>::value, L>
adapt(C&& container, const SC& shape, layout_type l = L)
{
static_assert(!xtl::is_integral<SC>::value, "shape cannot be a integer");
constexpr std::size_t N = std::tuple_size<std::decay_t<SC>>::value;
constexpr std::size_t N = detail::get_fixed_size<std::decay_t<SC>>::value;
using return_type = xtensor_adaptor<xtl::closure_type_t<C>, N, L>;
return return_type(std::forward<C>(container), shape, l);
}
Expand All @@ -252,7 +252,7 @@ namespace xt
{
static_assert(!xtl::is_integral<SC>::value, "shape cannot be a integer");
using buffer_type = xbuffer_adaptor<C, xt::no_ownership, detail::default_allocator_for_ptr_t<C>>;
constexpr std::size_t N = std::tuple_size<std::decay_t<SC>>::value;
constexpr std::size_t N = detail::get_fixed_size<std::decay_t<SC>>::value;
using return_type = xtensor_adaptor<buffer_type, N, L>;
return return_type(buffer_type(pointer, compute_size(shape)), shape, l);
}
Expand All @@ -267,11 +267,11 @@ namespace xt
template <class C, class SC, class SS,
XTL_REQUIRES(detail::has_fixed_size<std::decay_t<SC>>,
detail::not_a_layout<std::decay_t<SS>>)>
inline xtensor_adaptor<C, std::tuple_size<std::decay_t<SC>>::value, layout_type::dynamic>
inline xtensor_adaptor<C, detail::get_fixed_size<std::decay_t<SC>>::value, layout_type::dynamic>
adapt(C&& container, SC&& shape, SS&& strides)
{
static_assert(!xtl::is_integral<std::decay_t<SC>>::value, "shape cannot be a integer");
constexpr std::size_t N = std::tuple_size<std::decay_t<SC>>::value;
constexpr std::size_t N = detail::get_fixed_size<std::decay_t<SC>>::value;
using return_type = xtensor_adaptor<xtl::closure_type_t<C>, N, layout_type::dynamic>;
return return_type(std::forward<C>(container),
xtl::forward_sequence<typename return_type::inner_shape_type, SC>(shape),
Expand Down Expand Up @@ -313,13 +313,13 @@ namespace xt
*/
template <layout_type L = XTENSOR_DEFAULT_LAYOUT, class P, class O, class SC, class A = detail::default_allocator_for_ptr_t<P>,
XTL_REQUIRES(detail::has_fixed_size<std::decay_t<SC>>)>
inline xtensor_adaptor<xbuffer_adaptor<xtl::closure_type_t<P>, O, A>, std::tuple_size<std::decay_t<SC>>::value, L>
inline xtensor_adaptor<xbuffer_adaptor<xtl::closure_type_t<P>, O, A>, detail::get_fixed_size<std::decay_t<SC>>::value, L>
adapt(P&& pointer, typename A::size_type size, O ownership, const SC& shape, layout_type l = L, const A& alloc = A())
{
static_assert(!xtl::is_integral<SC>::value, "shape cannot be a integer");
(void)ownership;
using buffer_type = xbuffer_adaptor<xtl::closure_type_t<P>, O, A>;
constexpr std::size_t N = std::tuple_size<std::decay_t<SC>>::value;
constexpr std::size_t N = detail::get_fixed_size<std::decay_t<SC>>::value;
using return_type = xtensor_adaptor<buffer_type, N, L>;
buffer_type buf(std::forward<P>(pointer), size, alloc);
return return_type(std::move(buf), shape, l);
Expand All @@ -339,13 +339,13 @@ namespace xt
template <class P, class O, class SC, class SS, class A = detail::default_allocator_for_ptr_t<P>,
XTL_REQUIRES(detail::has_fixed_size<std::decay_t<SC>>,
detail::not_a_layout<std::decay_t<SS>>)>
inline xtensor_adaptor<xbuffer_adaptor<xtl::closure_type_t<P>, O, A>, std::tuple_size<std::decay_t<SC>>::value, layout_type::dynamic>
inline xtensor_adaptor<xbuffer_adaptor<xtl::closure_type_t<P>, O, A>, detail::get_fixed_size<std::decay_t<SC>>::value, layout_type::dynamic>
adapt(P&& pointer, typename A::size_type size, O ownership, SC&& shape, SS&& strides, const A& alloc = A())
{
static_assert(!xtl::is_integral<std::decay_t<SC>>::value, "shape cannot be a integer");
(void)ownership;
using buffer_type = xbuffer_adaptor<xtl::closure_type_t<P>, O, A>;
constexpr std::size_t N = std::tuple_size<std::decay_t<SC>>::value;
constexpr std::size_t N = detail::get_fixed_size<std::decay_t<SC>>::value;
using return_type = xtensor_adaptor<buffer_type, N, layout_type::dynamic>;
buffer_type buf(std::forward<P>(pointer), size, alloc);
return return_type(std::move(buf),
Expand Down
2 changes: 1 addition & 1 deletion include/xtensor/xeval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ namespace xt

template <class E, layout_type L>
using as_xtensor_container_t = xtensor<typename std::decay_t<E>::value_type,
std::tuple_size<typename std::decay_t<E>::shape_type>::value,
detail::get_fixed_size<typename std::decay_t<E>::shape_type>::value,
layout_remove_any(L)>;
}

Expand Down
4 changes: 2 additions & 2 deletions include/xtensor/xfixed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ namespace xt
constexpr T get_backstrides(const S& shape, const T& strides) noexcept
{
return detail::get_backstrides_impl(shape, strides,
std::make_index_sequence<std::tuple_size<T>::value>{});
std::make_index_sequence<detail::get_fixed_size<T>::value>{});
}

template <class V, class S>
Expand Down Expand Up @@ -314,7 +314,7 @@ namespace xt
using temporary_type = typename semantic_base::temporary_type;
using expression_tag = Tag;

constexpr static std::size_t N = std::tuple_size<shape_type>::value;
constexpr static std::size_t N = detail::get_fixed_size<shape_type>::value;
constexpr static std::size_t rank = N;

xfixed_container() = default;
Expand Down
20 changes: 18 additions & 2 deletions include/xtensor/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <xtl/xtype_traits.hpp>

#include "xaccessible.hpp"
#include "xaccumulator.hpp"
#include "xexpression_traits.hpp"
#include "xiterable.hpp"
#include "xlayout.hpp"
Expand Down Expand Up @@ -913,7 +914,7 @@ namespace xt
// Optimization: no need to compare each subiterator since they all
// are incremented decremented together.
constexpr std::size_t temp = xtl::mpl::find_if<is_not_xdummy_iterator, data_type>::value;
constexpr std::size_t index = (temp == std::tuple_size<data_type>::value) ? 0 : temp;
constexpr std::size_t index = (temp == detail::get_fixed_size<data_type>::value) ? 0 : temp;
return std::get<index>(m_it) == std::get<index>(rhs.m_it);
}

Expand All @@ -923,7 +924,7 @@ namespace xt
// Optimization: no need to compare each subiterator since they all
// are incremented decremented together.
constexpr std::size_t temp = xtl::mpl::find_if<is_not_xdummy_iterator, data_type>::value;
constexpr std::size_t index = (temp == std::tuple_size<data_type>::value) ? 0 : temp;
constexpr std::size_t index = (temp == detail::get_fixed_size<data_type>::value) ? 0 : temp;
return std::get<index>(m_it) < std::get<index>(rhs.m_it);
}

Expand Down Expand Up @@ -1059,6 +1060,21 @@ namespace xt
auto step_leading_lambda = [](auto&& st) { st.step_leading(); };
for_each(step_leading_lambda, m_st);
}

namespace detail
{
template<class F, class... CT>
struct has_fixed_size<xfunction<F, CT...>, std::enable_if_t<xt::detail::is_fixed<typename xt::xfunction<F, CT...>::shape_type>::value>>
: std::true_type
{
};

template<class F, class... CT>
struct get_fixed_size<xfunction<F, CT...>, std::enable_if_t<xt::detail::is_fixed<typename xt::xfunction<F, CT...>::shape_type>::value>>
: std::integral_constant<std::size_t, fixed_compute_size<typename xt::xfunction<F, CT...>::shape_type>::value>
{
};
}
}

#endif
6 changes: 3 additions & 3 deletions include/xtensor/xreducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ namespace xt

reducer_options(const T& tpl)
{
xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>([this, &tpl](auto no_compile) {
xtl::mpl::static_if<initial_val_idx != detail::get_fixed_size<T>::value>([this, &tpl](auto no_compile) {
// use no_compile to prevent compilation if initial_val_idx is out of bounds!
this->initial_value = no_compile(std::get<initial_val_idx != std::tuple_size<T>::value ? initial_val_idx : 0>(tpl)).value();
this->initial_value = no_compile(std::get<initial_val_idx != detail::get_fixed_size<T>::value ? initial_val_idx : 0>(tpl)).value();
},
[](auto /*np_compile*/){}
);
Expand All @@ -133,7 +133,7 @@ namespace xt
std::true_type,
std::false_type>;

constexpr static bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
constexpr static bool has_initial_value = initial_val_idx != detail::get_fixed_size<d_t>::value;

R initial_value;

Expand Down
22 changes: 18 additions & 4 deletions include/xtensor/xshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

namespace xt
{
namespace detail
{
template <class E, class Enable = void>
struct get_fixed_size;
}

template <class T>
using dynamic_shape = svector<T, 4>;

Expand Down Expand Up @@ -231,9 +237,9 @@ namespace xt
};

template <class T>
struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
struct static_dimension_impl<T, void_t<decltype(detail::get_fixed_size<T>::value)>>
{
static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value);
static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(detail::get_fixed_size<T>::value);
};
}

Expand Down Expand Up @@ -281,7 +287,7 @@ namespace xt
};

template <class T, class... Ts>
struct max_array_size<T, Ts...> : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
struct max_array_size<T, Ts...> : std::integral_constant<std::size_t, imax(detail::get_fixed_size<T>::value, max_array_size<Ts...>::value)>
{
};

Expand Down Expand Up @@ -375,7 +381,7 @@ namespace xt
static constexpr bool value = true;
};

template <class E, class = void>
template <class E, class Enable = void>
struct has_fixed_size : std::false_type
{
};
Expand All @@ -386,6 +392,14 @@ namespace xt
{
};

template <class E, class Enable>
struct get_fixed_size;

template <class E>
struct get_fixed_size<E, std::void_t<decltype(std::tuple_size<E>::value)>> : std::integral_constant<std::size_t, std::tuple_size<E>::value>
{
};

template <class... S>
using only_array = xtl::conjunction<xtl::disjunction<is_array<S>, is_fixed<S>, has_fixed_size<S>>...>;

Expand Down

0 comments on commit 1dd2a56

Please sign in to comment.