Skip to content

Commit

Permalink
[GPU] Optimized operations in the blas kernels with the latest buffer…
Browse files Browse the repository at this point in the history
… changes.

Updated the pipeline for both fp32 and fp16.
SGEMM, SGEMV, DotCL, SSACL, Transpose ops updated.

Signed-off-by: Niket Agarwal <[email protected]>
  • Loading branch information
niket-agarwal committed Jan 5, 2025
1 parent 8184b61 commit ae6b8c0
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 123 deletions.
91 changes: 42 additions & 49 deletions nntrainer/tensor/cl_operations/blas_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,26 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {

size_t dim1_size = sizeof(float) * dim1;

opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
nullptr);

opencl::Buffer inputX(cl_context_ref.context_inst_, dim1_size, true,
nullptr);

opencl::Buffer dotResult(cl_context_ref.context_inst_, sizeof(float), true,
&cl_ret);

result = inputA.WriteData(cl_context_ref.command_queue_inst_, vecAdata);
result = clbuffInstance.getInBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim1_size, vecAdata);
if (!result) {
break;
}

result = inputX.WriteData(cl_context_ref.command_queue_inst_, vecXdata);
result = clbuffInstance.getInBufferB()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim1_size, vecXdata);
if (!result) {
break;
}

result = kernel_dot_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
result = kernel_dot_ptr->SetKernelArguments(
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
if (!result) {
break;
}

result = kernel_dot_ptr->SetKernelArguments(1, &inputX, sizeof(cl_mem));
result = kernel_dot_ptr->SetKernelArguments(
1, clbuffInstance.getInBufferB(), sizeof(cl_mem));
if (!result) {
break;
}
Expand All @@ -154,7 +149,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {
break;
}

result = kernel_dot_ptr->SetKernelArguments(3, &dotResult, sizeof(cl_mem));
result = kernel_dot_ptr->SetKernelArguments(
3, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
if (!result) {
break;
}
Expand All @@ -168,7 +164,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {
break;
}

result = dotResult.ReadData(cl_context_ref.command_queue_inst_, &cl_ret);
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
cl_context_ref.command_queue_inst_, sizeof(float), &cl_ret);
if (!result) {
break;
}
Expand Down Expand Up @@ -213,41 +210,38 @@ void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
size_t k_n_size = K * N * sizeof(float);
size_t m_n_size = M * N * sizeof(float);

opencl::Buffer inputA(cl_context_ref.context_inst_, m_k_size, true,
nullptr);

opencl::Buffer inputB(cl_context_ref.context_inst_, k_n_size, true,
nullptr);

opencl::Buffer inOutC(cl_context_ref.context_inst_, m_n_size, true,
nullptr);

result = inputA.WriteData(cl_context_ref.command_queue_inst_, A);
result = clbuffInstance.getInBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, m_k_size, A);
if (!result) {
break;
}

result = inputB.WriteData(cl_context_ref.command_queue_inst_, B);
result = clbuffInstance.getInBufferB()->WriteDataRegion(
cl_context_ref.command_queue_inst_, k_n_size, B);
if (!result) {
break;
}

result = inOutC.WriteData(cl_context_ref.command_queue_inst_, C);
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, m_n_size, C);
if (!result) {
break;
}

result = kernel_sgemm_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
result = kernel_sgemm_ptr->SetKernelArguments(
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
if (!result) {
break;
}

result = kernel_sgemm_ptr->SetKernelArguments(1, &inputB, sizeof(cl_mem));
result = kernel_sgemm_ptr->SetKernelArguments(
1, clbuffInstance.getInBufferB(), sizeof(cl_mem));
if (!result) {
break;
}

result = kernel_sgemm_ptr->SetKernelArguments(2, &inOutC, sizeof(cl_mem));
result = kernel_sgemm_ptr->SetKernelArguments(
2, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
if (!result) {
break;
}
Expand Down Expand Up @@ -281,7 +275,8 @@ void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
break;
}

result = inOutC.ReadData(cl_context_ref.command_queue_inst_, C);
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
cl_context_ref.command_queue_inst_, m_n_size, C);
if (!result) {
break;
}
Expand Down Expand Up @@ -372,14 +367,14 @@ void sscal_cl(float *X, const unsigned int N, const float alpha) {

size_t x_size = N * sizeof(float);

opencl::Buffer inputX(cl_context_ref.context_inst_, x_size, false, nullptr);

result = inputX.WriteData(cl_context_ref.command_queue_inst_, X);
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, x_size, X);
if (!result) {
break;
}

result = kernel_ptr->SetKernelArguments(0, &inputX, sizeof(cl_mem));
result = kernel_ptr->SetKernelArguments(0, clbuffInstance.getOutBufferA(),
sizeof(cl_mem));
if (!result) {
break;
}
Expand All @@ -398,7 +393,8 @@ void sscal_cl(float *X, const unsigned int N, const float alpha) {
break;
}

result = inputX.ReadData(cl_context_ref.command_queue_inst_, X);
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
cl_context_ref.command_queue_inst_, x_size, X);
if (!result) {
break;
}
Expand Down Expand Up @@ -439,30 +435,26 @@ void transpose_cl_axis(const float *in, float *res,
size_t dim_size = sizeof(float) * input_batch_size * input_height *
input_width * input_channels;

opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true,
nullptr);

opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true,
nullptr);

result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
result = clbuffInstance.getInBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim_size, in);
if (!result) {
break;
}

result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res);
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim_size, res);
if (!result) {
break;
}

result =
kernel_transpose_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
result = kernel_transpose_ptr->SetKernelArguments(
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
if (!result) {
break;
}

result =
kernel_transpose_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
result = kernel_transpose_ptr->SetKernelArguments(
1, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
if (!result) {
break;
}
Expand Down Expand Up @@ -503,7 +495,8 @@ void transpose_cl_axis(const float *in, float *res,
break;
}

result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res);
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
cl_context_ref.command_queue_inst_, dim_size, res);
if (!result) {
break;
}
Expand Down
Loading

0 comments on commit ae6b8c0

Please sign in to comment.