diff --git a/src/blas/HPLAI_blas.cc b/src/blas/HPLAI_blas.cc index 9c8d1c2..0669440 100644 --- a/src/blas/HPLAI_blas.cc +++ b/src/blas/HPLAI_blas.cc @@ -1002,14 +1002,34 @@ void blas::gemm( hBsize = (K1 * N1 * sizeof(aclFloat16) + 63) / 32 * 32, hCsize = (M1 * N1 * sizeof(aclFloat16) + 63) / 32 * 32; - if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < sAsize + sBsize) - HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(sAsize + sBsize); + int64_t device_buffer_size = hCsize + hBsize + hAsize + sAsize + sBsize; + if (device_buffer_size < hCsize + sCsize) + device_buffer_size = hCsize + sCsize; + if (HPLAI_ACL_BLASPP_DEVICE_BUFFER_SIZE < device_buffer_size) + HPLAI_ACL_BLASPP_DEVICE_BUFFER_RESIZE(device_buffer_size); - if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < sCsize) - HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(sCsize); + char *hCdevice = reinterpret_cast(HPLAI_ACL_BLASPP_DEVICE_BUFFER); + char *hBdevice = hCdevice + hCsize; + char *hAdevice = hBdevice + hBsize; + char *sCdevice = hCdevice + hCsize; + char *sAdevice = hAdevice + hAsize; + char *sBdevice = sAdevice + sAsize; + char *sChost = sCdevice; + char *sAhost = sAdevice; + char *sBhost = sBdevice; - char *sChost = reinterpret_cast(HPLAI_ACL_BLASPP_HOST_BUFFER), - *sAhost = sChost, *sBhost = sAhost + sAsize; + if (HPLAI_ACL_BLASPP_RUNMODE == ACL_HOST) + { + int64_t host_buffer_size = sAsize + sBsize; + if (host_buffer_size < sCsize) + host_buffer_size = sCsize; + + if (HPLAI_ACL_BLASPP_HOST_BUFFER_SIZE < device_buffer_size) + HPLAI_ACL_BLASPP_HOST_BUFFER_RESIZE(device_buffer_size); + sChost = reinterpret_cast(HPLAI_ACL_BLASPP_HOST_BUFFER), + sAhost = sChost; + sBhost = sAhost + sAsize; + } if (TRANSA == blas::Op::NoTrans) { @@ -1032,18 +1052,6 @@ void blas::gemm( K1); } - int64_t device_buffer_size = hCsize + hBsize + hAsize + sAsize + sBsize; - if (device_buffer_size < hCsize + sCsize) - device_buffer_size = hCsize + sCsize; - if (HPLAI_ACL_BLASPP_DEVICE_BUFFER_SIZE < device_buffer_size) - HPLAI_ACL_BLASPP_DEVICE_BUFFER_RESIZE(device_buffer_size); - char *hCdevice = reinterpret_cast(HPLAI_ACL_BLASPP_DEVICE_BUFFER); - char *hBdevice = hCdevice + hCsize; - char *hAdevice = hBdevice + hBsize; - char *sCdevice = hCdevice + hCsize; - char *sAdevice = hAdevice + hAsize; - char *sBdevice = sAdevice + sAsize; - if (HPLAI_ACL_BLASPP_RUNMODE == ACL_HOST) ACLCHECK(aclrtMemcpyAsync( reinterpret_cast(sAdevice),