Skip to content

Commit

Permalink
support for mixed typing of scalars/data in blas/lapack-like function…
Browse files Browse the repository at this point in the history
…s + better type checking
  • Loading branch information
evaleev committed Oct 10, 2023
1 parent 4321ef3 commit a02be0d
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 64 deletions.
8 changes: 6 additions & 2 deletions btas/generic/axpy_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ template<> struct axpy_impl<true>
_IteratorY itrY, const iterator_difference_t<_IteratorY>& incY,
blas_lapack_impl_tag)
{
blas::axpy( Nsize, alpha, static_cast<const _T*>(&(*itrX)), incX,
static_cast<_T*>(&(*itrY)), incY );
static_assert(std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorY>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorX>;

blas::axpy( Nsize, static_cast<T>(alpha), static_cast<const T*>(&(*itrX)), incX,
static_cast<T*>(&(*itrY)), incY );
}
#endif

Expand Down
31 changes: 10 additions & 21 deletions btas/generic/dot_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,13 @@ struct dotc_impl<true> {
_IteratorY itrY, const iterator_difference_t<_IteratorY>& incY,
blas_lapack_impl_tag)
{

using x_traits = std::iterator_traits<_IteratorX>;
using y_traits = std::iterator_traits<_IteratorY>;

using x_value_type = typename x_traits::value_type;
using y_value_type = typename y_traits::value_type;

using x_ptr_type = const x_value_type*;
using y_ptr_type = const y_value_type*;
static_assert(std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorY>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorX>;

// XXX: DOTC == DOT in BLASPP
return blas::dot( Nsize, static_cast<x_ptr_type>(&(*itrX)), incX,
static_cast<y_ptr_type>(&(*itrY)), incY );
return blas::dot( Nsize, static_cast<const T*>(&(*itrX)), incX,
static_cast<const T*>(&(*itrY)), incY );

}
#endif
Expand Down Expand Up @@ -172,17 +166,12 @@ struct dotu_impl<true> {
blas_lapack_impl_tag)
{

using x_traits = std::iterator_traits<_IteratorX>;
using y_traits = std::iterator_traits<_IteratorY>;

using x_value_type = typename x_traits::value_type;
using y_value_type = typename y_traits::value_type;

using x_ptr_type = const x_value_type*;
using y_ptr_type = const y_value_type*;
static_assert(std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorY>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorX>;

return blas::dotu( Nsize, static_cast<x_ptr_type>(&(*itrX)), incX,
static_cast<y_ptr_type>(&(*itrY)), incY );
return blas::dotu( Nsize, static_cast<const T*>(&(*itrX)), incX,
static_cast<const T*>(&(*itrY)), incY );

}
#endif
Expand Down
28 changes: 10 additions & 18 deletions btas/generic/gemm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,24 +252,16 @@ template<> struct gemm_impl<true>
const unsigned long& LDC,
blas_lapack_impl_tag)
{

using a_traits = std::iterator_traits<_IteratorA>;
using b_traits = std::iterator_traits<_IteratorB>;
using c_traits = std::iterator_traits<_IteratorC>;

using a_value_type = typename a_traits::value_type;
using b_value_type = typename b_traits::value_type;
using c_value_type = typename c_traits::value_type;

using a_ptr_type = const a_value_type*;
using b_ptr_type = const b_value_type*;
using c_ptr_type = c_value_type*;

blas::gemm( order, transA, transB, Msize, Nsize, Ksize, alpha,
static_cast<a_ptr_type>(&(*itrA)), LDA,
static_cast<b_ptr_type>(&(*itrB)), LDB,
beta,
static_cast<c_ptr_type>(&(*itrC)), LDC );
static_assert(std::is_same_v<iterator_value_t<_IteratorA>,iterator_value_t<_IteratorB>> &&
std::is_same_v<iterator_value_t<_IteratorA>,iterator_value_t<_IteratorC>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorA>;

blas::gemm( order, transA, transB, Msize, Nsize, Ksize, static_cast<T>(alpha),
static_cast<const T*>(&(*itrA)), LDA,
static_cast<const T*>(&(*itrB)), LDB,
static_cast<T>(beta),
static_cast<T*>(&(*itrC)), LDC );
}
#endif

Expand Down
15 changes: 10 additions & 5 deletions btas/generic/gemv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ template<> struct gemv_impl<true>
blas_lapack_impl_tag)
{

blas::gemv( order, transA, Msize, Nsize, alpha,
static_cast<const _T*>(&(*itrA)), LDA,
static_cast<const _T*>(&(*itrX)), incX,
beta,
static_cast< _T*>(&(*itrY)), incY );
static_assert(std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorY>> &&
std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorA>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorX>;

blas::gemv( order, transA, Msize, Nsize, static_cast<T>(alpha),
static_cast<const T*>(&(*itrA)), LDA,
static_cast<const T*>(&(*itrX)), incX,
static_cast<T>(beta),
static_cast< T*>(&(*itrY)), incY );

}
#endif
Expand Down
14 changes: 9 additions & 5 deletions btas/generic/ger_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ template<> struct ger_impl<true>
const unsigned long& LDA,
blas_lapack_impl_tag)
{

blas::geru( order, Msize, Nsize, alpha,
static_cast<const _T*>(&(*itrX)), incX,
static_cast<const _T*>(&(*itrY)), incY,
static_cast< _T*>(&*(itrA)), LDA );
static_assert(std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorY>> &&
std::is_same_v<iterator_value_t<_IteratorX>,iterator_value_t<_IteratorA>>,
"mismatching iterator value types");
using T = iterator_value_t<_IteratorX>;

blas::geru( order, Msize, Nsize, static_cast<T>(alpha),
static_cast<const T*>(&(*itrX)), incX,
static_cast<const T*>(&(*itrY)), incY,
static_cast< T*>(&*(itrA)), LDA );
}
#endif

Expand Down
2 changes: 0 additions & 2 deletions btas/generic/gesvd_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ template<> struct gesvd_impl<true>
if( inplaceU and inplaceVt )
BTAS_EXCEPTION("SVD cannot return both vectors inplace");



value_type dummy;
value_type* A = static_cast<value_type*>(&(*itrA));
value_type* U = (needU and not inplaceU) ?
Expand Down
3 changes: 2 additions & 1 deletion btas/generic/scal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ template<> struct scal_impl<true>
_IteratorX itrX, const iterator_difference_t<_IteratorX>& incX,
blas_lapack_impl_tag)
{
blas::scal( Nsize, alpha, static_cast<_T*>(&(*itrX)), incX );
using T = iterator_value_t<_IteratorX>;
blas::scal( Nsize, static_cast<T>(alpha), static_cast<T*>(&(*itrX)), incX );
}
#endif

Expand Down
3 changes: 3 additions & 0 deletions btas/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ namespace btas {

// Convienience traits
template <typename _Iterator>
using iterator_value_t =
typename std::iterator_traits<_Iterator>::value_type;
template <typename _Iterator>
using iterator_difference_t =
typename std::iterator_traits<_Iterator>::difference_type;

Expand Down
21 changes: 11 additions & 10 deletions unittest/tensor_blas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ TEST_CASE("Tensor Scal")
Tensor<double> T(4,2,6,5);
T.generate([](){ return randomReal<double>(); });
Tensor<double> Tbak=T;
double d = randomReal<double>();
const auto d = randomReal<float>(); // N.B. use different types for scalar and tensor
scal(d,T);
double res=0;
for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d);
Expand All @@ -136,10 +136,11 @@ TEST_CASE("Tensor Scal")
Tensor<std::complex<double>> T(4,2,6,5);
T.generate([](){ return randomCplx<double>(); });
Tensor<std::complex<double>> Tbak=T;
std::complex<double> d = randomCplx<double>();
const auto d = randomCplx<float>(); // N.B. use different types for scalar and tensor
scal(d,T);
double res=0;
for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d);
std::complex<double> d_double(d);
for(auto i : T.range()) res+=std::abs(T(i)-Tbak(i)*d_double);
CHECK(res < eps_double);
}

