Skip to content

Commit

Permalink
Modify to check if alpha is in host memory. (#1356)
Browse files Browse the repository at this point in the history
* use correct data type 'rocblaslt_pointer_mode' for pointer mode.

* Modify to check if alpha is in host memory.
  • Loading branch information
geotseng-amd authored Dec 5, 2024
1 parent 9f30df5 commit b4e5042
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 42 deletions.
2 changes: 1 addition & 1 deletion library/include/hipblaslt-ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ namespace hipblaslt_ext
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* scaleA; //!< The scaleA input pointer.
void* scaleB; //!< The scaleA input pointer.
void* scaleB; //!< The scaleB input pointer.
void* scaleC; //!< The scaleC input pointer.
void* scaleD; //!< The scaleD input pointer.
void* scaleAlphaVec; //!< The scaleAlpha vector input pointer.
Expand Down
18 changes: 11 additions & 7 deletions library/src/amd_detail/rocblaslt/include/rocblaslt-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ typedef enum rocblaslt_epilogue_
typedef enum rocblaslt_pointer_mode_
{
rocblaslt_pointer_mode_host = 0, /**< scalar pointers are in host memory. */
rocblaslt_pointer_mode_device = 1 /**< scalar pointers are in device memory. */
rocblaslt_pointer_mode_device = 1, /**< scalar pointers are in device memory. */
rocblaslt_pointer_mode_alpha_device_vector_beta_host
= 4 /** alpha pointer targets a device memory vector of length equal to the number of rows of matrix D, and beta is a single value in host memory. */
} rocblaslt_pointer_mode;

/*! \ingroup types_module
Expand Down Expand Up @@ -265,16 +267,18 @@ typedef enum rocblaslt_compute_type_
rocblaslt_compute_f64_pedantic = 8, /**< compute will be exactly 64-bit precision */
rocblaslt_compute_i32 = 9, /**< 32-bit integer precision. */
rocblaslt_compute_i32_pedantic = 10, /**< compute will be exactly 32-bit integer precision */
rocblaslt_compute_f32_fast_f8_fnuz = 100, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_fnuz = 101, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8_fnuz = 100, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_fnuz = 101, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_fnuz
= 102, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_fnuz
= 103, /**< 32-bit input can use bf8 for A and fp8 for B compute */
rocblaslt_compute_f32_fast_f8_ocp = 104, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_ocp = 105, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_ocp = 106, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_ocp = 107, /**< 32-bit input can use bf8 for A and fp8 for B compute */
rocblaslt_compute_f32_fast_f8_ocp = 104, /**< 32-bit input can use fp8 compute */
rocblaslt_compute_f32_fast_bf8_ocp = 105, /**< 32-bit input can use bf8 compute */
rocblaslt_compute_f32_fast_f8bf8_ocp
= 106, /**< 32-bit input can use fp8 for A and bf8 for B compute */
rocblaslt_compute_f32_fast_bf8f8_ocp
= 107, /**< 32-bit input can use bf8 for A and fp8 for B compute */
} rocblaslt_compute_type;

