Skip to content

Commit

Permalink
Fix incorrect type casting for alpha and beta in f16 compute type
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffxzx committed Dec 17, 2024
1 parent 4d40e36 commit 3600099
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 46 deletions.
79 changes: 54 additions & 25 deletions clients/common/cblas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ template <typename TcCast, typename Tc, typename TiA>
void cast_mul(customVector<TcCast>& dst,
const TiA* A,
bool isScaleAVec,
const Tc* scaleAVec,
const Tc* AlphaVec,
const TcCast* scaleAVec,
const TcCast* AlphaVec,
bool transA,
int64_t m,
int64_t k,
Expand Down Expand Up @@ -282,8 +282,8 @@ void cast_mul(customVector<TcCast>& dst,
const void* src,
hipDataType TiA,
bool isScaleAVec,
const Tc* scaleAVec,
const Tc* AlphaVec,
const TcCast* scaleAVec,
const TcCast* AlphaVec,
bool transA,
int64_t m,
int64_t k,
Expand Down Expand Up @@ -413,8 +413,8 @@ template <typename TcCast, typename Tc, typename TciACast, typename TiA>
void cast_mul_with_Tci(customVector<TcCast>& dst,
const TiA* A,
bool isScaleAVec,
const Tc* scaleAVec,
const Tc* AlphaVec,
const TcCast* scaleAVec,
const TcCast* AlphaVec,
bool transA,
int64_t m,
int64_t k,
Expand Down Expand Up @@ -489,8 +489,8 @@ void cast_mul_with_Tci(customVector<TcCast>& dst,
const void* src,
hipDataType TiA,
bool isScaleAVec,
const Tc* scaleAVec,
const Tc* AlphaVec,
const TcCast* scaleAVec,
const TcCast* AlphaVec,
bool transA,
int64_t m,
int64_t k,
Expand Down Expand Up @@ -627,8 +627,8 @@ void cast_mul_with_Tci(customVector<TcCast>& dst,
const void* src,
hipDataType TiA,
bool isScaleAVec,
const Tc* scaleAVec,
const Tc* AlphaVec,
const TcCast* scaleAVec,
const TcCast* AlphaVec,
bool transA,
int64_t m,
int64_t k,
Expand Down Expand Up @@ -880,9 +880,9 @@ void cblas_gemm(hipblasOperation_t transA,
Tc beta,
std::add_pointer_t<void> C,
int64_t ldc,
const Tc* AlphaVec,
const Tc* scaleAVec,
const Tc* scaleBVec,
const void* AlphaVec,
const void* scaleAVec,
const void* scaleBVec,
Tc scaleD,
bool isScaleAVec,
bool isScaleBVec,
Expand All @@ -894,8 +894,16 @@ void cblas_gemm(hipblasOperation_t transA,
hipDataType TciB,
bool alt)
{
using TcCast = std::conditional_t<std::is_same<Tc, int32_t>::value, double, Tc>;
Tc_enum = (Tc_enum == HIP_R_32I) ? HIP_R_64F : Tc_enum;
using IntTcCast = std::conditional_t<std::is_same<Tc, int32_t>::value, double, Tc>;
using HalfTcCast = std::conditional_t<std::is_same<Tc, hipblasLtHalf>::value, float, Tc>;
using TcCast = std::conditional_t<std::is_same<Tc, int32_t>::value, IntTcCast, HalfTcCast>;

if (Tc_enum == HIP_R_32I) {
Tc_enum = HIP_R_64F;
} else if (Tc_enum == HIP_R_16F) {
Tc_enum = HIP_R_32F;
}

hipDataType TciACast = (TciA == HIP_R_32I) ? HIP_R_64F : TciA;
hipDataType TciBCast = (TciB == HIP_R_32I) ? HIP_R_64F : TciB;

Expand All @@ -904,8 +912,28 @@ void cblas_gemm(hipblasOperation_t transA,
size_t sizeA = (transA == HIPBLAS_OP_N ? k : m) * size_t(lda);
size_t sizeB = (transB == HIPBLAS_OP_N ? n : k) * size_t(ldb);
size_t sizeC = n * size_t(ldc);
size_t scaleAVec_size = isScaleAVec ? (transA == HIPBLAS_OP_N ? k : m) : 1;
size_t scaleBVec_size = isScaleBVec ? (transB != HIPBLAS_OP_N ? n : k) : 1;

customVector<TcCast> A_Tc, B_Tc, C_Tc, AlphaVec_Tc, scaleA_Tc, scaleB_Tc;

customVector<TcCast> A_Tc, B_Tc, C_Tc;
if(AlphaVec)
{
AlphaVec_Tc.initialize(m);
cast_mul(AlphaVec_Tc, (Tc*)AlphaVec, m);
}

if(scaleAVec)
{
scaleA_Tc.initialize(scaleAVec_size);
cast_mul(scaleA_Tc, (Tc*)scaleAVec, scaleAVec_size);
}

if(scaleBVec)
{
scaleB_Tc.initialize(scaleBVec_size);
cast_mul(scaleB_Tc, (Tc*)scaleBVec, scaleBVec_size);
}

A_Tc.initialize(sizeA);
if(realDataTypeSize(TiA) > realDataTypeSize(TciACast))
Expand All @@ -914,8 +942,8 @@ void cblas_gemm(hipblasOperation_t transA,
A,
TiA,
isScaleAVec,
scaleAVec,
AlphaVec,
scaleA_Tc,
AlphaVec_Tc,
transA == HIPBLAS_OP_N,
m,
k,
Expand All @@ -925,7 +953,7 @@ void cblas_gemm(hipblasOperation_t transA,
else
{
cast_mul<TcCast, Tc>(
A_Tc, A, TiA, isScaleAVec, scaleAVec, AlphaVec, transA == HIPBLAS_OP_N, m, k, sizeA);
A_Tc, A, TiA, isScaleAVec, scaleA_Tc, AlphaVec_Tc, transA == HIPBLAS_OP_N, m, k, sizeA);
}

B_Tc.initialize(sizeB);
Expand All @@ -935,7 +963,7 @@ void cblas_gemm(hipblasOperation_t transA,
B,
TiB,
isScaleBVec,
scaleBVec,
scaleB_Tc,
nullptr,
transB != HIPBLAS_OP_N,
n,
Expand All @@ -946,7 +974,7 @@ void cblas_gemm(hipblasOperation_t transA,
else
{
cast_mul<TcCast, Tc>(
B_Tc, B, TiB, isScaleBVec, scaleBVec, nullptr, transB != HIPBLAS_OP_N, n, k, sizeB);
B_Tc, B, TiB, isScaleBVec, scaleB_Tc, nullptr, transB != HIPBLAS_OP_N, n, k, sizeB);
}

if(To == Tc_enum)
Expand Down Expand Up @@ -993,7 +1021,7 @@ void cblas_gemm(hipblasOperation_t transA,
{
static constexpr int64_t small = 600; // seeing random NaNs with blis on some small sizes
if(m > small || n > small || k > small || lda > small || ldb > small || ldc > small)
{
{
cblas_dgemm(CblasColMajor,
HIPOperationToCBLASTanspose(transA),
HIPOperationToCBLASTanspose(transB),
Expand Down Expand Up @@ -1042,9 +1070,9 @@ void cblas_gemm(hipblasOperation_t transA,
Tc beta, \
std::add_pointer_t<void> C, \
int64_t ldc, \
const Tc* AlphaVec, \
const Tc* scaleAVec, \
const Tc* scaleBVec, \
const void* AlphaVec, \
const void* scaleAVec, \
const void* scaleBVec, \
Tc scaleD, \
bool isScaleAVec, \
bool isScaleBVec, \
Expand All @@ -1056,6 +1084,7 @@ void cblas_gemm(hipblasOperation_t transA,
hipDataType TciB, \
bool alt);

CREATEFUNCTION(hipblasLtHalf)
CREATEFUNCTION(float)
CREATEFUNCTION(double)
CREATEFUNCTION(int32_t)
53 changes: 40 additions & 13 deletions clients/include/cblas_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ void cblas_gemm(hipblasOperation_t transA,
Tc beta,
std::add_pointer_t<void> C,
int64_t ldc,
const Tc* AlphaVec,
const Tc* scaleA,
const Tc* scaleB,
const void* AlphaVec,
const void* scaleA,
const void* scaleB,
Tc scaleD,
bool isScaleAVec,
bool isScaleBVec,
Expand Down Expand Up @@ -96,6 +96,33 @@ inline void cblas_gemm(hipblasOperation_t transA,
switch(tc)
{
case HIP_R_16F: // setting compute_type to f16_r will fallback to f32_r
cblas_gemm<hipblasLtHalf>(transA,
transB,
m,
n,
k,
alpha.f16,
A,
lda,
B,
ldb,
beta.f16,
C,
ldc,
AlphaVec,
scaleA,
scaleB,
*(hipblasLtHalf*)scaleD,
isScaleAVec,
isScaleBVec,
tiA,
tiB,
to,
tc,
tciA,
tciB,
alt);
return;
case HIP_R_32F:
cblas_gemm<float>(transA,
transB,
Expand All @@ -110,16 +137,16 @@ inline void cblas_gemm(hipblasOperation_t transA,
beta.f32,
C,
ldc,
(const float*)AlphaVec,
(const float*)scaleA,
(const float*)scaleB,
AlphaVec,
scaleA,
scaleB,
*(float*)scaleD,
isScaleAVec,
isScaleBVec,
tiA,
tiB,
to,
HIP_R_32F,
tc,
tciA,
tciB,
alt);
Expand All @@ -138,9 +165,9 @@ inline void cblas_gemm(hipblasOperation_t transA,
beta.f64,
C,
ldc,
(const double*)AlphaVec,
(const double*)scaleA,
(const double*)scaleB,
AlphaVec,
scaleA,
scaleB,
*(double*)scaleD,
isScaleAVec,
isScaleBVec,
Expand All @@ -166,9 +193,9 @@ inline void cblas_gemm(hipblasOperation_t transA,
beta.i32,
C,
ldc,
(const int32_t*)AlphaVec,
(const int32_t*)scaleA,
(const int32_t*)scaleB,
AlphaVec,
scaleA,
scaleB,
*(int32_t*)scaleD,
isScaleAVec,
isScaleBVec,
Expand Down
2 changes: 1 addition & 1 deletion clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ void testing_matmul_with_bias(const Arguments& arg,
hipblasOperation_t transA(char_to_hipblas_operation(arg.transA));
hipblasOperation_t transB(char_to_hipblas_operation(arg.transB));

hipDataType Talpha = (Tc == HIP_R_16F ? HIP_R_32F : Tc);
hipDataType Talpha = Tc;

bool do_grouped_gemm = arg.grouped_gemm > 0;
int32_t gemm_count = std::max(1, arg.grouped_gemm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ inline TensileLite::DataType hipDataType_to_tensile_type(hipDataType type)

namespace
{
TensileLite::DataType roc2TensileType(rocblaslt_compute_type);
TensileLite::DataType roc2TensileType(rocblaslt_compute_type, bool);
}

namespace TensileLite
Expand Down
5 changes: 5 additions & 0 deletions library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* a
*((int32_t*)alpha) = 1.f;
*((int32_t*)beta) = 1.f;
}
else if(compute_type == rocblaslt_compute_f16)
{
*((hipblasLtHalf*)alpha) = 1.f;
*((hipblasLtHalf*)beta) = 1.f;
}
else
{
*((float*)alpha) = 1.f;
Expand Down
Loading

0 comments on commit 3600099

Please sign in to comment.