From e244b7b2a296583c21cc3f769a5cc6eeb18bab71 Mon Sep 17 00:00:00 2001 From: bhsueh Date: Mon, 15 Aug 2022 20:16:56 -0700 Subject: [PATCH] fix: remove allow_gemm_test flag of bert_example.cc --- examples/cpp/bert/bert_example.cc | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/examples/cpp/bert/bert_example.cc b/examples/cpp/bert/bert_example.cc index a33554120..04a033e9a 100644 --- a/examples/cpp/bert/bert_example.cc +++ b/examples/cpp/bert/bert_example.cc @@ -25,20 +25,15 @@ int bertExample(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - bool is_remove_padding, - bool allow_gemm_test = false); + bool is_remove_padding); int main(int argc, char** argv) { - if (argc != 8 && argc != 9) { + if (argc != 8) { FT_LOG_ERROR("bert_example batch_size num_layers seq_len head_num size_per_head data_type is_remove_padding"); FT_LOG_ERROR("e.g., ./bin/bert_example 32 12 32 12 64 0 0"); return 0; } - bool allow_gemm_test = false; - if (argc == 9) { - allow_gemm_test = (atoi(argv[8]) == 1) ? true : false; - } int batch_size = atoi(argv[1]); int num_layers = atoi(argv[2]); @@ -50,16 +45,16 @@ int main(int argc, char** argv) if (data_type == FLOAT_DATATYPE) { return bertExample( - batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); + batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding); } else if (data_type == HALF_DATATYPE) { return bertExample( - batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); + batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding); } #ifdef ENABLE_BF16 else if (data_type == BFLOAT16_DATATYPE) { return bertExample<__nv_bfloat16>( - batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test); + batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding); } #endif else { @@ -74,8 +69,7 @@ int bertExample(size_t batch_size, size_t seq_len, size_t head_num, size_t size_per_head, - bool is_remove_padding, - bool allow_gemm_test) + bool is_remove_padding) { printf("[INFO] Device: %s \n", getDeviceName().c_str()); print_mem_usage("Before loading model");