/*! \ingroup types_module
Expand Down
19 changes: 10 additions & 9 deletions library/src/amd_detail/rocblaslt/src/include/handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ struct _rocblaslt_matmul_desc
hipblasOperation_t op_B = HIPBLAS_OP_N;
// epilogue operation
rocblaslt_epilogue epilogue = ROCBLASLT_EPILOGUE_DEFAULT;
// alpha,beta pointer mode
rocblaslt_pointer_mode pointermode = rocblaslt_pointer_mode_host;
// bias vector pointer
void* bias = nullptr;
void* scaleA = nullptr;
void* scaleB = nullptr;
void* scaleC = nullptr;
void* scaleD = nullptr;
void* scaleE = nullptr;
void* pointermode = nullptr;
void* amaxD = nullptr;
hipDataType bias_type = HIPBLASLT_DATATYPE_INVALID;
void* bias = nullptr;
void* scaleA = nullptr;
void* scaleB = nullptr;
void* scaleC = nullptr;
void* scaleD = nullptr;
void* scaleE = nullptr;
void* amaxD = nullptr;
hipDataType bias_type = HIPBLASLT_DATATYPE_INVALID;
// E
void* e = nullptr;
int64_t lde = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,25 @@ inline rocblaslt_status validateMatmulDescrArgs(rocblaslt_handle handle,
/*******************************************************************************
* Validate Matmul Arguments
******************************************************************************/
inline rocblaslt_status validateMatmulArgs(int64_t m,
int64_t n,
int64_t k,
const void* alpha,
const void* a,
const void* b,
const void* beta,
const void* c,
const void* d,
int num_batches_a = 1,
int num_batches_b = 1,
int num_batches_c = 1,
int num_batches_d = 1,
int64_t batch_stride_a = 0,
int64_t batch_stride_b = 0,
int64_t batch_stride_c = 0,
int64_t batch_stride_d = 0)
inline rocblaslt_status validateMatmulArgs(int64_t m,
int64_t n,
int64_t k,
const void* alpha,
const void* a,
const void* b,
const void* beta,
const void* c,
const void* d,
int num_batches_a = 1,
int num_batches_b = 1,
int num_batches_c = 1,
int num_batches_d = 1,
int64_t batch_stride_a = 0,
int64_t batch_stride_b = 0,
int64_t batch_stride_c = 0,
int64_t batch_stride_d = 0,
const rocblaslt_pointer_mode& pointermode
= rocblaslt_pointer_mode_host)
{
// sizes must not be negative
if(batch_stride_a < 0 || batch_stride_b < 0 || batch_stride_c < 0 || batch_stride_d < 0)
Expand Down Expand Up @@ -183,9 +185,10 @@ inline rocblaslt_status validateMatmulArgs(int64_t m,
if(!beta)
return rocblaslt_status_invalid_pointer;

// Update for the valid case: ((alpha_in_host && alpha=0) && (A=NULL || B=NULL))
bool alpha_A_B_violation = (!alpha || ((pointermode || (*((float*)alpha))) && (!a || !b)));
// pointers must be valid
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(n && ((k && (!alpha || ((*((float*)alpha)) && (!a || !b)))) || !c || !d))
if(n && ((k && alpha_A_B_violation) || !c || !d))
return rocblaslt_status_invalid_pointer;

return rocblaslt_status_continue;
Expand Down Expand Up @@ -339,7 +342,8 @@ inline rocblaslt_status rocblaslt_matmul_valid_args(const rocblaslt_matmul_desc
batch_stride_a,
batch_stride_b,
batch_stride_c,
batch_stride_d);
batch_stride_d,
matmul_descr->pointermode);

if(status == rocblaslt_status_continue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ rocblaslt_status rocblaslt_matmul_desc_get_attribute(rocblaslt_matmul_desc
log_error(__func__, "invalid buf size", sizeInBytes);
return rocblaslt_status_invalid_value;
}
memcpy(buf, &matmulDesc->pointermode, sizeof(void*));
memcpy(buf, &matmulDesc->pointermode, sizeof(int32_t));
break;
case ROCBLASLT_MATMUL_DESC_BIAS_DATA_TYPE:
if(sizeWritten)
Expand Down
10 changes: 6 additions & 4 deletions library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ rocblaslt_status
batch_stride_a,
batch_stride_b,
batch_stride_c,
batch_stride_d);
batch_stride_d,
matmul_descr[i]->pointermode);
if(validArgs == rocblaslt_status_success)
continue;

Expand Down Expand Up @@ -650,10 +651,11 @@ rocblaslt_status rocblaslt_matmul(rocblaslt_handle handle,
return rocblaslt_status_invalid_handle;
}

// Update for the valid case: ((alpha_in_host && alpha=0) && (A=NULL || B=NULL))
bool alpha_A_B_violation
= (!alpha || ((matmul_descr->pointermode || (*((float*)alpha))) && (!A || !B)));
// Check if pointer is valid
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(alpha == nullptr || beta == nullptr || C == nullptr || D == nullptr
|| ((*((float*)alpha)) && (A == nullptr || B == nullptr)))
if(alpha == nullptr || beta == nullptr || C == nullptr || D == nullptr || alpha_A_B_violation)
{
log_error(__func__, "invalid data pointer");
return rocblaslt_status_invalid_pointer;
Expand Down

0 comments on commit b4e5042

Please sign in to comment.