From fb3c0fa14e9c73f04527d33462e65a09050ef318 Mon Sep 17 00:00:00 2001 From: Cem Bassoy Date: Fri, 1 Nov 2024 12:51:23 +0100 Subject: [PATCH] refactor(ttm): add namespace ttm --- example/interface1.cpp | 7 +- example/interface2.cpp | 15 +- example/interface3.cpp | 19 +- example/measure.cpp | 33 ++-- include/tlib/detail/cases.h | 22 +-- include/tlib/detail/index.h | 4 +- include/tlib/detail/layout.h | 12 +- include/tlib/detail/mtm.h | 4 +- include/tlib/detail/shape.h | 4 +- include/tlib/detail/strides.h | 4 +- include/tlib/detail/tags.h | 23 +-- include/tlib/detail/tensor.h | 8 +- include/tlib/detail/ttm.h | 88 +++++---- include/tlib/ttm.h | 16 +- test/src/gtest_tlib_layout.cpp | 309 ++++++++++++++++---------------- test/src/gtest_tlib_mtm.cpp | 213 +++++++++++----------- test/src/gtest_tlib_shape.cpp | 104 +++++------ test/src/gtest_tlib_strides.cpp | 74 ++++---- test/src/gtest_tlib_ttm.cpp | 235 ++++++++++++------------ ttmpy/src/wrapped_ttm.cpp | 43 +++-- 20 files changed, 600 insertions(+), 637 deletions(-) diff --git a/example/interface1.cpp b/example/interface1.cpp index c5c73e5..cad641d 100644 --- a/example/interface1.cpp +++ b/example/interface1.cpp @@ -4,11 +4,12 @@ #include #include +using namespace tlib::ttm; int main() { using value_t = float; - using tensor_t = tlib::tensor; // or std::array + using tensor_t = tensor; // or std::array using shape_t = typename tensor_t::shape_t; // shape tuple for A @@ -25,8 +26,8 @@ int main() auto pb = nb.size(); // layout tuple for A and C - auto pia = tlib::detail::generate_k_order_layout(pa,1ul); - auto pib = tlib::detail::generate_k_order_layout(pb,1ul); + auto pia = detail::generate_k_order_layout(pa,1ul); + auto pib = detail::generate_k_order_layout(pb,1ul); auto A = tensor_t( na, pia ); auto B = tensor_t( nb, pib ); diff --git a/example/interface2.cpp b/example/interface2.cpp index 1c73ef3..8b97463 100644 --- a/example/interface2.cpp +++ b/example/interface2.cpp @@ -4,11 +4,12 @@ #include #include +using namespace tlib::ttm; int main() { using value_t = float; - using tensor_t = tlib::tensor; // or std::array + using tensor_t = tensor; // or std::array using shape_t = typename tensor_t::shape_t; // shape tuple for A @@ -25,8 +26,8 @@ int main() auto pb = nb.size(); // layout tuple for A and C - auto pia = tlib::detail::generate_k_order_layout(pa,1ul); - auto pib = tlib::detail::generate_k_order_layout(pb,1ul); + auto pia = detail::generate_k_order_layout(pa,1ul); + auto pib = detail::generate_k_order_layout(pb,1ul); auto A = tensor_t( na, pia ); auto B = tensor_t( nb, pib ); @@ -55,10 +56,10 @@ int main() // correct shape, layout and strides of the output tensors C1,C2 are automatically computed and returned by the functions. - auto C1 = tlib::ttm(q, A,B, tlib::parallel_policy::parallel_blas , tlib::slicing_policy::slice, tlib::fusion_policy::none ); - auto C2 = tlib::ttm(q, A,B, tlib::parallel_policy::parallel_loop , tlib::slicing_policy::slice, tlib::fusion_policy::all ); - auto C3 = tlib::ttm(q, A,B, tlib::parallel_policy::parallel_loop , tlib::slicing_policy::subtensor, tlib::fusion_policy::all ); - auto C4 = tlib::ttm(q, A,B, tlib::parallel_policy::batched_gemm , tlib::slicing_policy::subtensor, tlib::fusion_policy::all ); + auto C1 = ttm(q, A,B, parallel_policy::parallel_blas , slicing_policy::slice, fusion_policy::none ); + auto C2 = ttm(q, A,B, parallel_policy::parallel_loop , slicing_policy::slice, fusion_policy::all ); + auto C3 = ttm(q, A,B, parallel_policy::parallel_loop , slicing_policy::subtensor, fusion_policy::all ); + auto C4 = ttm(q, A,B, parallel_policy::batched_gemm , slicing_policy::subtensor, fusion_policy::all ); std::cout << "C1 = " << C1 << std::endl; diff --git a/example/interface3.cpp b/example/interface3.cpp index c39688c..84dae66 100644 --- a/example/interface3.cpp +++ b/example/interface3.cpp @@ -4,13 +4,14 @@ #include #include +using namespace tlib::ttm; int main() { using value_t = float; using size_t = std::size_t; using tensor_t = std::vector; // or std::array - using shape_t = std::vector; + using shape_t = std::vector; using iterator_t = std::ostream_iterator; auto na = shape_t{4,3,2}; // input shape tuple @@ -18,9 +19,9 @@ int main() auto k = 1ul; // k-order of input tensor auto q = 2ul; - auto pia = tlib::detail::generate_k_order_layout(p,k); // layout tuple of input tensor - here {1,2,3}; - auto wa = tlib::detail::generate_strides(na,pia); // stride tuple of input tensor - here {1,4,12}; - auto nna = std::accumulate(na.begin(),na.end(),1ul,std::multiplies<>()); // number of elements of input tensor + auto pia = detail::generate_k_order_layout(p,k); // layout tuple of input tensor - here {1,2,3}; + auto wa = detail::generate_strides(na,pia); // stride tuple of input tensor - here {1,4,12}; + auto nna = std::accumulate(na.begin(),na.end(),1ul,std::multiplies<>()); // number of elements of input tensor auto pib = shape_t{1,2}; auto nb = shape_t{na[q-1]+1,na[q-1]}; @@ -29,7 +30,7 @@ int main() auto nc = na; nc[q-1] = nb[0]; auto pic = pia; - auto wc = tlib::detail::generate_strides(nc,pic); + auto wc = detail::generate_strides(nc,pic); auto nnc = std::accumulate(nc.begin(),nc.end(),1ul,std::multiplies<>()); // number of elements of input tensor @@ -43,15 +44,15 @@ int main() std::cout << "A = [ "; std::copy(A.begin(), A.end(), iterator_t(std::cout, " ")); std::cout << " ];" << std::endl; std::cout << "B = [ "; std::copy(B.begin(), B.end(), iterator_t(std::cout, " ")); std::cout << " ];" << std::endl; - tlib::ttm( - tlib::parallel_policy::parallel_blas , tlib::slicing_policy::slice, tlib::fusion_policy::none, + ttm( + parallel_policy::parallel_blas , slicing_policy::slice, fusion_policy::none, q, p, A.data(), na.data(), wa.data(), pia.data(), B.data(), nb.data(), pib.data(), C1.data(), nc.data(), wc.data()); - tlib::ttm( - tlib::parallel_policy::parallel_loop, tlib::slicing_policy::subtensor, tlib::fusion_policy::all, + ttm( + parallel_policy::parallel_loop, slicing_policy::subtensor, fusion_policy::all, q, p, A.data(), na.data(), wa.data(), pia.data(), B.data(), nb.data(), pib.data(), diff --git a/example/measure.cpp b/example/measure.cpp index 0b973ec..37bb6b5 100644 --- a/example/measure.cpp +++ b/example/measure.cpp @@ -6,6 +6,8 @@ #include #include // for high precision timing +using namespace tlib::ttm; + static const auto gdims = std::string("abcdefghij"); inline @@ -72,9 +74,9 @@ get_gflops(double nn, double cdimc, double cdima) template inline void measure(unsigned q, - tlib::tensor const& A, - tlib::tensor const& B, - tlib::tensor& C, + tensor const& A, + tensor const& B, + tensor& C, parallel_policy pp, slicing_policy sp, fusion_policy fp) @@ -87,8 +89,7 @@ inline void measure(unsigned q, for(auto i = 0u; i < iters; ++i){ std::fill(cache.begin(), cache.end(),char{}); auto start = std::chrono::high_resolution_clock::now(); - tlib::ttm( - pp, sp, fp, + ttm(pp, sp, fp, q, A.order(), A.data().data(), A.shape().data(), A.strides().data(), A.layout().data(), B.data().data(), B.shape().data(), B.layout().data(), @@ -109,7 +110,7 @@ inline void measure(unsigned q, std::cout << "Time : " << avg_time_s << " [s]" << std::endl; std::cout << "Gflops : " << gflops << " [gflops]" << std::endl; std::cout << "Performance : " << gflops/avg_time_s << " [gflops/s]" << std::endl; - std::cout << "Performance : " << gflops/avg_time_s/tlib::detail::cores << " [gflops/s/core]" << std::endl; + std::cout << "Performance : " << gflops/avg_time_s/detail::cores << " [gflops/s/core]" << std::endl; } @@ -122,7 +123,7 @@ int main(int argc, char* argv[]) { using value = double; - using tensor = tlib::tensor; // or std::array + using tensor = tensor; // or std::array using shape = typename tensor::shape_t; assert(argc > 4); @@ -159,9 +160,9 @@ int main(int argc, char* argv[]) const auto pc = pa; // layout tuple for A and C - const auto pia = tlib::detail::generate_k_order_layout(pa,1ul); - const auto pib = tlib::detail::generate_k_order_layout(pb,1ul); - const auto pic = tlib::detail::generate_k_order_layout(pc,1ul); + const auto pia = detail::generate_k_order_layout(pa,1ul); + const auto pib = detail::generate_k_order_layout(pb,1ul); + const auto pic = detail::generate_k_order_layout(pc,1ul); auto A = tensor( na, pia ); auto B = tensor( nb, pib ); @@ -172,37 +173,37 @@ int main(int argc, char* argv[]) if(method == 1 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_loop, tlib::slicing_policy::slice, tlib::fusion_policy::all ); + measure(q, A, B, C, parallel_policy::parallel_loop, slicing_policy::slice, fusion_policy::all ); std::cout << "---------" << std::endl << std::endl; } if(method == 2 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_loop, tlib::slicing_policy::subtensor, tlib::fusion_policy::all ); + measure(q, A, B, C, parallel_policy::parallel_loop, slicing_policy::subtensor, fusion_policy::all ); std::cout << "---------" << std::endl << std::endl; } if(method == 3 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_blas, tlib::slicing_policy::slice, tlib::fusion_policy::none ); + measure(q, A, B, C, parallel_policy::parallel_blas, slicing_policy::slice, fusion_policy::none ); std::cout << "---------" << std::endl << std::endl; } if(method == 4 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_blas, tlib::slicing_policy::slice, tlib::fusion_policy::all ); + measure(q, A, B, C, parallel_policy::parallel_blas, slicing_policy::slice, fusion_policy::all ); std::cout << "---------" << std::endl << std::endl; } if(method == 5 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_blas, tlib::slicing_policy::subtensor, tlib::fusion_policy::none ); + measure(q, A, B, C, parallel_policy::parallel_blas, slicing_policy::subtensor, fusion_policy::none ); std::cout << "---------" << std::endl << std::endl; } if(method == 6 || method == 7){ std::cout << "Algorithm: " << std::endl; - measure(q, A, B, C, tlib::parallel_policy::parallel_blas, tlib::slicing_policy::subtensor, tlib::fusion_policy::all ); + measure(q, A, B, C, parallel_policy::parallel_blas, slicing_policy::subtensor, fusion_policy::all ); std::cout << "---------" << std::endl << std::endl; } diff --git a/include/tlib/detail/cases.h b/include/tlib/detail/cases.h index 877c5fa..6da9942 100644 --- a/include/tlib/detail/cases.h +++ b/include/tlib/detail/cases.h @@ -19,13 +19,13 @@ #include -namespace tlib::detail{ +namespace tlib::ttm::detail{ template inline constexpr bool is_case(unsigned p, std::size_t q, std::size_t const*const pi) { - static_assert(case_nr > 0u || case_nr < 9u, "tlib::detail::is_case: only 8 cases from 1 to 8 are covered."); + static_assert(case_nr > 0u || case_nr < 9u, "tlib::ttm::detail::is_case: only 8 cases from 1 to 8 are covered."); if constexpr (case_nr == 1u) return p==1u; if constexpr (case_nr == 2u) return p==2u && q == 1u && pi[0] == 1u; if constexpr (case_nr == 3u) return p==2u && q == 2u && pi[0] == 1u; @@ -36,20 +36,4 @@ inline constexpr bool is_case(unsigned p, std::size_t q, std::size_t const*const if constexpr (case_nr == 8u) return p>=3u && !(is_case<6u>(p,q,pi)||is_case<7u>(p,q,pi)); } - -//// assume that the input matrix (2nd argument) with a column-major format -//template -//inline constexpr bool is_case(unsigned p, std::size_t q, std::size_t const*const pi) -//{ -// static_assert(case_nr > 0u || case_nr < 9u, "tlib::detail::is_case: only 8 cases from 1 to 8 are covered."); -// if constexpr (case_nr == 1u) return p==1u; -// if constexpr (case_nr == 2u) return p==2u && q == 1u && pi[0] == 1u; -// if constexpr (case_nr == 3u) return p==2u && q == 2u && pi[0] == 1u; -// if constexpr (case_nr == 4u) return p==2u && q == 1u && pi[0] == 2u; -// if constexpr (case_nr == 5u) return p==2u && q == 2u && pi[0] == 2u; -// if constexpr (case_nr == 6u) return p>=3u && pi[0] == q; -// if constexpr (case_nr == 7u) return p>=3u && pi[p-1] == q; -// if constexpr (case_nr == 8u) return p>=3u && !(is_case<6u>(p,q,pi)||is_case<7u>(p,q,pi)); -//} - -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/index.h b/include/tlib/detail/index.h index 30d84b5..4bae519 100644 --- a/include/tlib/detail/index.h +++ b/include/tlib/detail/index.h @@ -17,7 +17,7 @@ #pragma once -namespace tlib::detail +namespace tlib::ttm::detail { @@ -126,4 +126,4 @@ constexpr auto at_at_1(size_type const j_view, container_type const& w_view, con -} // namespace detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/layout.h b/include/tlib/detail/layout.h index 0b63d41..4c00afd 100644 --- a/include/tlib/detail/layout.h +++ b/include/tlib/detail/layout.h @@ -23,7 +23,7 @@ -namespace tlib::detail +namespace tlib::ttm::detail { template @@ -60,7 +60,7 @@ inline void compute_k_order_layout(OutputIt begin, OutputIt end, size_t k) auto const n_signed = std::distance(begin,end); if(n_signed <= 0) - throw std::runtime_error("Error in tlib::detail::compute_k_order: range provided by begin and end not correct!"); + throw std::runtime_error("Error in tlib::ttm::detail::compute_k_order: range provided by begin and end not correct!"); auto const n = static_cast>(n_signed); assert(n > 0); @@ -122,16 +122,16 @@ inline auto inverse_mode(InputIt layout_begin, InputIt layout_end, SizeType mode { using value_type = typename std::iterator_traits::value_type; if(!is_valid_layout(layout_begin,layout_end)) - throw std::runtime_error("Error in tlib::detail::inverse_mode(): input layout is not valid."); + throw std::runtime_error("Error in tlib::ttm::detail::inverse_mode(): input layout is not valid."); auto const p_ = std::distance(layout_begin,layout_end); if(p_<= 0) - throw std::runtime_error("Error in tlib::detail::inverse_mode(): input layout is invalid."); + throw std::runtime_error("Error in tlib::ttm::detail::inverse_mode(): input layout is invalid."); auto const p = static_cast(p_); if(mode==0u || mode > SizeType(p)) - throw std::runtime_error("Error in tlib::detail::inverse_mode(): mode should be one-based and equal to or less than layout size."); + throw std::runtime_error("Error in tlib::ttm::detail::inverse_mode(): mode should be one-based and equal to or less than layout size."); auto inverse_mode = value_type{0u}; for(; inverse_mode < p; ++inverse_mode) @@ -146,4 +146,4 @@ inline auto inverse_mode(InputIt layout_begin, InputIt layout_end, SizeType mode -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/mtm.h b/include/tlib/detail/mtm.h index 5b8694d..7fbaca3 100644 --- a/include/tlib/detail/mtm.h +++ b/include/tlib/detail/mtm.h @@ -46,7 +46,7 @@ -namespace tlib::detail { +namespace tlib::ttm::detail { struct cblas_layout {}; @@ -223,4 +223,4 @@ inline void mtm_cm(unsigned const q, unsigned const p, } -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/shape.h b/include/tlib/detail/shape.h index 479d7ef..7cf56d3 100644 --- a/include/tlib/detail/shape.h +++ b/include/tlib/detail/shape.h @@ -25,7 +25,7 @@ #include -namespace tlib::detail +namespace tlib::ttm::detail { template @@ -93,4 +93,4 @@ inline bool is_tensor(InputIt begin, InputIt end) -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/strides.h b/include/tlib/detail/strides.h index e36c373..447643a 100644 --- a/include/tlib/detail/strides.h +++ b/include/tlib/detail/strides.h @@ -25,7 +25,7 @@ #include "layout.h" -namespace tlib::detail +namespace tlib::ttm::detail { template @@ -100,4 +100,4 @@ inline bool is_valid_strides(InputIt1 layout_begin, InputIt1 layout_end, InputIt // [stride_begin]( auto l ) {return stride_begin[l-2] > stride_begin[l-1];} ); } -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/detail/tags.h b/include/tlib/detail/tags.h index 7eb7d5b..5954623 100644 --- a/include/tlib/detail/tags.h +++ b/include/tlib/detail/tags.h @@ -19,24 +19,9 @@ // ttm -namespace tlib::parallel_policy +namespace tlib::ttm::parallel_policy { -/* -struct sequential_t {}; // multithreaded gemm -struct threaded_gemm_t {}; // multithreaded gemm -struct omp_taskloop_t {}; // omp_taskloops with single threaded gemm -struct omp_forloop_t {}; // omp_for with single threaded gemm -struct omp_forloop_and_threaded_gemm_t {}; // omp_for with multi-threaded gemm -struct batched_gemm_t {}; // multithreaded batched gemm with collapsed loops -struct combined_t {}; - -inline constexpr sequential_t sequential; -inline constexpr threaded_gemm_t threaded_gemm; -inline constexpr omp_taskloop_t omp_taskloop; -inline constexpr omp_forloop_t omp_forloop; -inline constexpr batched_gemm_t batched_gemm; -inline constexpr combined_t combined; -*/ + struct sequential_t {}; // sequential loops and sequential gemm struct parallel_blas_t {}; // multithreaded gemm struct parallel_loop_t {}; // omp_for with single threaded gemm @@ -54,7 +39,7 @@ inline constexpr combined_t combined; } -namespace tlib::slicing_policy +namespace tlib::ttm::slicing_policy { struct slice_t {}; struct subtensor_t {}; @@ -68,7 +53,7 @@ inline constexpr subtensor_t subtensor; } -namespace tlib::fusion_policy +namespace tlib::ttm::fusion_policy { struct none_t {}; struct outer_t {}; diff --git a/include/tlib/detail/tensor.h b/include/tlib/detail/tensor.h index 292b0dd..ffc38ca 100644 --- a/include/tlib/detail/tensor.h +++ b/include/tlib/detail/tensor.h @@ -28,7 +28,7 @@ -namespace tlib +namespace tlib::ttm { template @@ -120,14 +120,14 @@ class tensor vector_t _data; }; -} +} // namespace tlib::ttm template -void stream_out(std::ostream& out, tlib::tensor const& a, std::size_t j, unsigned r) +void stream_out(std::ostream& out, tlib::ttm::tensor const& a, std::size_t j, unsigned r) { const auto& w = a.strides(); const auto& n = a.shape(); @@ -157,7 +157,7 @@ void stream_out(std::ostream& out, tlib::tensor const& a, std::size_ template -std::ostream& operator<< (std::ostream& out, tlib::tensor const& a) +std::ostream& operator<< (std::ostream& out, tlib::ttm::tensor const& a) { const auto& w = a.strides(); const auto& n = a.shape(); diff --git a/include/tlib/detail/ttm.h b/include/tlib/detail/ttm.h index 8c2232f..6ceeada 100644 --- a/include/tlib/detail/ttm.h +++ b/include/tlib/detail/ttm.h @@ -55,7 +55,7 @@ #include #include -namespace tlib::detail{ +namespace tlib::ttm::detail{ static inline unsigned get_number_cores() { @@ -145,8 +145,6 @@ inline void set_omp_threads_max() { #ifdef _OPENMP omp_set_num_threads(cores); -#else - return 1; #endif } @@ -311,7 +309,7 @@ inline void ttm( mtm_rm(q, p, a, na, pia, b, nb, c, nc ); } else { - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = inverse_mode(pia, pia+p, q); using namespace std::placeholders; @@ -320,8 +318,8 @@ inline void ttm( auto nq = na[q-1]; auto wq = wa[q-1]; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c if(is_cm) loops_over_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc); else loops_over_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc); @@ -351,7 +349,7 @@ inline void ttm( mtm_rm(q, p, a, na, pia, b, nb, c, nc ); } else { - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); using namespace std::placeholders; @@ -360,8 +358,8 @@ inline void ttm( auto nq = na[q-1]; auto wq = wa[q-1]; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c if(is_cm) loops_over_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc); else loops_over_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc); @@ -394,7 +392,7 @@ inline void ttm( assert(p>2); assert(q>0); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // outer = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const outer = product(na, pia, qh+1,p+1); @@ -417,8 +415,8 @@ inline void ttm( //std::cout << "m=" << m << ", n=" << n1 << ", k=" << nq << std::endl; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c //std::cout << "Get blas threads: " << get_blas_threads() << std::endl; @@ -463,7 +461,7 @@ inline void ttm( assert(p>2); assert(q>0); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const num = product(na, pia, qh+1,p+1); @@ -479,8 +477,8 @@ inline void ttm( auto nq = na[q-1]; auto wq = wa[q-1]; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq,m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq,m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c #pragma omp parallel for schedule(static) num_threads(cores) proc_bind(spread) for(size_t k = 0u; k < num; ++k){ @@ -526,7 +524,7 @@ inline void ttm( assert(p>2); assert(q>0); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // outer = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const outer = product(na, pia, qh+1,p+1); @@ -547,8 +545,8 @@ inline void ttm( auto wq = wa[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c #pragma omp parallel for schedule(static) collapse(2) num_threads(cores) proc_bind(spread) for(size_t k = 0u; k < outer; ++k){ @@ -598,7 +596,7 @@ inline void ttm( assert(p>2); assert(q>0); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // outer = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const outer = product(na, pia, qh+1,p+1); @@ -619,8 +617,8 @@ inline void ttm( auto wq = wa[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c #pragma omp parallel for schedule(static) collapse(2) num_threads(cores) proc_bind(spread) for(size_t k = 0u; k < outer; ++k){ @@ -664,7 +662,7 @@ inline void ttm( assert(p>2); assert(q>0); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // outer = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const outer = product(na, pia, qh+1,p+1); @@ -685,8 +683,8 @@ inline void ttm( auto wq = wa[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c const auto ompthreads = unsigned (double(cores)*ratio); @@ -731,15 +729,15 @@ inline void ttm( mtm_rm(q, p, a, na, pia, b, nb, c, nc ); } else { - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); auto const nnq = product(na, pia, 1, qh); auto const m = nc[q-1]; auto const nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c if(is_cm) loops_over_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc); else loops_over_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc); @@ -770,15 +768,15 @@ inline void ttm( mtm_rm(q, p, a, na, pia, b, nb, c, nc ); } else { - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); auto const nnq = product(na, pia, 1, qh); auto const m = nc[q-1]; auto const nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c if(is_cm) loops_over_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc); else loops_over_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc); @@ -812,7 +810,7 @@ inline void ttm( assert(q>0); assert(p>2); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const num = product(na, pia, qh+1,p+1); @@ -828,8 +826,8 @@ inline void ttm( auto nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c for(size_t k = 0u; k < num; ++k){ auto aa = a+k*waq; @@ -868,7 +866,7 @@ inline void ttm( assert(q>0); assert(p>2); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const num = product(na, pia, qh+1,p+1); @@ -884,8 +882,8 @@ inline void ttm( auto nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c #pragma omp parallel for schedule(static) num_threads(cores) proc_bind(spread) for(size_t k = 0u; k < num; ++k){ @@ -931,7 +929,7 @@ inline void ttm( assert(q>0); assert(p>2); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const num = product(na, pia, qh+1,p+1); @@ -947,8 +945,8 @@ inline void ttm( auto nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c #pragma omp parallel for schedule(dynamic) num_threads(cores) proc_bind(spread) for(size_t k = 0u; k < num; ++k){ @@ -990,7 +988,7 @@ inline void ttm( assert(q>0); assert(p>2); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto const num = product(na, pia, qh+1,p+1); @@ -1006,8 +1004,8 @@ inline void ttm( auto nq = na[q-1]; using namespace std::placeholders; - auto gemm_col = std::bind(tlib::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c - auto gemm_row = std::bind(tlib::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c + auto gemm_col = std::bind(tlib::ttm::detail::gemm_col_tr2::run,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c + auto gemm_row = std::bind(tlib::ttm::detail::gemm_row:: run,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c const auto ompthreads = unsigned (double(cores)*ratio); const auto blasthreads = unsigned (double(cores)*(1.0-ratio)); @@ -1053,7 +1051,7 @@ inline void ttm( assert(q>0); assert(p>2); #ifdef USE_MKL - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = inverse_mode(pia, pia+p, q); // num = n[pi[qh+1]] * n[pi[qh+2]] * ... * n[pi[p]] auto pp = product(na, pia, qh+1,p+1); @@ -1150,7 +1148,7 @@ inline void ttm( assert(q>0); assert(p>2); - auto const qh = tlib::detail::inverse_mode(pia, pia+p, q); + auto const qh = tlib::ttm::detail::inverse_mode(pia, pia+p, q); // inner = n[pi[2]] * ... * n[pi[qh-1]] with pi[qh] = q auto const inner = product(na, pia, 2, qh); @@ -1184,4 +1182,4 @@ inline void ttm( } -} // namespace tlib::detail +} // namespace tlib::ttm::detail diff --git a/include/tlib/ttm.h b/include/tlib/ttm.h index 7619059..18d30fe 100644 --- a/include/tlib/ttm.h +++ b/include/tlib/ttm.h @@ -21,7 +21,7 @@ #include "detail/tensor.h" #include "detail/tags.h" -namespace tlib +namespace tlib::ttm { @@ -60,6 +60,7 @@ inline void ttm( value_t *c, size_t const*const nc, size_t const*const wc ) { + using namespace tlib::ttm; if(p==0) throw std::runtime_error("Error in tlib::tensor_times_matrix: input tensor order should be greater zero."); if(q==0 || q>p) throw std::runtime_error("Error in tlib::tensor_times_matrix: contraction mode should be greater zero or less than or equal to p."); @@ -97,11 +98,10 @@ inline void ttm( * */ template -inline auto ttm( - std::size_t q, - tensor const& a, - tensor const& b, - execution_policy ep, slicing_policy sp, fusion_policy fp) +inline auto ttm(std::size_t q, + ttm::tensor const& a, + ttm::tensor const& b, + execution_policy ep, slicing_policy sp, fusion_policy fp) { auto const p = a.order(); @@ -131,8 +131,8 @@ inline auto ttm( * */ template -inline auto operator*(tlib::tensor_view const& a, tlib::tensor const& b) +inline auto operator*(tlib::ttm::tensor_view const& a, tlib::ttm::tensor const& b) { return ttm(a.contraction_mode(), a.get_tensor(), b, - tlib::parallel_policy::combined, tlib::slicing_policy::combined, tlib::fusion_policy::all) ; + tlib::ttm::parallel_policy::combined, tlib::ttm::slicing_policy::combined, tlib::ttm::fusion_policy::all) ; } diff --git a/test/src/gtest_tlib_layout.cpp b/test/src/gtest_tlib_layout.cpp index 57fa3d8..8e3f14e 100644 --- a/test/src/gtest_tlib_layout.cpp +++ b/test/src/gtest_tlib_layout.cpp @@ -23,173 +23,168 @@ #include +using namespace tlib::ttm; + class LayoutTest : public ::testing::Test { protected: - using layout_t = std::vector; - - void SetUp() override - { - layouts = - { - layout_t(1), // 1 - layout_t(2), // 2 - layout_t(3), // 3 - layout_t(4), // 4 - }; + using layout_t = std::vector; + + void SetUp() override + { + layouts = { layout_t(1), layout_t(2), layout_t(3), layout_t(4), }; } std::vector layouts; }; TEST_F(LayoutTest, generate_1_order) { - auto ref_layouts = std::vector - { - layout_t{1}, - layout_t{1,2}, - layout_t{1,2,3}, - layout_t{1,2,3,4} - }; - - ASSERT_TRUE(ref_layouts.size() == layouts.size()); - - for(auto i = 0u; i < layouts.size(); ++i){ - tlib::detail::compute_first_order_layout(layouts[i].begin(), layouts[i].end()); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - - tlib::detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),1u); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - EXPECT_TRUE (tlib::detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); - - EXPECT_TRUE (tlib::detail::is_first_order(layouts[i].begin(), layouts[i].end())); - if(i>0){ - EXPECT_TRUE (!tlib::detail::is_last_order(layouts[i].begin(), layouts[i].end())); - } - } + auto ref_layouts = std::vector + { + layout_t{1}, + layout_t{1,2}, + layout_t{1,2,3}, + layout_t{1,2,3,4} + }; + + ASSERT_TRUE(ref_layouts.size() == layouts.size()); + + for(auto i = 0u; i < layouts.size(); ++i){ + detail::compute_first_order_layout(layouts[i].begin(), layouts[i].end()); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + + detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),1u); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + EXPECT_TRUE (detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); + + EXPECT_TRUE (detail::is_first_order(layouts[i].begin(), layouts[i].end())); + if(i>0){ + EXPECT_TRUE (!detail::is_last_order(layouts[i].begin(), layouts[i].end())); + } + } } TEST_F(LayoutTest, generate_2_order) { - auto ref_layouts = std::vector - { - layout_t{1}, - layout_t{2,1}, - layout_t{2,1,3}, - layout_t{2,1,3,4} - }; - - ASSERT_TRUE(ref_layouts.size() == layouts.size()); - - for(auto i = 0u; i < layouts.size(); ++i){ - tlib::detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),2u); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - EXPECT_TRUE (tlib::detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); - if(i>0){ - EXPECT_TRUE (!tlib::detail::is_first_order(layouts[i].begin(), layouts[i].end())); - } - if(i==1){ - EXPECT_TRUE (tlib::detail::is_last_order(layouts[i].begin(), layouts[i].end())); - } - } + auto ref_layouts = std::vector + { + layout_t{1}, + layout_t{2,1}, + layout_t{2,1,3}, + layout_t{2,1,3,4} + }; + + ASSERT_TRUE(ref_layouts.size() == layouts.size()); + + for(auto i = 0u; i < layouts.size(); ++i){ + detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),2u); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + EXPECT_TRUE (detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); + if(i>0){ + EXPECT_TRUE (!detail::is_first_order(layouts[i].begin(), layouts[i].end())); + } + if(i==1){ + EXPECT_TRUE (detail::is_last_order(layouts[i].begin(), layouts[i].end())); + } + } } TEST_F(LayoutTest, generate_3_order) { - auto ref_layouts = std::vector - { - layout_t{1}, - layout_t{2,1}, - layout_t{3,2,1}, - layout_t{3,2,1,4} - }; - - ASSERT_TRUE(ref_layouts.size() == layouts.size()); - - for(auto i = 0u; i < layouts.size(); ++i){ - tlib::detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),3u); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - EXPECT_TRUE (tlib::detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); - if(i>0){ - EXPECT_TRUE (!tlib::detail::is_first_order(layouts[i].begin(), layouts[i].end())); - } - if(i==1 || i == 2){ - EXPECT_TRUE (tlib::detail::is_last_order(layouts[i].begin(), layouts[i].end())); - } - if(i==3) { - EXPECT_TRUE (!tlib::detail::is_last_order(layouts[i].begin(), layouts[i].end())); - } - - } + auto ref_layouts = std::vector + { + layout_t{1}, + layout_t{2,1}, + layout_t{3,2,1}, + layout_t{3,2,1,4} + }; + + ASSERT_TRUE(ref_layouts.size() == layouts.size()); + + for(auto i = 0u; i < layouts.size(); ++i){ + detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),3u); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + EXPECT_TRUE (detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); + if(i>0){ + EXPECT_TRUE (!detail::is_first_order(layouts[i].begin(), layouts[i].end())); + } + if(i==1 || i == 2){ + EXPECT_TRUE (detail::is_last_order(layouts[i].begin(), layouts[i].end())); + } + if(i==3) { + EXPECT_TRUE (!detail::is_last_order(layouts[i].begin(), layouts[i].end())); + } + } } TEST_F(LayoutTest, generate_4_order) { - auto ref_layouts = std::vector - { - layout_t{1}, - layout_t{2,1}, - layout_t{3,2,1}, - layout_t{4,3,2,1} - }; - - ASSERT_TRUE(ref_layouts.size() == layouts.size()); - - for(auto i = 0u; i < layouts.size(); ++i){ - tlib::detail::compute_last_order_layout(layouts[i].begin(), layouts[i].end()); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - EXPECT_TRUE (tlib::detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); - - tlib::detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),4u); - ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); - EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); - EXPECT_TRUE (tlib::detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); - if(i>0) { - EXPECT_TRUE (!tlib::detail::is_first_order(layouts[i].begin(), layouts[i].end())); - } - EXPECT_TRUE (tlib::detail::is_last_order(layouts[i].begin(), layouts[i].end())); - } + auto ref_layouts = std::vector + { + layout_t{1}, + layout_t{2,1}, + layout_t{3,2,1}, + layout_t{4,3,2,1} + }; + + ASSERT_TRUE(ref_layouts.size() == layouts.size()); + + for(auto i = 0u; i < layouts.size(); ++i){ + detail::compute_last_order_layout(layouts[i].begin(), layouts[i].end()); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + EXPECT_TRUE (detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); + + detail::compute_k_order_layout(layouts[i].begin(), layouts[i].end(),4u); + ASSERT_TRUE (layouts[i].size() == ref_layouts[i].size()); + EXPECT_TRUE (std::equal(layouts[i].begin(),layouts[i].end(),ref_layouts[i].begin())); + EXPECT_TRUE (detail::is_valid_layout(layouts[i].begin(), layouts[i].end())); + if(i>0) { + EXPECT_TRUE (!detail::is_first_order(layouts[i].begin(), layouts[i].end())); + } + EXPECT_TRUE (detail::is_last_order(layouts[i].begin(), layouts[i].end())); + } } TEST_F(LayoutTest, is_valid_layout) { - using layout_t = std::vector; - auto invalid_layouts = std::vector - { - {}, - {0}, - {0,1}, - {1,0}, - {0,2}, - {2,0}, - {2,1,0}, - {3,0,2}, - {3,1,4}, - {1,3,4}, - {1,3,5}, - }; - - for(auto const& invalid_layout : invalid_layouts) - { - EXPECT_FALSE ( tlib::detail::is_valid_layout(invalid_layout.begin(), invalid_layout.end()) ); - } - - auto valid_layouts = std::vector(); - - for(auto order = 1u; order <= 10u; ++order) - { - auto layout = layout_t(order,0); - for(auto format = 1u; format <= order; ++format) - { - tlib::detail::compute_k_order_layout(layout.begin(), layout.end(),format); - EXPECT_TRUE(tlib::detail::is_valid_layout(layout.begin(), layout.end())); - } - } + using layout_t = std::vector; + auto invalid_layouts = std::vector + { + {}, + {0}, + {0,1}, + {1,0}, + {0,2}, + {2,0}, + {2,1,0}, + {3,0,2}, + {3,1,4}, + {1,3,4}, + {1,3,5}, + }; + + for(auto const& invalid_layout : invalid_layouts) + { + EXPECT_FALSE ( detail::is_valid_layout(invalid_layout.begin(), invalid_layout.end()) ); + } + + auto valid_layouts = std::vector(); + + for(auto order = 1u; order <= 10u; ++order) + { + auto layout = layout_t(order,0); + for(auto format = 1u; format <= order; ++format) + { + detail::compute_k_order_layout(layout.begin(), layout.end(),format); + EXPECT_TRUE(detail::is_valid_layout(layout.begin(), layout.end())); + } + } } @@ -197,20 +192,20 @@ TEST_F(LayoutTest, is_valid_layout) TEST_F(LayoutTest, inverse_mode) { - for(auto order = 1u; order <= 10u; ++order) - { - auto layout = layout_t(order,0); - for(auto format = 1u; format <= order; ++format) - { - tlib::detail::compute_k_order_layout(layout.begin(), layout.end(),format); - ASSERT_TRUE(tlib::detail::is_valid_layout(layout.begin(), layout.end())); - for(auto mode = 1u; mode <= order; ++mode) - { - auto r = tlib::detail::inverse_mode(layout.begin(), layout.end(), mode); - ASSERT_TRUE(r>=1); - ASSERT_TRUE(r<=order); - EXPECT_TRUE(layout[r-1]==mode); - } - } - } + for(auto order = 1u; order <= 10u; ++order) + { + auto layout = layout_t(order,0); + for(auto format = 1u; format <= order; ++format) + { + detail::compute_k_order_layout(layout.begin(), layout.end(),format); + ASSERT_TRUE(detail::is_valid_layout(layout.begin(), layout.end())); + for(auto mode = 1u; mode <= order; ++mode) + { + auto r = detail::inverse_mode(layout.begin(), layout.end(), mode); + ASSERT_TRUE(r>=1); + ASSERT_TRUE(r<=order); + EXPECT_TRUE(layout[r-1]==mode); + } + } + } } diff --git a/test/src/gtest_tlib_mtm.cpp b/test/src/gtest_tlib_mtm.cpp index 9025f9d..12442ae 100644 --- a/test/src/gtest_tlib_mtm.cpp +++ b/test/src/gtest_tlib_mtm.cpp @@ -26,6 +26,7 @@ #include "gtest_aux.h" +using namespace tlib::ttm; template [[nodiscard]] matrix_type mtm(matrix_type const& a, matrix_type const& b) @@ -290,7 +291,7 @@ TEST(MatrixTimesVector, Ref) auto c = mtv(a,b); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(a.p(),q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -305,7 +306,7 @@ TEST(MatrixTimesVector, Ref) auto c = vtm(a,b); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(a.p(),q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -348,7 +349,7 @@ TEST(MatrixTimesMatrix, Ref) auto c = mtm(a,b); - auto qh = tlib::detail::inverse_mode(f.begin(),f.end(),q); + auto qh = detail::inverse_mode(f.begin(),f.end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -391,11 +392,10 @@ TEST(MatrixTimesMatrix, Case1) tlib::gtest::init(b,2u); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); for(auto i = 0ul; i < m; ++i) EXPECT_FLOAT_EQ(c[i], refc(b,i,2u) ); @@ -422,11 +422,10 @@ TEST(MatrixTimesMatrix, Case1) tlib::gtest::init(b,2u); - tlib::detail::mtm_cm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_cm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); for(auto i = 0ul; i < m; ++i) EXPECT_FLOAT_EQ(c[i], refc(b,i,2u) ); @@ -461,7 +460,7 @@ TEST(MatrixTimesMatrix, Case2) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<2>(p,q,cm.data())); + ASSERT_TRUE(detail::is_case<2>(p,q,cm.data())); auto a = matrix({m,n}, 1.0, cm); @@ -473,13 +472,12 @@ TEST(MatrixTimesMatrix, Case2) tlib::gtest::init(a,q); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(pia.begin(),pia.end(),q); + auto qh = detail::inverse_mode(pia.begin(),pia.end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -491,7 +489,7 @@ TEST(MatrixTimesMatrix, Case2) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<2>(p,q,cm.data())); + ASSERT_TRUE(detail::is_case<2>(p,q,cm.data())); auto a = matrix({m,n}, 1.0, cm); @@ -503,13 +501,12 @@ TEST(MatrixTimesMatrix, Case2) tlib::gtest::init(a,q); - tlib::detail::mtm_cm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_cm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(pia.begin(),pia.end(),q); + auto qh = detail::inverse_mode(pia.begin(),pia.end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -542,7 +539,7 @@ TEST(MatrixTimesMatrix, Case3) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<3>(p,q,cm.data())); + ASSERT_TRUE(detail::is_case<3>(p,q,cm.data())); auto a = matrix({m,n}, 1.0, cm); @@ -554,13 +551,12 @@ TEST(MatrixTimesMatrix, Case3) tlib::gtest::init(a,q); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -572,7 +568,7 @@ TEST(MatrixTimesMatrix, Case3) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<3>(p,q,cm.data())); + ASSERT_TRUE(detail::is_case<3>(p,q,cm.data())); auto a = matrix({m,n}, 1.0, cm); @@ -584,13 +580,12 @@ TEST(MatrixTimesMatrix, Case3) tlib::gtest::init(a,q); - tlib::detail::mtm_cm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_cm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } } @@ -623,7 +618,7 @@ TEST(MatrixTimesMatrix, Case4) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<4>(p,q,rm.data())); + ASSERT_TRUE(detail::is_case<4>(p,q,rm.data())); auto a = matrix({m,n}, 1.0, rm); @@ -635,13 +630,12 @@ TEST(MatrixTimesMatrix, Case4) tlib::gtest::init(a,q); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -653,7 +647,7 @@ TEST(MatrixTimesMatrix, Case4) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<4>(p,q,rm.data())); + ASSERT_TRUE(detail::is_case<4>(p,q,rm.data())); auto a = matrix({m,n}, 1.0, rm); @@ -665,13 +659,12 @@ TEST(MatrixTimesMatrix, Case4) tlib::gtest::init(a,q); - tlib::detail::mtm_cm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_cm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } } @@ -703,7 +696,7 @@ TEST(MatrixTimesMatrix, Case5) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<5>(p,q,rm.data())); + ASSERT_TRUE(detail::is_case<5>(p,q,rm.data())); auto a = matrix({m,n}, 1.0, rm); @@ -715,13 +708,12 @@ TEST(MatrixTimesMatrix, Case5) tlib::gtest::init(a,q); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -733,7 +725,7 @@ TEST(MatrixTimesMatrix, Case5) auto n = na[1]; auto u = m*2; - ASSERT_TRUE(tlib::detail::is_case<5>(p,q,rm.data())); + ASSERT_TRUE(detail::is_case<5>(p,q,rm.data())); auto a = matrix({m,n}, 1.0, rm); @@ -745,13 +737,12 @@ TEST(MatrixTimesMatrix, Case5) tlib::gtest::init(a,q); - tlib::detail::mtm_rm( - q,p, - a.data(), na.data(), pia.data(), - b.data(), nb.data(), - c.data(), nc.data()); + detail::mtm_rm(q,p, + a.data(), na.data(), pia.data(), + b.data(), nb.data(), + c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } @@ -798,13 +789,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); tlib::gtest::init(a,q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -825,13 +816,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); tlib::gtest::init(a,q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -859,13 +850,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); @@ -886,13 +877,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); @@ -922,13 +913,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -949,13 +940,13 @@ TEST(MatrixTimesMatrix, Case6) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<6>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<6>(p,q,pia.data())); init(a, q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1005,13 +996,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a,q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1032,13 +1023,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a,q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1066,13 +1057,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1091,13 +1082,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1125,13 +1116,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1152,13 +1143,13 @@ TEST(MatrixTimesMatrix, Case7) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); tlib::gtest::init(a, q); - tlib::detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_cm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); - auto qh = tlib::detail::inverse_mode(a.pi().begin(),a.pi().end(),q); + auto qh = detail::inverse_mode(a.pi().begin(),a.pi().end(),q); ttm_check(p,q,qh, 0ul,0ul, a.container(),a.n(),a.w(),a.pi(), c.container(),c.n(),c.w(),c.pi()); } // layouts @@ -1212,11 +1203,11 @@ TEST(MatrixTimesMatrix, Case8) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); init(a,q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); //std::cout << "q = " << q << std::endl; //std::cout << "a = " << a << std::endl; @@ -1254,11 +1245,11 @@ TEST(MatrixTimesMatrix, Case8) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); //std::cout << "q = " << q << std::endl; //std::cout << "a = " << a << std::endl; @@ -1299,11 +1290,11 @@ TEST(MatrixTimesMatrix, Case8) const auto nc = c.n(); const auto pia = a.pi(); - ASSERT_TRUE(tlib::detail::is_case<7>(p,q,pia.data())); + ASSERT_TRUE(detail::is_case<7>(p,q,pia.data())); init(a, q); - tlib::detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); + detail::mtm_rm( q,p, a.data(), na.data(), pia.data(), b.data(), nb.data(), c.data(), nc.data()); //std::cout << "q = " << q << std::endl; //std::cout << "a = " << a << std::endl; diff --git a/test/src/gtest_tlib_shape.cpp b/test/src/gtest_tlib_shape.cpp index df69fa9..c5300f8 100644 --- a/test/src/gtest_tlib_shape.cpp +++ b/test/src/gtest_tlib_shape.cpp @@ -24,7 +24,7 @@ #include - +using namespace tlib::ttm; class ShapeTest : public ::testing::Test { protected: @@ -64,76 +64,76 @@ class ShapeTest : public ::testing::Test { TEST_F(ShapeTest, is_scalar) { - auto ints = std::vector{2,6,15}; - - for(auto i : ints){ - EXPECT_TRUE (tlib::detail::is_scalar(shapes[i].begin(), shapes[i].end())); - } - - for(auto i = 0u; i < shapes.size(); ++i){ - if(std::find(ints.begin(), ints.end(),i)==ints.end()){ - EXPECT_FALSE(tlib::detail::is_scalar(shapes[i].begin(), shapes[i].end())); - } - } + auto ints = std::vector{2,6,15}; + + for(auto i : ints){ + EXPECT_TRUE (detail::is_scalar(shapes[i].begin(), shapes[i].end())); + } + + for(auto i = 0u; i < shapes.size(); ++i){ + if(std::find(ints.begin(), ints.end(),i)==ints.end()){ + EXPECT_FALSE(detail::is_scalar(shapes[i].begin(), shapes[i].end())); + } + } } TEST_F(ShapeTest, is_vector) { - auto ints = std::vector{3,7,8,9,10,17,19}; - - for(auto i : ints ){ - EXPECT_TRUE (tlib::detail::is_vector(shapes[i].begin(), shapes[i].end())); - } - for(auto i = 0u; i < shapes.size(); ++i ){ - if(std::find(ints.begin(), ints.end(),i)==ints.end()){ - EXPECT_FALSE(tlib::detail::is_vector(shapes[i].begin(), shapes[i].end())); - } - } + auto ints = std::vector{3,7,8,9,10,17,19}; + + for(auto i : ints){ + EXPECT_TRUE (detail::is_vector(shapes[i].begin(), shapes[i].end())); + } + for(auto i = 0u; i < shapes.size(); ++i ){ + if(std::find(ints.begin(), ints.end(),i)==ints.end()){ + EXPECT_FALSE(detail::is_vector(shapes[i].begin(), shapes[i].end())); + } + } } TEST_F(ShapeTest, is_matrix) { - auto ints = std::vector{11,12,21}; - - for(auto i : ints ){ - EXPECT_TRUE (tlib::detail::is_matrix(shapes[i].begin(), shapes[i].end())); - } - for(auto i = 0u; i < shapes.size(); ++i ){ - if(std::find(ints.begin(), ints.end(),i)==ints.end()){ - EXPECT_FALSE(tlib::detail::is_matrix(shapes[i].begin(), shapes[i].end())); - } - } + auto ints = std::vector{11,12,21}; + + for(auto i : ints ){ + EXPECT_TRUE (detail::is_matrix(shapes[i].begin(), shapes[i].end())); + } + for(auto i = 0u; i < shapes.size(); ++i ){ + if(std::find(ints.begin(), ints.end(),i)==ints.end()){ + EXPECT_FALSE(detail::is_matrix(shapes[i].begin(), shapes[i].end())); + } + } } TEST_F(ShapeTest, is_tensor) { - auto ints = std::vector {16,18,20,22}; - - for(auto i : ints ){ - EXPECT_TRUE (tlib::detail::is_tensor(shapes[i].begin(), shapes[i].end())); - } - for(auto i = 0u; i < shapes.size(); ++i ){ - if(std::find(ints.begin(), ints.end(),i)==ints.end()){ - EXPECT_FALSE(tlib::detail::is_tensor(shapes[i].begin(), shapes[i].end())); - } - } + auto ints = std::vector {16,18,20,22}; + + for(auto i : ints ){ + EXPECT_TRUE (detail::is_tensor(shapes[i].begin(), shapes[i].end())); + } + for(auto i = 0u; i < shapes.size(); ++i ){ + if(std::find(ints.begin(), ints.end(),i)==ints.end()){ + EXPECT_FALSE(detail::is_tensor(shapes[i].begin(), shapes[i].end())); + } + } } TEST_F(ShapeTest, is_valid) { - auto ints = std::vector {0,1,4,5,13,14}; - - for(auto i : ints ){ - EXPECT_FALSE(tlib::detail::is_valid_shape(shapes[i].begin(), shapes[i].end())); - } - for(auto i = 0u; i < shapes.size(); ++i ){ - if(std::find(ints.begin(), ints.end(),i)==ints.end()){ - EXPECT_TRUE(tlib::detail::is_valid_shape(shapes[i].begin(), shapes[i].end())); - } - } + auto ints = std::vector {0,1,4,5,13,14}; + + for(auto i : ints ){ + EXPECT_FALSE(detail::is_valid_shape(shapes[i].begin(), shapes[i].end())); + } + for(auto i = 0u; i < shapes.size(); ++i ){ + if(std::find(ints.begin(), ints.end(),i)==ints.end()){ + EXPECT_TRUE(detail::is_valid_shape(shapes[i].begin(), shapes[i].end())); + } + } } diff --git a/test/src/gtest_tlib_strides.cpp b/test/src/gtest_tlib_strides.cpp index 8ac5012..f468ddd 100644 --- a/test/src/gtest_tlib_strides.cpp +++ b/test/src/gtest_tlib_strides.cpp @@ -29,38 +29,40 @@ using extents_t = std::vector; using layout_t = std::vector; using strides_t = std::vector; +using namespace tlib::ttm; TEST(StridesTest, ScalarShape) { - auto l1 = layout_t{1}; - auto v1 = tlib::detail::generate_strides(extents_t{1},l1); - ASSERT_EQ(v1.size(),1u); - EXPECT_EQ(1u, v1[0u]); - ASSERT_TRUE(tlib::detail::is_valid_strides(l1.begin(), l1.end(), v1.begin())); + + auto l1 = layout_t{1}; + auto v1 = detail::generate_strides(extents_t{1},l1); + ASSERT_EQ(v1.size(),1u); + EXPECT_EQ(1u, v1[0u]); + ASSERT_TRUE(detail::is_valid_strides(l1.begin(), l1.end(), v1.begin())); } TEST(StridesTest, VectorShape) { - auto l0 = layout_t{1,2}; - auto v0 = tlib::detail::generate_strides(extents_t{1,1},l0); - ASSERT_EQ(v0.size(),2u); - EXPECT_EQ(1u, v0[0u]); - EXPECT_EQ(1u, v0[1u]); - ASSERT_TRUE(tlib::detail::is_valid_strides(l0.begin(), l0.end(), v0.begin())); - - auto extents = std::vector{ extents_t{1,2}, {1,4}, {1,8}, {2,1}, {4,1}, {8,1} }; - auto layouts = std::vector{ layout_t {1,2}, {2,1} }; - - for(auto const& extent : extents){ - ASSERT_TRUE(tlib::detail::is_vector(extent.begin(), extent.end())); - for(auto const& layout : layouts){ - auto v = tlib::detail::generate_strides(extent,layout); - ASSERT_EQ(v.size(),2u); - EXPECT_EQ(1u, v[0u]); - EXPECT_EQ(1u, v[1u]); - ASSERT_TRUE(tlib::detail::is_valid_strides(layout.begin(), layout.end(),v.begin())); - } - } + auto l0 = layout_t{1,2}; + auto v0 = detail::generate_strides(extents_t{1,1},l0); + ASSERT_EQ(v0.size(),2u); + EXPECT_EQ(1u, v0[0u]); + EXPECT_EQ(1u, v0[1u]); + ASSERT_TRUE(detail::is_valid_strides(l0.begin(), l0.end(), v0.begin())); + + auto extents = std::vector{ extents_t{1,2}, {1,4}, {1,8}, {2,1}, {4,1}, {8,1} }; + auto layouts = std::vector{ layout_t {1,2}, {2,1} }; + + for(auto const& extent : extents){ + ASSERT_TRUE(detail::is_vector(extent.begin(), extent.end())); + for(auto const& layout : layouts){ + auto v = detail::generate_strides(extent,layout); + ASSERT_EQ(v.size(),2u); + EXPECT_EQ(1u, v[0u]); + EXPECT_EQ(1u, v[1u]); + ASSERT_TRUE(detail::is_valid_strides(layout.begin(), layout.end(),v.begin())); + } + } } @@ -70,21 +72,21 @@ TEST(StridesTest, MatrixShape) auto layouts = std::vector{ layout_t {1,2}, {2,1} }; for(auto const& extent : extents){ - auto strides = tlib::detail::generate_strides(extent,layouts[0]); + auto strides = detail::generate_strides(extent,layouts[0]); ASSERT_EQ(strides.size(),2u); - ASSERT_TRUE(tlib::detail::is_matrix(extent.begin(), extent.end())); + ASSERT_TRUE(detail::is_matrix(extent.begin(), extent.end())); EXPECT_EQ(1u, strides[0u]); EXPECT_EQ(extent[0], strides[1u]); - ASSERT_TRUE(tlib::detail::is_valid_strides(layouts[0].begin(), layouts[0].end(),strides.begin())); + ASSERT_TRUE(detail::is_valid_strides(layouts[0].begin(), layouts[0].end(),strides.begin())); } for(auto const& extent : extents){ - auto strides = tlib::detail::generate_strides(extent,layouts[1]); + auto strides = detail::generate_strides(extent,layouts[1]); ASSERT_EQ(strides.size(),2u); - ASSERT_TRUE(tlib::detail::is_matrix(extent.begin(), extent.end())); + ASSERT_TRUE(detail::is_matrix(extent.begin(), extent.end())); EXPECT_EQ(extent[1], strides[0u]); EXPECT_EQ(1u, strides[1u]); - ASSERT_TRUE(tlib::detail::is_valid_strides(layouts[1].begin(), layouts[1].end(), strides.begin())); + ASSERT_TRUE(detail::is_valid_strides(layouts[1].begin(), layouts[1].end(), strides.begin())); } } @@ -96,11 +98,11 @@ TEST(StridesTest, TensorShape) ASSERT_EQ(strides.size(),p); ASSERT_EQ(ref_strides.size(),p); ASSERT_EQ(extents.size(),p); - ASSERT_TRUE(tlib::detail::is_tensor(extents.begin(), extents.end())); + ASSERT_TRUE(detail::is_tensor(extents.begin(), extents.end())); for(auto i = 0u; i < p; ++i) EXPECT_EQ(ref_strides[i], strides[i]); - ASSERT_TRUE(tlib::detail::is_valid_strides(layout.begin(), layout.end(), strides.begin())); + ASSERT_TRUE(detail::is_valid_strides(layout.begin(), layout.end(), strides.begin())); }; // first-order @@ -110,7 +112,7 @@ TEST(StridesTest, TensorShape) auto layout = layout_t {1,2,3}; for(auto i = 0u; i < extents.size(); ++i){ - auto strides = tlib::detail::generate_strides(extents[i],layout); + auto strides = detail::generate_strides(extents[i],layout); test_strides(extents[i].size(),extents[i],strides,ref_strides[i],layout); } } @@ -122,7 +124,7 @@ TEST(StridesTest, TensorShape) auto layout = layout_t {3,2,1}; for(auto i = 0u; i < extents.size(); ++i){ - auto strides = tlib::detail::generate_strides(extents[i],layout); + auto strides = detail::generate_strides(extents[i],layout); test_strides(extents[i].size(),extents[i],strides,ref_strides[i],layout); } } @@ -134,7 +136,7 @@ TEST(StridesTest, TensorShape) auto layout = layout_t {2,1,3}; for(auto i = 0u; i < extents.size(); ++i){ - auto strides = tlib::detail::generate_strides(extents[i],layout); + auto strides = detail::generate_strides(extents[i],layout); test_strides(extents[i].size(),extents[i],strides,ref_strides[i],layout); } } diff --git a/test/src/gtest_tlib_ttm.cpp b/test/src/gtest_tlib_ttm.cpp index 8b995bb..c89f71f 100644 --- a/test/src/gtest_tlib_ttm.cpp +++ b/test/src/gtest_tlib_ttm.cpp @@ -22,6 +22,7 @@ #include #include +using namespace tlib::ttm; template inline void @@ -60,7 +61,7 @@ inline void ttm_init( assert(p>=2); assert(1<=q && q <= p); - const size_type qh = tlib::detail::inverse_mode(pia.begin(), pia.end(), q ); + const size_type qh = detail::inverse_mode(pia.begin(), pia.end(), q ); assert(1<=qh && qh <= p); size_type k = 0ul; @@ -155,7 +156,7 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step for(auto const& na : shapes) { - assert(tlib::detail::is_valid_shape(na.begin(), na.end())); + assert(detail::is_valid_shape(na.begin(), na.end())); auto nna = std::accumulate(na.begin(),na.end(), 1ul, std::multiplies()); auto a = std::vector(nna,value_type{}); @@ -166,11 +167,11 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step for(auto const& pia : layouts) { - assert(tlib::detail::is_valid_layout(pia.begin(), pia.end())); + assert(detail::is_valid_layout(pia.begin(), pia.end())); - auto wa = tlib::detail::generate_strides (na ,pia ); + auto wa = detail::generate_strides (na ,pia ); - assert(tlib::detail::is_valid_strides(pia.begin(), pia.end(),wa.begin())); + assert(detail::is_valid_strides(pia.begin(), pia.end(),wa.begin())); // std::cout <<"pia = [ "; std::copy(pia.begin(), pia.end(), std::ostream_iterator(std::cout, " ")); std::cout <<"];" << std::endl; @@ -197,24 +198,26 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step assert(nb.at(1) == nq); - auto pic = pia; // tlib::detail::generate_output_layout(pia,q); - auto nc = na; // tlib::detail::generate_output_shape (na ,q); + auto pic = pia; // detail::generate_output_layout(pia,q); + auto nc = na; // detail::generate_output_shape (na ,q); nc.at(q-1) = m; - auto wa = tlib::detail::generate_strides(na,pia); - auto wc = tlib::detail::generate_strides(nc,pic); + auto wa = detail::generate_strides(na,pia); + auto wc = detail::generate_strides(nc,pic); auto nnc = std::accumulate(nc.begin(),nc.end(), 1ul, std::multiplies()); auto c = std::vector(nnc,value_type{}); - auto qh = tlib::detail::inverse_mode(pia.begin(), pia.end(), q); + auto qh = detail::inverse_mode(pia.begin(), pia.end(), q); { - tlib::ttm(ep,sp,fp, q,p, - a.data(), na.data(), wa.data(), pia.data(), - b.data(), nb.data(), rm.data(), - c.data(), nc.data(), wc.data()); + ttm(ep,sp,fp, + q,p, + a.data(), na.data(), wa.data(), pia.data(), + b.data(), nb.data(), rm.data(), + c.data(), nc.data(), wc.data()); + bool test = ttm_check(p,q,qh, 0u,0u, a, na, wa, pia, c, nc, wc, pic); @@ -223,10 +226,12 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step { - tlib::ttm(ep,sp,fp, q,p, - a.data(), na.data(), wa.data(), pia.data(), - b.data(), nb.data(), cm.data(), - c.data(), nc.data(), wc.data()); + ttm(ep,sp,fp, + q,p, + a.data(), na.data(), wa.data(), pia.data(), + b.data(), nb.data(), cm.data(), + c.data(), nc.data(), wc.data()); + bool test = ttm_check(p,q,qh, 0u,0u, a, na, wa, pia, c, nc, wc, pic); @@ -256,90 +261,90 @@ inline void check_tensor_times_matrix(const size_type init, const size_type step TEST(TensorTimesMatrix, SequentialSliceNoFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::sequential_t; - using slicing_policy = tlib::slicing_policy::slice_t; - using fusion_policy = tlib::fusion_policy::none_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::sequential_t; + using sp = slicing_policy::slice_t; + using fp = fusion_policy::none_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelGemmSliceNoFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_blas_t; - using slicing_policy = tlib::slicing_policy::slice_t; - using fusion_policy = tlib::fusion_policy::none_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_blas_t; + using sp = slicing_policy::slice_t; + using fp = fusion_policy::none_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelGemmSubtensorNoFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_blas_t; - using slicing_policy = tlib::slicing_policy::subtensor_t; - using fusion_policy = tlib::fusion_policy::none_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_blas_t; + using sp = slicing_policy::subtensor_t; + using fp = fusion_policy::none_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelLoopSliceOuterFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_loop_t; - using slicing_policy = tlib::slicing_policy::slice_t; - using fusion_policy = tlib::fusion_policy::outer_t; + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_loop_t; + using sp = slicing_policy::slice_t; + using fp = fusion_policy::outer_t; - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelLoopSliceAllFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_loop_t; - using slicing_policy = tlib::slicing_policy::slice_t; - using fusion_policy = tlib::fusion_policy::all_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_loop_t; + using sp = slicing_policy::slice_t; + using fp = fusion_policy::all_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelLoopParallelGemmSliceAllFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_loop_blas_t; - using slicing_policy = tlib::slicing_policy::slice_t; - using fusion_policy = tlib::fusion_policy::all_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_loop_blas_t; + using sp = slicing_policy::slice_t; + using fp = fusion_policy::all_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } @@ -347,45 +352,45 @@ TEST(TensorTimesMatrix, ParallelLoopParallelGemmSliceAllFusion) TEST(TensorTimesMatrix, SequentialSubtensorNoFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::sequential_t; - using slicing_policy = tlib::slicing_policy::subtensor_t; - using fusion_policy = tlib::fusion_policy::none_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::sequential_t; + using sp = slicing_policy::subtensor_t; + using fp = fusion_policy::none_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelLoopSubtensorOuterFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_loop_t; - using slicing_policy = tlib::slicing_policy::subtensor_t; - using fusion_policy = tlib::fusion_policy::all_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_loop_t; + using sp = slicing_policy::subtensor_t; + using fp = fusion_policy::all_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } TEST(TensorTimesMatrix, ParallelLoopParallelGemmSubtensorOuterFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::parallel_loop_blas_t; - using slicing_policy = tlib::slicing_policy::subtensor_t; - using fusion_policy = tlib::fusion_policy::all_t; - - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + using vt = double; + using st = std::size_t; + using ep = parallel_policy::parallel_loop_blas_t; + using sp = slicing_policy::subtensor_t; + using fp = fusion_policy::all_t; + + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } @@ -393,17 +398,17 @@ TEST(TensorTimesMatrix, ParallelLoopParallelGemmSubtensorOuterFusion) TEST(TensorTimesMatrix, BatchedGemmSubtensorOuterFusion) { - using value_type = double; - using size_type = std::size_t; - using execution_policy = tlib::parallel_policy::batched_gemm_t; - using slicing_policy = tlib::slicing_policy::subtensor_t; - using fusion_policy = tlib::fusion_policy::all_t; + using vt = double; + using st = std::size_t; + using ep = parallel_policy::batched_gemm_t; + using sp = slicing_policy::subtensor_t; + using fp = fusion_policy::all_t; - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); - check_tensor_times_matrix(2u,3); -// check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); + check_tensor_times_matrix(2u,3); +// check_tensor_times_matrix(2u,3); } #endif diff --git a/ttmpy/src/wrapped_ttm.cpp b/ttmpy/src/wrapped_ttm.cpp index 0f6aca4..69d9e03 100644 --- a/ttmpy/src/wrapped_ttm.cpp +++ b/ttmpy/src/wrapped_ttm.cpp @@ -34,16 +34,16 @@ ttm(std::size_t const contraction_mode, auto const& ainfo = a.request(); // request a buffer descriptor from Python of type py::buffer_info auto const p = std::size_t(ainfo.ndim); //py::ssize_t - if(p==0) throw std::invalid_argument("Error calling ttmpy::ttm: first input should be a tensor with order greater than zero."); - if(q==0 || q>p) throw std::invalid_argument("Error calling ttmpy::ttm: contraction mode should be greater than zero and at most equal to p."); - + if(p==0) throw std::invalid_argument("Error calling ttmpy::ttm: first input should be a tensor with order greater than zero."); + if(q==0 || q>p) throw std::invalid_argument("Error calling ttmpy::ttm: contraction mode should be greater than zero and at most equal to p."); + auto const*const aptr = static_cast(ainfo.ptr); // extract data an shape of input array auto na = std::vector(ainfo.shape .begin(), ainfo.shape .end()); auto wa = std::vector(ainfo.strides.begin(), ainfo.strides.end()); std::for_each(wa.begin(), wa.end(), [sizeofT](auto& w){w/=sizeofT;}); - auto pia = tlib::detail::generate_k_order_layout(p, p); - auto pib = tlib::detail::generate_k_order_layout(2ul, 2ul); + auto pia = tlib::ttm::detail::generate_k_order_layout(p, p); + auto pib = tlib::ttm::detail::generate_k_order_layout(2ul, 2ul); auto const& binfo = b.request(); // request a buffer descriptor from Python of type py::buffer_info auto const*const bptr = static_cast(binfo.ptr); // extract data an shape of input array @@ -52,26 +52,25 @@ ttm(std::size_t const contraction_mode, if(pb!=2) throw std::invalid_argument("Error calling ttmpy::ttm: second input should be a mtrix with order equal to 2."); + auto nc = na; // tlib::detail::generate_output_shape (na ,q); + nc[q-1] = nb[0]; + auto const& pic = pia; + //auto const pic = tlib::detail::generate_output_layout(pia,q); + auto wc = tlib::ttm::detail::generate_strides(nc,pic); + auto nc_ = std::vector(nc.begin(),nc.end()); + auto wc_ = std::vector(wc.begin(),wc.end()); + std::for_each(wc_.begin(), wc_.end(), [sizeofT](auto& w){w*=sizeofT;}); - auto nc = na; // tlib::detail::generate_output_shape (na ,q); - nc[q-1] = nb[0]; - auto const& pic = pia; - //auto const pic = tlib::detail::generate_output_layout(pia,q); - auto wc = tlib::detail::generate_strides(nc,pic); - auto nc_ = std::vector(nc.begin(),nc.end()); - auto wc_ = std::vector(wc.begin(),wc.end()); - std::for_each(wc_.begin(), wc_.end(), [sizeofT](auto& w){w*=sizeofT;}); - - auto c = py::array_t(nc_,wc_); - auto const& cinfo = c.request(); // request a buffer descriptor from Python of type py::buffer_info - auto* cptr = static_cast(cinfo.ptr); // extract data an shape of input array + auto c = py::array_t(nc_,wc_); + auto const& cinfo = c.request(); // request a buffer descriptor from Python of type py::buffer_info + auto* cptr = static_cast(cinfo.ptr); // extract data an shape of input array // auto nnc = std::size_t(cinfo.size); - tlib::ttm(tlib::parallel_policy::combined, tlib::slicing_policy::combined, tlib::fusion_policy::all, - q, p, - aptr, na.data(), wa.data(), pia.data(), - bptr, nb.data(), pib.data(), - cptr, nc.data(), wc.data()); + tlib::ttm::ttm(tlib::ttm::parallel_policy::combined, tlib::ttm::slicing_policy::combined, tlib::ttm::fusion_policy::all, + q, p, + aptr, na.data(), wa.data(), pia.data(), + bptr, nb.data(), pib.data(), + cptr, nc.data(), wc.data()); return c; }