Skip to content

Commit

Permalink
Initial commit (#5790)
Browse files Browse the repository at this point in the history
  • Loading branch information
Critsium-xy authored Jan 1, 2025
1 parent d7b76fc commit 3a2eb06
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 5 deletions.
143 changes: 142 additions & 1 deletion source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
}

// C = a * A.? * B.? + b * C
// Row-Major part
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
Expand Down Expand Up @@ -154,6 +155,147 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
#endif
}

// Col-Major part
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

// Symm and Hemm part. Only col-major is supported.

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
ssymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dsymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
csymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zsymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
chemm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zhemm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
Expand Down Expand Up @@ -190,7 +332,6 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
}
}


// out = ||x||_2
float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
{
Expand Down
74 changes: 72 additions & 2 deletions source/module_base/blas_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,23 @@ extern "C"
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);

//a is symmetric
// A is symmetric. C = a * A.? * B.? + b * C
void ssymm_(const char *side, const char *uplo, const int *m, const int *n,
const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
const float *beta, float *c, const int *ldc);
void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
const double *beta, double *c, const int *ldc);
//a is hermitian
void csymm_(const char *side, const char *uplo, const int *m, const int *n,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda, const std::complex<float> *b, const int *ldb,
const std::complex<float> *beta, std::complex<float> *c, const int *ldc);
void zsymm_(const char *side, const char *uplo, const int *m, const int *n,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);

// A is hermitian. C = a * A.? * B.? + b * C
void chemm_(char *side, char *uplo, int *m, int *n,std::complex<float> *alpha,
std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, std::complex<float> *beta, std::complex<float> *c, int *ldc);
void zhemm_(char *side, char *uplo, int *m, int *n,std::complex<double> *alpha,
std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, std::complex<double> *beta, std::complex<double> *c, int *ldc);

Expand Down Expand Up @@ -175,6 +187,7 @@ class BlasConnector

// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
// C = a * A.? * B.? + b * C
// Row Major by default
static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
Expand All @@ -195,6 +208,61 @@ class BlasConnector
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// Col-Major if you need to use it

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// Because you cannot pack symm or hemm into a row-major kernel by exchanging parameters, so only col-major functions are provided.
static
void symm_cm(const char side, const char uplo, const int m, const int n,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void hemm_cm(char side, char uplo, int m, int n,
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void hemm_cm(char side, char uplo, int m, int n,
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// y = A*x + beta*y

static
void gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
Expand Down Expand Up @@ -234,6 +302,8 @@ class BlasConnector

static
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// A is symmetric
};

// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,
Expand Down
5 changes: 3 additions & 2 deletions source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "gint_tools.h"
#include "module_base/timer.h"
#include "module_base/ylm.h"
#include "module_base/blas_connector.h"

namespace Gint_Tools{

Expand Down Expand Up @@ -60,8 +61,8 @@ void mult_psi_DMR(

const auto tmp_matrix_ptr = tmp_matrix->get_pointer();
const int idx1 = block_index[ia1];
dsymm_(&side, &uplo, &block_size[ia1], &ib_len, &alpha, tmp_matrix_ptr, &block_size[ia1],
&psi[ib_start][idx1], &LD_pool, &beta, &psi_DMR[ib_start][idx1], &LD_pool);
BlasConnector::symm_cm(side, uplo, block_size[ia1], ib_len, alpha, tmp_matrix_ptr, block_size[ia1],
&psi[ib_start][idx1], LD_pool, beta, &psi_DMR[ib_start][idx1], LD_pool);
}

//! get (j,beta,R2)
Expand Down

0 comments on commit 3a2eb06

Please sign in to comment.