diff --git a/btas/generic/axpy_impl.h b/btas/generic/axpy_impl.h index 8eb991e8..34f4a897 100644 --- a/btas/generic/axpy_impl.h +++ b/btas/generic/axpy_impl.h @@ -52,8 +52,12 @@ template<> struct axpy_impl _IteratorY itrY, const iterator_difference_t<_IteratorY>& incY, blas_lapack_impl_tag) { - blas::axpy( Nsize, alpha, static_cast(&(*itrX)), incX, - static_cast<_T*>(&(*itrY)), incY ); + static_assert(std::is_same_v,iterator_value_t<_IteratorY>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorX>; + + blas::axpy( Nsize, static_cast(alpha), static_cast(&(*itrX)), incX, + static_cast(&(*itrY)), incY ); } #endif diff --git a/btas/generic/dot_impl.h b/btas/generic/dot_impl.h index e23186f9..8d84b6e0 100644 --- a/btas/generic/dot_impl.h +++ b/btas/generic/dot_impl.h @@ -110,19 +110,13 @@ struct dotc_impl { _IteratorY itrY, const iterator_difference_t<_IteratorY>& incY, blas_lapack_impl_tag) { - - using x_traits = std::iterator_traits<_IteratorX>; - using y_traits = std::iterator_traits<_IteratorY>; - - using x_value_type = typename x_traits::value_type; - using y_value_type = typename y_traits::value_type; - - using x_ptr_type = const x_value_type*; - using y_ptr_type = const y_value_type*; + static_assert(std::is_same_v,iterator_value_t<_IteratorY>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorX>; // XXX: DOTC == DOT in BLASPP - return blas::dot( Nsize, static_cast(&(*itrX)), incX, - static_cast(&(*itrY)), incY ); + return blas::dot( Nsize, static_cast(&(*itrX)), incX, + static_cast(&(*itrY)), incY ); } #endif @@ -172,17 +166,12 @@ struct dotu_impl { blas_lapack_impl_tag) { - using x_traits = std::iterator_traits<_IteratorX>; - using y_traits = std::iterator_traits<_IteratorY>; - - using x_value_type = typename x_traits::value_type; - using y_value_type = typename y_traits::value_type; - - using x_ptr_type = const x_value_type*; - using y_ptr_type = const y_value_type*; + static_assert(std::is_same_v,iterator_value_t<_IteratorY>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorX>; - return blas::dotu( Nsize, static_cast(&(*itrX)), incX, - static_cast(&(*itrY)), incY ); + return blas::dotu( Nsize, static_cast(&(*itrX)), incX, + static_cast(&(*itrY)), incY ); } #endif diff --git a/btas/generic/gemm_impl.h b/btas/generic/gemm_impl.h index d939083c..da7f10fc 100644 --- a/btas/generic/gemm_impl.h +++ b/btas/generic/gemm_impl.h @@ -252,24 +252,16 @@ template<> struct gemm_impl const unsigned long& LDC, blas_lapack_impl_tag) { - - using a_traits = std::iterator_traits<_IteratorA>; - using b_traits = std::iterator_traits<_IteratorB>; - using c_traits = std::iterator_traits<_IteratorC>; - - using a_value_type = typename a_traits::value_type; - using b_value_type = typename b_traits::value_type; - using c_value_type = typename c_traits::value_type; - - using a_ptr_type = const a_value_type*; - using b_ptr_type = const b_value_type*; - using c_ptr_type = c_value_type*; - - blas::gemm( order, transA, transB, Msize, Nsize, Ksize, alpha, - static_cast(&(*itrA)), LDA, - static_cast(&(*itrB)), LDB, - beta, - static_cast(&(*itrC)), LDC ); + static_assert(std::is_same_v,iterator_value_t<_IteratorB>> && + std::is_same_v,iterator_value_t<_IteratorC>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorA>; + + blas::gemm( order, transA, transB, Msize, Nsize, Ksize, static_cast(alpha), + static_cast(&(*itrA)), LDA, + static_cast(&(*itrB)), LDB, + static_cast(beta), + static_cast(&(*itrC)), LDC ); } #endif diff --git a/btas/generic/gemv_impl.h b/btas/generic/gemv_impl.h index aeadf415..0414d4e6 100644 --- a/btas/generic/gemv_impl.h +++ b/btas/generic/gemv_impl.h @@ -151,11 +151,16 @@ template<> struct gemv_impl blas_lapack_impl_tag) { - blas::gemv( order, transA, Msize, Nsize, alpha, - static_cast(&(*itrA)), LDA, - static_cast(&(*itrX)), incX, - beta, - static_cast< _T*>(&(*itrY)), incY ); + static_assert(std::is_same_v,iterator_value_t<_IteratorY>> && + std::is_same_v,iterator_value_t<_IteratorA>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorX>; + + blas::gemv( order, transA, Msize, Nsize, static_cast(alpha), + static_cast(&(*itrA)), LDA, + static_cast(&(*itrX)), incX, + static_cast(beta), + static_cast< T*>(&(*itrY)), incY ); } #endif diff --git a/btas/generic/ger_impl.h b/btas/generic/ger_impl.h index 88f9cda1..3a747c67 100644 --- a/btas/generic/ger_impl.h +++ b/btas/generic/ger_impl.h @@ -79,11 +79,15 @@ template<> struct ger_impl const unsigned long& LDA, blas_lapack_impl_tag) { - - blas::geru( order, Msize, Nsize, alpha, - static_cast(&(*itrX)), incX, - static_cast(&(*itrY)), incY, - static_cast< _T*>(&*(itrA)), LDA ); + static_assert(std::is_same_v,iterator_value_t<_IteratorY>> && + std::is_same_v,iterator_value_t<_IteratorA>>, + "mismatching iterator value types"); + using T = iterator_value_t<_IteratorX>; + + blas::geru( order, Msize, Nsize, static_cast(alpha), + static_cast(&(*itrX)), incX, + static_cast(&(*itrY)), incY, + static_cast< T*>(&*(itrA)), LDA ); } #endif diff --git a/btas/generic/gesvd_impl.h b/btas/generic/gesvd_impl.h index 5f026cce..0e821956 100644 --- a/btas/generic/gesvd_impl.h +++ b/btas/generic/gesvd_impl.h @@ -73,8 +73,6 @@ template<> struct gesvd_impl if( inplaceU and inplaceVt ) BTAS_EXCEPTION("SVD cannot return both vectors inplace"); - - value_type dummy; value_type* A = static_cast(&(*itrA)); value_type* U = (needU and not inplaceU) ? diff --git a/btas/generic/scal_impl.h b/btas/generic/scal_impl.h index 0f1cbc73..f184f7a9 100644 --- a/btas/generic/scal_impl.h +++ b/btas/generic/scal_impl.h @@ -44,7 +44,8 @@ template<> struct scal_impl _IteratorX itrX, const iterator_difference_t<_IteratorX>& incX, blas_lapack_impl_tag) { - blas::scal( Nsize, alpha, static_cast<_T*>(&(*itrX)), incX ); + using T = iterator_value_t<_IteratorX>; + blas::scal( Nsize, static_cast(alpha), static_cast(&(*itrX)), incX ); } #endif diff --git a/btas/type_traits.h b/btas/type_traits.h index af073bd8..af466612 100644 --- a/btas/type_traits.h +++ b/btas/type_traits.h @@ -275,6 +275,9 @@ namespace btas { // Convienience traits template + using iterator_value_t = + typename std::iterator_traits<_Iterator>::value_type; + template using iterator_difference_t = typename std::iterator_traits<_Iterator>::difference_type; diff --git a/unittest/tensor_blas_test.cc b/unittest/tensor_blas_test.cc index 2e7d5fff..74113f6f 100644 --- a/unittest/tensor_blas_test.cc +++ b/unittest/tensor_blas_test.cc @@ -112,7 +112,7 @@ TEST_CASE("Tensor Scal") Tensor T(4,2,6,5); T.generate([](){ return randomReal(); }); Tensor Tbak=T; - double d = randomReal(); + const auto d = randomReal(); // N.B. use different types for scalar and tensor scal(d,T); double res=0; for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d); @@ -136,10 +136,11 @@ TEST_CASE("Tensor Scal") Tensor> T(4,2,6,5); T.generate([](){ return randomCplx(); }); Tensor> Tbak=T; - std::complex d = randomCplx(); + const auto d = randomCplx(); // N.B. use different types for scalar and tensor scal(d,T); double res=0; - for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d); + std::complex d_double(d); + for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d_double); CHECK(res < eps_double); } @@ -166,7 +167,7 @@ TEST_CASE("Tensor Axpy") X.generate([](){ return randomReal(); }); Y.generate([](){ return randomReal(); }); Tensor Ybak=Y; - double alpha = randomReal(); + const auto alpha = randomReal(); // N.B. use different types for scalar and tensor axpy(alpha,X,Y); double res=0; for(auto i : Y.range()) res+=std::abs(Ybak(i)+X(i)*alpha-Y(i)); @@ -228,7 +229,7 @@ TEST_CASE("Tensor Ger") X.generate([](){ return randomReal(); }); Y.generate([](){ return randomReal(); }); Tensor Abak=A; - double a = randomReal(); + const auto a = randomReal(); // N.B. use different types for scalar and tensor ger(a,X,Y,A); double res=0; for(auto i : A.range()) res+=std::abs(a*X(i[0],i[1])*Y(i[2],i[3])+Abak(i)-A(i)); @@ -296,8 +297,8 @@ TEST_CASE("Tensor Gemv") A.generate([](){ return randomReal(); }); X.generate([](){ return randomReal(); }); Y.generate([](){ return randomReal(); }); - double alpha = randomReal(); - double beta = randomReal(); + const auto alpha = randomReal(); // N.B. use different types for scalar and tensor + const auto beta = randomReal(); // N.B. use different types for scalar and tensor Tensor Ytest=Y; scal(beta,Ytest); for(long i=0;i(); }); B.generate([](){ return randomReal(); }); C.generate([](){ return randomReal(); }); - double alpha = randomReal(); - double beta = randomReal(); + const auto alpha = randomReal(); // N.B. use different types for scalar and tensor + const auto beta = randomReal(); // N.B. use different types for scalar and tensor Ctest=C; scal(beta,Ctest); contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},beta,C,{'i','m','l'}); @@ -700,7 +701,7 @@ TEST_CASE("Contraction") Tensor D; Tensor Dtest(2,4,6); Dtest.fill(0.0); - contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},0.0,D,{'i','m','l'}); + contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},static_cast(0.0),D,{'i','m','l'}); for(long i=0;i