Expand All @@ -166,7 +167,7 @@ TEST_CASE("Tensor Axpy")
X.generate([](){ return randomReal<double>(); });
Y.generate([](){ return randomReal<double>(); });
Tensor<double> Ybak=Y;
double alpha = randomReal<double>();
const auto alpha = randomReal<float>(); // N.B. use different types for scalar and tensor
axpy(alpha,X,Y);
double res=0;
for(auto i : Y.range()) res+=std::abs(Ybak(i)+X(i)*alpha-Y(i));
Expand Down Expand Up @@ -228,7 +229,7 @@ TEST_CASE("Tensor Ger")
X.generate([](){ return randomReal<double>(); });
Y.generate([](){ return randomReal<double>(); });
Tensor<double> Abak=A;
double a = randomReal<double>();
const auto a = randomReal<float>(); // N.B. use different types for scalar and tensor
ger(a,X,Y,A);
double res=0;
for(auto i : A.range()) res+=std::abs(a*X(i[0],i[1])*Y(i[2],i[3])+Abak(i)-A(i));
Expand Down Expand Up @@ -296,8 +297,8 @@ TEST_CASE("Tensor Gemv")
A.generate([](){ return randomReal<double>(); });
X.generate([](){ return randomReal<double>(); });
Y.generate([](){ return randomReal<double>(); });
double alpha = randomReal<double>();
double beta = randomReal<double>();
const auto alpha = randomReal<float>(); // N.B. use different types for scalar and tensor
const auto beta = randomReal<float>(); // N.B. use different types for scalar and tensor
Tensor<double> Ytest=Y;
scal(beta,Ytest);
for(long i=0;i<A.extent(0);i++)
Expand Down Expand Up @@ -681,8 +682,8 @@ TEST_CASE("Contraction")
A.generate([](){ return randomReal<double>(); });
B.generate([](){ return randomReal<double>(); });
C.generate([](){ return randomReal<double>(); });
double alpha = randomReal<double>();
double beta = randomReal<double>();
const auto alpha = randomReal<float>(); // N.B. use different types for scalar and tensor
const auto beta = randomReal<float>(); // N.B. use different types for scalar and tensor
Ctest=C;
scal(beta,Ctest);
contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},beta,C,{'i','m','l'});
Expand All @@ -700,7 +701,7 @@ TEST_CASE("Contraction")
Tensor<double> D;
Tensor<double> Dtest(2,4,6);
Dtest.fill(0.0);
contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},0.0,D,{'i','m','l'});
contract(alpha,A,{'i','j','k'},B,{'k','j','l','m'},static_cast<decltype(alpha)>(0.0),D,{'i','m','l'});
for(long i=0;i<A.extent(0);i++)
for(long j=0;j<A.extent(1);j++)
for(long k=0;k<A.extent(2);k++)
Expand Down

0 comments on commit a02be0d

Please sign in to comment.