Skip to content

Commit

Permalink
fix: remove allow_gemm_test flag of bert_example.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Aug 16, 2022
1 parent bc21406 commit e244b7b
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions examples/cpp/bert/bert_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,15 @@ int bertExample(size_t batch_size,
size_t seq_len,
size_t head_num,
size_t size_per_head,
bool is_remove_padding,
bool allow_gemm_test = false);
bool is_remove_padding);

int main(int argc, char** argv)
{
if (argc != 8 && argc != 9) {
if (argc != 8) {
FT_LOG_ERROR("bert_example batch_size num_layers seq_len head_num size_per_head data_type is_remove_padding");
FT_LOG_ERROR("e.g., ./bin/bert_example 32 12 32 12 64 0 0");
return 0;
}
bool allow_gemm_test = false;
if (argc == 9) {
allow_gemm_test = (atoi(argv[8]) == 1) ? true : false;
}

int batch_size = atoi(argv[1]);
int num_layers = atoi(argv[2]);
Expand All @@ -50,16 +45,16 @@ int main(int argc, char** argv)

if (data_type == FLOAT_DATATYPE) {
return bertExample<float>(
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test);
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding);
}
else if (data_type == HALF_DATATYPE) {
return bertExample<half>(
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test);
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding);
}
#ifdef ENABLE_BF16
else if (data_type == BFLOAT16_DATATYPE) {
return bertExample<__nv_bfloat16>(
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding, allow_gemm_test);
batch_size, num_layers, seq_len, head_num, size_per_head, is_remove_padding);
}
#endif
else {
Expand All @@ -74,8 +69,7 @@ int bertExample(size_t batch_size,
size_t seq_len,
size_t head_num,
size_t size_per_head,
bool is_remove_padding,
bool allow_gemm_test)
bool is_remove_padding)
{
printf("[INFO] Device: %s \n", getDeviceName().c_str());
print_mem_usage("Before loading model");
Expand Down

0 comments on commit e244b7b

Please sign in to comment.