Skip to content

Commit

Permalink
最后调整一个逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-kan committed Apr 23, 2021
1 parent 70650ee commit f9c4a4c
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/blas/HPLAI_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1002,14 +1002,34 @@ void blas::gemm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
hBsize = (K1 * N1 * sizeof(aclFloat16) + 63) / 32 * 32,
hCsize = (M1 * N1 * sizeof(aclFloat16) + 63) / 32 * 32;

if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < sAsize + sBsize)
HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(sAsize + sBsize);
int64_t device_buffer_size = hCsize + hBsize + hAsize + sAsize + sBsize;
if (device_buffer_size < hCsize + sCsize)
device_buffer_size = hCsize + sCsize;
if (HPLAI_ACL_BLASPP_DEVICE_BUFFER_SIZE < device_buffer_size)
HPLAI_ACL_BLASPP_DEVICE_BUFFER_RESIZE(device_buffer_size);

if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < sCsize)
HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(sCsize);
char *hCdevice = reinterpret_cast<char *>(HPLAI_ACL_BLASPP_DEVICE_BUFFER);
char *hBdevice = hCdevice + hCsize;
char *hAdevice = hBdevice + hBsize;
char *sCdevice = hCdevice + hCsize;
char *sAdevice = hAdevice + hAsize;
char *sBdevice = sAdevice + sAsize;
char *sChost = sCdevice;
char *sAhost = sAdevice;
char *sBhost = sBdevice;

char *sChost = reinterpret_cast<char *>(HPLAI_ACL_BLASPP_HOST_BUFFER),
*sAhost = sChost, *sBhost = sAhost + sAsize;
if (HPLAI_ACL_BLASPP_RUNMODE == ACL_HOST)
{
int64_t host_buffer_size = sAsize + sBsize;
if (host_buffer_size < sCsize)
host_buffer_size = sCsize;

if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < device_buffer_size)
HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(device_buffer_size);
sChost = reinterpret_cast<char *>(HPLAI_ACL_BLASPP_HOST_BUFFER),
sAhost = sChost;
sBhost = sAhost + sAsize;
}

if (TRANSA == blas::Op::NoTrans)
{
Expand All @@ -1032,18 +1052,6 @@ void blas::gemm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
K1);
}

int64_t device_buffer_size = hCsize + hBsize + hAsize + sAsize + sBsize;
if (device_buffer_size < hCsize + sCsize)
device_buffer_size = hCsize + sCsize;
if (HPLAI_ACL_BLASPP_DEVICE_BUFFER_SIZE < device_buffer_size)
HPLAI_ACL_BLASPP_DEVICE_BUFFER_RESIZE(device_buffer_size);
char *hCdevice = reinterpret_cast<char *>(HPLAI_ACL_BLASPP_DEVICE_BUFFER);
char *hBdevice = hCdevice + hCsize;
char *hAdevice = hBdevice + hBsize;
char *sCdevice = hCdevice + hCsize;
char *sAdevice = hAdevice + hAsize;
char *sBdevice = sAdevice + sAsize;

if (HPLAI_ACL_BLASPP_RUNMODE == ACL_HOST)
ACLCHECK(aclrtMemcpyAsync(
reinterpret_cast<void *>(sAdevice),
Expand Down

0 comments on commit f9c4a4c

Please sign in to comment.