From dc123265709763230ac5159c4fd1c3c426d8a807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 5 Dec 2024 11:52:29 +0100 Subject: [PATCH] Apply formatting suggestions from code review Co-authored-by: Finlay --- .../35_gemm_softmax/gemm_online_softmax.cpp | 4 ++-- examples/35_gemm_softmax/softmax_finalize.hpp | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 67dff1460..2a6e6f1d5 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -128,7 +128,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "14_ampere_tf32_tensorop_gemm_cute example\n\n" + out << "35_gemm_softmax example\n\n" << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" @@ -238,7 +238,7 @@ struct ExampleRunner { // Methods // template - bool verify_tensor(std::vector vector_Input, \ + bool verify_tensor(std::vector vector_Input, std::vector vector_Input_Ref, const Options& options) { auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 30ecd90fe..ca6e6ac93 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -129,7 +129,7 @@ class SoftmaxFinalize { const int y_size = BlockDimY(); const int batch_id = BlockIdxY(); - if(m>=params.args.M){ + if (m >= params.args.M) { return; } @@ -146,43 +146,43 @@ class SoftmaxFinalize { make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); ElementPartial max_val = std::numeric_limits::lowest(); - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(idx_x,idx_y) = max_val; + sPartial(idx_x, idx_y) = max_val; syncthreads(); // tree-reduction could be better, although it does not seem to be a bottleneck - for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_max = sPartial(idx_x,idx_y2); + for (int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x, idx_y2); max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(idx_x,idx_y) = sum_val; + sPartial(idx_x, idx_y) = sum_val; syncthreads(); sum_val = 0; // tree-reduction could be better, although it does not seem to be a bottleneck for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_sum = sPartial(idx_x,idx_y2); + ElementPartial partial_sum = sPartial(idx_x, idx_y2); sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ + for (int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; } - if(params.args.dataN % 2 == 1){ + if (params.args.dataN % 2 == 1){ int n = params.args.dataN - 1; auto inVal = mIn(m, n, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm;