Skip to content

Commit

Permalink
Apply formatting suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Finlay <[email protected]>
  • Loading branch information
t4c1 and FMarno authored Dec 5, 2024
1 parent 81b70b6 commit dc12326
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/35_gemm_softmax/gemm_online_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {

out << "14_ampere_tf32_tensorop_gemm_cute example\n\n"
out << "35_gemm_softmax example\n\n"
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
Expand Down Expand Up @@ -238,7 +238,7 @@ struct ExampleRunner {
// Methods
//
template<typename Element>
bool verify_tensor(std::vector<Element> vector_Input, \
bool verify_tensor(std::vector<Element> vector_Input,
std::vector<Element> vector_Input_Ref, const Options& options) {

auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size());
Expand Down
20 changes: 10 additions & 10 deletions examples/35_gemm_softmax/softmax_finalize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class SoftmaxFinalize {
const int y_size = BlockDimY();
const int batch_id = BlockIdxY();

if(m>=params.args.M){
if (m >= params.args.M) {
return;
}

Expand All @@ -146,43 +146,43 @@ class SoftmaxFinalize {
make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp)));

ElementPartial max_val = std::numeric_limits<ElementPartial>::lowest();
for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){
for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){
ElementPartial partial_max = mPartialMax(m, partial_n, batch_id);
max_val = cutlass::fast_max(max_val, partial_max);
}
sPartial(idx_x,idx_y) = max_val;
sPartial(idx_x, idx_y) = max_val;
syncthreads();
// tree-reduction could be better, although it does not seem to be a bottleneck
for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){
ElementPartial partial_max = sPartial(idx_x,idx_y2);
for (int idx_y2 = 0; idx_y2 < y_size; idx_y2++){
ElementPartial partial_max = sPartial(idx_x, idx_y2);
max_val = cutlass::fast_max(max_val, partial_max);
}

ElementPartial sum_val = 0;
for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){
for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){
ElementPartial partial_max = mPartialMax(m, partial_n, batch_id);
ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id);
sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val);
}
syncthreads();
sPartial(idx_x,idx_y) = sum_val;
sPartial(idx_x, idx_y) = sum_val;
syncthreads();
sum_val = 0;
// tree-reduction could be better, although it does not seem to be a bottleneck
for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){
ElementPartial partial_sum = sPartial(idx_x,idx_y2);
ElementPartial partial_sum = sPartial(idx_x, idx_y2);
sum_val += partial_sum;
}

ElementPartial norm = 1 / sum_val;

for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){
for (int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){
auto inVal = mIn(m, n, batch_id);
auto inVal2 = mIn(m, n+1, batch_id);
mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm;
mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm;
}
if(params.args.dataN % 2 == 1){
if (params.args.dataN % 2 == 1){
int n = params.args.dataN - 1;
auto inVal = mIn(m, n, batch_id);
mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm;
Expand Down

0 comments on commit dc12326

Please sign in to comment.