diff --git a/c++_samples/simple_inference_ClipOCR.cpp b/c++_samples/simple_inference_nvCLIP4STR.cpp similarity index 77% rename from c++_samples/simple_inference_ClipOCR.cpp rename to c++_samples/simple_inference_nvCLIP4STR.cpp index 3a19de7..5feccaf 100644 --- a/c++_samples/simple_inference_ClipOCR.cpp +++ b/c++_samples/simple_inference_nvCLIP4STR.cpp @@ -75,7 +75,7 @@ int main() // Please pay attention to the following parameters. You may need to change them according to different models. nvOCDRParam param; param.input_data_format = NHWC; - param.ocdnet_trt_engine_path = (char *)"/localhome/local-bizhao/models/ocdnet.fp16.engine"; + param.ocdnet_trt_engine_path = (char *)"/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/onnx_models/ocdnet_vit.fp16.engine"; param.ocdnet_infer_input_shape[0] = 3; param.ocdnet_infer_input_shape[1] = 736; param.ocdnet_infer_input_shape[2] = 1280; @@ -83,16 +83,22 @@ int main() param.ocdnet_polygon_threshold = 0.3; param.ocdnet_max_candidate = 200; param.ocdnet_unclip_ratio = 1.5; - param.ocrnet_trt_engine_path = (char *)"/localhome/local-bizhao/output/vl4str_base_pcb_10_split_oversample.ckpt.img.fp32.onnx_sim.onnx.fp16.engine"; - param.ocrnet_dict_file = (char *)"/localhome/local-bizhao/models/character_list"; + param.ocrnet_trt_engine_path = (char *)"/home/binz/CLIP4STR_nvCLIP/trained_with_nvclip/best_ckpt/vl4str_2024-11-19-06-48-47_checkpoints_epoch_9-step_15580-val_accuracy_71.1684-val_NED_79.9133.visual.sim.fp16.engine"; + param.ocrnet_text_trt_engine_path = (char *)"/home/binz/CLIP4STR_nvCLIP/trained_with_nvclip/best_ckpt/vl4str_2024-11-19-06-48-47_checkpoints_epoch_9-step_15580-val_accuracy_71.1684-val_NED_79.9133.text.sim.fp16.engine"; + param.ocrnet_vocab_file = (char *)"/home/binz/CLIP4STR_nvCLIP/code/CLIP4STR/strhub/clip/bpe_simple_vocab_16e6.txt"; + param.ocrnet_vocab_size = 32000; + param.ocrnet_dict_file = (char *)"/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/onnx_models/character_list_clip4str"; param.ocrnet_infer_input_shape[0] = 3; param.ocrnet_infer_input_shape[1] = 224; param.ocrnet_infer_input_shape[2] = 224; - param.ocrnet_decode = CLIP; + param.ocrnet_decode = Transformer; + param.ocrnet_only_alnum = false; + param.ocrnet_only_lowercase = false; + nvOCDRp nvocdr_ptr = nvOCDR_init(param); // Load the input - const char* img_path = "/localhome/local-bizhao/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/c++_samples/test_img/scene_text.jpg"; + const char* img_path = "/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/c++_samples/test_img/nvocdr.jpg"; cv::Mat img = cv::imread(img_path); nvOCDRInput input; input.device_type = GPU; @@ -110,8 +116,8 @@ int main() nvOCDR_inference(input, &output, nvocdr_ptr); // filter the output text, and covert to lowercase - std::string keeped_charset = "0123456789abcdefghijklmnopqrstuvwxyz"; - textFilter(output, keeped_charset); + std::string keeped_charset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";; + textFilter(output, keeped_charset, param.ocrnet_only_lowercase); // Visualize the output int offset = 0; diff --git a/c++_samples/simple_inference_vit.cpp b/c++_samples/simple_inference_vit.cpp index 768a319..4aaa128 100644 --- a/c++_samples/simple_inference_vit.cpp +++ b/c++_samples/simple_inference_vit.cpp @@ -45,7 +45,7 @@ int main() // Please pay attention to the following parameters. You may need to change them according to different models. nvOCDRParam param; param.input_data_format = NHWC; - param.ocdnet_trt_engine_path = (char *)"/hdd_10t/tylerz/CTSE_DL/github/ptmv2_models/ocdnet.fp16.engine"; + param.ocdnet_trt_engine_path = (char *)"/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/onnx_models/ocdnet_vit.fp16.engine"; param.ocdnet_infer_input_shape[0] = 3; param.ocdnet_infer_input_shape[1] = 736; param.ocdnet_infer_input_shape[2] = 1280; @@ -53,8 +53,8 @@ int main() param.ocdnet_polygon_threshold = 0.3; param.ocdnet_max_candidate = 200; param.ocdnet_unclip_ratio = 1.5; - param.ocrnet_trt_engine_path = (char *)"/hdd_10t/tylerz/CTSE_DL/github/ptmv2_models/ocrnet.fp16.engine"; - param.ocrnet_dict_file = (char *)"/hdd_10t/tylerz/CTSE_DL/github/ptmv2_models/character_list"; + param.ocrnet_trt_engine_path = (char *)"/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/onnx_models/ocrnet_vit.fp16.engine"; + param.ocrnet_dict_file = (char *)"/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/onnx_models/character_list"; param.ocrnet_infer_input_shape[0] = 1; param.ocrnet_infer_input_shape[1] = 64; param.ocrnet_infer_input_shape[2] = 200; @@ -62,7 +62,7 @@ int main() nvOCDRp nvocdr_ptr = nvOCDR_init(param); // Load the input - const char* img_path = "/hdd_10t/tylerz/CTSE_DL/github/scene_text.jpg"; + const char* img_path = "/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/c++_samples/test_img/scene_text.jpg"; cv::Mat img = cv::imread(img_path); nvOCDRInput input; input.device_type = GPU; diff --git a/c++_samples/test_img/nvocdr.jpg b/c++_samples/test_img/nvocdr.jpg new file mode 100644 index 0000000..77aff3d Binary files /dev/null and b/c++_samples/test_img/nvocdr.jpg differ diff --git a/include/nvocdr.h b/include/nvocdr.h index 88a68dd..8c9c7bc 100644 --- a/include/nvocdr.h +++ b/include/nvocdr.h @@ -36,7 +36,7 @@ enum OCRNetDecode { CTC, Attention, - CLIP + Transformer }; typedef struct @@ -58,7 +58,15 @@ typedef struct char* ocrnet_trt_engine_path; char* ocrnet_dict_file; int32_t ocrnet_infer_input_shape[3]; - OCRNetDecode ocrnet_decode = CTC; + OCRNetDecode ocrnet_decode = Transformer; + // Param for clip4str + char* ocrnet_text_trt_engine_path; + bool ocrnet_only_alnum = false; + bool ocrnet_only_lowercase = false; + char* ocrnet_vocab_file; + int ocrnet_vocab_size = 32000; + // char* charset_train = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + // char* charset_test = "0123456789abcdefghijklmnopqrstuvwxyz"; // common param } nvOCDRParam; diff --git a/src/OCRNetEngine.cpp b/src/OCRNetEngine.cpp index 2532d61..de43e6e 100644 --- a/src/OCRNetEngine.cpp +++ b/src/OCRNetEngine.cpp @@ -7,7 +7,8 @@ using namespace nvocdr; -OCRNetEngine::OCRNetEngine(const std::string& engine_path, const std::string& dict_path, const bool upside_down, const DecodeMode decode_mode) +OCRNetEngine::OCRNetEngine(const std::string& engine_path, const std::string& dict_path, const bool upside_down=0, const DecodeMode decode_mode=Transformer, + const std::string& text_engine_path, bool only_alnum, bool only_lowercase, const std::string& vocab_file, const int vocab_size) { // Init TRTEngine mEngine = std::move(std::unique_ptr(new TRTEngine(engine_path))); @@ -29,7 +30,7 @@ OCRNetEngine::OCRNetEngine(const std::string& engine_path, const std::string& di mDict.emplace_back("[GO]"); mDict.emplace_back("[s]"); } - else if (mDecodeMode == CLIP) + else if (mDecodeMode == Transformer) { mDict.emplace_back("[E]"); } @@ -47,14 +48,19 @@ OCRNetEngine::OCRNetEngine(const std::string& engine_path, const std::string& di } } - if (mDecodeMode == CLIP) + if (mDecodeMode == Transformer) { mDict.emplace_back("[B]"); mDict.emplace_back("[P]"); + // init text engine of CLIP4STR + mTextEngine = std::move(std::unique_ptr(new TRTEngine(text_engine_path))); + // vocab file + mTokenizer.initTokenizer(vocab_file, vocab_size); } mUDFlag = upside_down; - + mOnlyAlNum = only_alnum; + mOnlyLowerCase = only_lowercase; } @@ -70,10 +76,34 @@ OCRNetEngine::initTRTBuffer(BufferManager& buffer_mgr) // Init trt input gpu buffer mTRTInputBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxInputBufferSize(), sizeof(float)); mEngine->setInputBuffer(buffer_mgr.mDeviceBuffer[mTRTInputBufferIndex].data()); - - // Init trt output gpu buffer - mTRTOutputBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxOutputBufferSize(), sizeof(float)); - mEngine->setOutputBuffer(buffer_mgr.mDeviceBuffer[mTRTOutputBufferIndex].data()); + if (mDecodeMode != Transformer) + { + // Init trt output gpu buffer + mTRTOutputBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxOutputBufferSize(), sizeof(float)); + mEngine->setOutputBuffer(buffer_mgr.mDeviceBuffer[mTRTOutputBufferIndex].data()); + } + else + { + // init CLIP4STR visual branch trt ouput gpu buffer + mVisTRTOutputImgFeatureBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxTrtIoTensorSizeByName(mVisualOutImgFeaturebName), mEngine->getTrtIoTensorDtypeSizeByName(mVisualOutImgFeaturebName)); + mEngine->setOutputBufferByName(buffer_mgr.mDeviceBuffer[mVisTRTOutputImgFeatureBufferIndex].data(), mVisualOutImgFeaturebName); + mVisTRTOutputDecodeProbsBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxTrtIoTensorSizeByName(mVisualOutDecodeProbName), mEngine->getTrtIoTensorDtypeSizeByName(mVisualOutDecodeProbName)); + mEngine->setOutputBufferByName(buffer_mgr.mDeviceBuffer[mVisTRTOutputDecodeProbsBufferIndex].data(), mVisualOutDecodeProbName); + mVisTRTOutputContextBufferIndex = buffer_mgr.initDeviceBuffer(mEngine->getMaxTrtIoTensorSizeByName(mVisualOutContextName), mEngine->getTrtIoTensorDtypeSizeByName(mVisualOutContextName)); + mEngine->setOutputBufferByName(buffer_mgr.mDeviceBuffer[mVisTRTOutputContextBufferIndex].data(), mVisualOutContextName); + // init CLIP4STR text branch trt in gpu buffers + mTextTRTInputTextTokenBufferIndex = buffer_mgr.initDeviceBuffer(mTextEngine->getMaxTrtIoTensorSizeByName(mTextInTokenName), mTextEngine->getTrtIoTensorDtypeSizeByName(mTextInTokenName)); + mTextEngine->setInputBufferbyName(buffer_mgr.mDeviceBuffer[mTextTRTInputTextTokenBufferIndex].data(), mTextInTokenName); + mTextEngine->setInputBufferbyName(buffer_mgr.mDeviceBuffer[mVisTRTOutputContextBufferIndex].data(), mVisualOutContextName); + mTextEngine->setInputBufferbyName(buffer_mgr.mDeviceBuffer[mVisTRTOutputImgFeatureBufferIndex].data(), mVisualOutImgFeaturebName); + mTextEngine->setInputBatchSizebyName(mTextInTokenName, mTextEngine->getMaxBatchSize()); + mTextEngine->setInputBatchSizebyName(mVisualOutContextName, mTextEngine->getMaxBatchSize()); + mTextEngine->setInputBatchSizebyName(mVisualOutImgFeaturebName, mTextEngine->getMaxBatchSize()); + // init CLIP4STR text branch trt out gpu buffer + mTextTRTOutputBufferIndex = buffer_mgr.initDeviceBuffer(mTextEngine->getMaxTrtIoTensorSizeByName(mTextOutLogitName), mTextEngine->getTrtIoTensorDtypeSizeByName(mTextOutLogitName)); + mTextEngine->setOutputBufferByName(buffer_mgr.mDeviceBuffer[mTextTRTOutputBufferIndex].data(), mTextOutLogitName); + } + return 0; } @@ -114,105 +144,138 @@ OCRNetEngine::infer(BufferManager& buffer_mgr, std::vector> temp_de_texts; mEngine->infer(stream); + + if (mDecodeMode != Transformer) + { + // CPU Decode: + Dims output_prob_shape = mEngine->getExactOutputShape(OCRNET_OUTPUT_PROB); + Dims output_id_shape = mEngine->getExactOutputShape(OCRNET_OUTPUT_ID); + batch_size = output_prob_shape.d[0]; + int output_len = output_prob_shape.d[1]; - // CPU Decode: - Dims output_prob_shape = mEngine->getExactOutputShape(OCRNET_OUTPUT_PROB); - Dims output_id_shape = mEngine->getExactOutputShape(OCRNET_OUTPUT_ID); - int batch_size = output_prob_shape.d[0]; - int output_len = output_prob_shape.d[1]; + std::vector output_prob(volume(output_prob_shape)); + std::vector output_id(volume(output_id_shape)); + cudaMemcpyAsync(output_prob.data(), mEngine->getOutputAddr(OCRNET_OUTPUT_PROB), + output_prob.size() * sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(output_id.data(), mEngine->getOutputAddr(OCRNET_OUTPUT_ID), + output_id.size() * sizeof(int), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); - std::vector output_prob(volume(output_prob_shape)); - std::vector output_id(volume(output_id_shape)); - cudaMemcpyAsync(output_prob.data(), mEngine->getOutputAddr(OCRNET_OUTPUT_PROB), - output_prob.size() * sizeof(float), cudaMemcpyDeviceToHost, stream); - cudaMemcpyAsync(output_id.data(), mEngine->getOutputAddr(OCRNET_OUTPUT_ID), - output_id.size() * sizeof(int), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - std::vector> temp_de_texts; - if (mDecodeMode == CTC) - { - for(int batch_idx = 0; batch_idx < batch_size; ++batch_idx) + if (mDecodeMode == CTC) { - int b_offset = batch_idx * output_len; - int prev = output_id[b_offset]; - std::vector temp_seq_id = {prev}; - std::vector temp_seq_prob = {output_prob[b_offset]}; - for(int i = 1 ; i < output_len; ++i) + for(int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - if (output_id[b_offset + i] != prev) + int b_offset = batch_idx * output_len; + int prev = output_id[b_offset]; + std::vector temp_seq_id = {prev}; + std::vector temp_seq_prob = {output_prob[b_offset]}; + for(int i = 1 ; i < output_len; ++i) { - temp_seq_id.push_back(output_id[b_offset + i]); - temp_seq_prob.push_back(output_prob[b_offset + i]); - prev = output_id[b_offset + i]; - } - } - std::string de_text = ""; - float prob = 1.0; - for(size_t i = 0; i < temp_seq_id.size(); ++i) - { - if (temp_seq_id[i] != 0) - { - if (temp_seq_id[i] <= static_cast(mDict.size()) - 1) + if (output_id[b_offset + i] != prev) { - de_text += mDict[temp_seq_id[i]]; - prob *= temp_seq_prob[i]; + temp_seq_id.push_back(output_id[b_offset + i]); + temp_seq_prob.push_back(output_prob[b_offset + i]); + prev = output_id[b_offset + i]; } - else + } + std::string de_text = ""; + float prob = 1.0; + for(size_t i = 0; i < temp_seq_id.size(); ++i) + { + if (temp_seq_id[i] != 0) { - std::cerr << "[ERROR] Character dict is not compatible with OCRNet TRT engine." << std::endl; + if (temp_seq_id[i] <= static_cast(mDict.size()) - 1) + { + de_text += mDict[temp_seq_id[i]]; + prob *= temp_seq_prob[i]; + } + else + { + std::cerr << "[ERROR] Character dict is not compatible with OCRNet TRT engine." << std::endl; + } } } + temp_de_texts.emplace_back(std::make_pair(de_text, prob)); } - temp_de_texts.emplace_back(std::make_pair(de_text, prob)); } - } - else if (mDecodeMode == Attention) - { - for(int batch_idx = 0; batch_idx < batch_size; ++batch_idx) + else if (mDecodeMode == Attention) { - int b_offset = batch_idx * output_len; - int stop_idx = 0; - std::string de_text = ""; - float prob = 1.0; - for(int i = 0; i < output_len; ++i) + for(int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - if (mDict[output_id[b_offset + i]] != "[s]") - { - de_text += mDict[output_id[b_offset + i]]; - prob *= output_prob[b_offset + i]; - } - else + int b_offset = batch_idx * output_len; + int stop_idx = 0; + std::string de_text = ""; + float prob = 1.0; + for(int i = 0; i < output_len; ++i) { - break; + if (mDict[output_id[b_offset + i]] != "[s]") + { + de_text += mDict[output_id[b_offset + i]]; + prob *= output_prob[b_offset + i]; + } + else + { + break; + } } + temp_de_texts.emplace_back(std::make_pair(de_text, prob)); } - temp_de_texts.emplace_back(std::make_pair(de_text, prob)); } } - else if (mDecodeMode == CLIP) + + else if (mDecodeMode == Transformer) { - for(int batch_idx = 0; batch_idx < batch_size; ++batch_idx) + // CPU Decode: + Dims visOutputDecodeProbShape = mEngine->getExactOutputShape(mVisualOutDecodeProbName); + batch_size = visOutputDecodeProbShape.d[0]; + int context_max_length = visOutputDecodeProbShape.d[1]; + int charset_len = visOutputDecodeProbShape.d[2]; + + // Get visual branch output + std::vector output_prob(buffer_mgr.mDeviceBuffer[mVisTRTOutputDecodeProbsBufferIndex].size()); + cudaMemcpyAsync(output_prob.data(), buffer_mgr.mDeviceBuffer[mVisTRTOutputDecodeProbsBufferIndex].data(), + buffer_mgr.mDeviceBuffer[mVisTRTOutputDecodeProbsBufferIndex].nbBytes(), cudaMemcpyDeviceToHost, stream); + // CLIP4STR decode visual branch output + std::vector> all_text_tokens; + for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - int b_offset = batch_idx * output_len; - std::string de_text = ""; - float prob = 1.0; + std::pair de_text_prob = clip4strDecode(output_prob, batch_idx, context_max_length, charset_len); + // batch_captions.emplace_back(de_text); + std::vector text_tokens = mTokenizer.encode(de_text_prob.first); + text_tokens.insert(text_tokens.begin(), mTokenizer.getStartTextToken()); + text_tokens.push_back(mTokenizer.getEndTextToken()); - for(int i = 0; i < output_len; ++i) + std::vector truncated_text_tokens(mMaxContextLen,0); + if (text_tokens.size() > mMaxContextLen) { - if (mDict[output_id[b_offset + i]] == "[E]") - { - break; - } - de_text += mDict[output_id[b_offset + i]]; - prob *= output_prob[b_offset + i]; + std::copy(text_tokens.begin(), text_tokens.begin() + mMaxContextLen, truncated_text_tokens.begin()); + truncated_text_tokens.back() = mTokenizer.getEndTextToken(); } - - temp_de_texts.emplace_back(std::make_pair(de_text, prob)); + else + { + std::copy(text_tokens.begin(), text_tokens.end(), truncated_text_tokens.begin()); + } + all_text_tokens.emplace_back(truncated_text_tokens); } + // text branch inference + cudaMemcpyAsync(buffer_mgr.mDeviceBuffer[mTextTRTInputTextTokenBufferIndex].data(), all_text_tokens.data(), + batch_size*mMaxContextLen*sizeof(int), cudaMemcpyHostToDevice, stream); + + mTextEngine->infer(stream); + + Dims textOutputDecodeProbShape = mTextEngine->getExactOutputShape(mTextOutLogitName); + std::vector text_output_prob(volume(textOutputDecodeProbShape)); + cudaMemcpyAsync(text_output_prob.data(), buffer_mgr.mDeviceBuffer[mTextTRTOutputBufferIndex].data(), + buffer_mgr.mDeviceBuffer[mTextTRTOutputBufferIndex].nbBytes(), cudaMemcpyDeviceToHost, stream); + for (int i=0; i de_text_prob = clip4strDecode(text_output_prob, i, context_max_length, charset_len); + temp_de_texts.emplace_back(de_text_prob); + } } else { @@ -241,4 +304,40 @@ OCRNetEngine::infer(BufferManager& buffer_mgr, std::vector OCRNetEngine::clip4strDecode( const std::vector& output_prob, const int batch_idx, const int context_len, const int charset_len) +{ + std::string de_text = ""; + float prob = 1.0; + for (int context_id = 0; context_id < context_len; ++context_id) + { + int batch_context_row_start = batch_idx * context_len * charset_len + context_id * charset_len; + auto max_iter = std::max_element(output_prob.begin() + batch_context_row_start, output_prob.begin() + batch_context_row_start + charset_len); + prob *= *max_iter; + int id = std::distance(output_prob.begin() + batch_context_row_start, max_iter); + if (mDict[id] == "[E]") + { + break; + } + if (mOnlyAlNum && !std::regex_match(mDict[id], std::regex("[a-zA-Z0-9]"))) + { + continue; + } + if (mOnlyLowerCase) + { + std::string tmpCaption = mDict[id]; + for (char& c : tmpCaption) { + c = static_cast(std::tolower(c)); + } + de_text += tmpCaption; + } + else + { + de_text += mDict[id]; + } + + } + return std::make_pair(de_text, prob); +} diff --git a/src/OCRNetEngine.h b/src/OCRNetEngine.h index 4873b1b..cfaa88c 100644 --- a/src/OCRNetEngine.h +++ b/src/OCRNetEngine.h @@ -1,6 +1,12 @@ #ifndef __NVOCDR_OCRN_HEADER__ #define __NVOCDR_OCRN_HEADER__ #include +#include +#include +#include +#include +#include +#include #include "MemManager.h" #include "TRTEngine.h" @@ -15,14 +21,308 @@ enum DecodeMode { CTC, Attention, - CLIP + Transformer }; + +class SimpleTokenizer +{ + private: + // Member variables + std::map byte_encoder; // Byte-to-Unicode encoder + std::unordered_map byte_decoder; // Unicode-to-Byte decoder + std::vector> merges; // BPE merge rules + std::map, int> bpe_ranks; // Merge ranks + std::unordered_map encoder; // Token-to-ID mapping + std::unordered_map decoder; // ID-to-Token mapping + std::unordered_map cache; // Cache for BPE results + std::regex pat = std::regex(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]+|[^\s\w]+)"); // Regular expression for tokenization + std::wstring_convert> mConverter; + int mStartTextToken; + int mEndTextToken; + + public: + // Constructor: Initialize tokenizer with BPE rules and vocabulary size + SimpleTokenizer(){}; + + SimpleTokenizer(const std::string& bpe_path, int vocab_size = 32000){ + initTokenizer(bpe_path, vocab_size); + } + + void initTokenizer(const std::string& bpe_path, int vocab_size) + { + // Initialize byte encoder/decoder + std::vector> b2u = bytes_to_unicode(); + for (int i = 0; i < b2u.size(); ++i) { + byte_encoder[b2u[i].first] = b2u[i].second; + encoder[b2u[i].second] = i; + encoder[b2u[i].second + L""] = i + b2u.size(); + } + + for (const auto& [k, v] : byte_encoder) { + byte_decoder[v] = k; + } + + // Load BPE merges from file + std::ifstream file(bpe_path, std::ios::in); + if (!file.is_open()) { + throw std::runtime_error("Failed to open BPE file: " + bpe_path); + } + + // Read the BPE merges file and process it + std::string line; + std::getline(file, line); + // Calculate max merges based on encoder size + int max_merges = vocab_size - 256 - 256 - 2 ; + while (std::getline(file, line) && max_merges-- > 0) { + std::istringstream iss(line); + std::string first, second; + if (iss >> first >> second) { + merges.emplace_back(first, second); + } + } + file.close(); + + // Build BPE ranks + for (size_t i = 0; i < merges.size(); ++i) { + bpe_ranks[merges[i]] = static_cast(i); + } + + for (const auto& merge : merges) { + std::string key_merge = merge.first + merge.second; + std::wstring key_merge_w = mConverter.from_bytes(key_merge); + encoder[key_merge_w] = encoder.size(); + + } + encoder[L"<|startoftext|>"] = encoder.size(); + encoder[L"<|endoftext|>"] = encoder.size(); + for (const auto& [key, value] : encoder) { + decoder[value] = key; + } + + cache["<|startoftext|>"] = "<|startoftext|>"; + cache["<|endoftext|>"] = "<|endoftext|>"; + mStartTextToken = encoder[L"<|startoftext|>"]; + mEndTextToken = encoder[L"<|endoftext|>"]; + }; + + std::string bpe(const std::string& token) { + + if (cache.find(token) != cache.end()) { + return cache[token]; + } + + std::vector word; + for (size_t i = 0; i < token.size(); ++i) { + word.push_back(std::string(1, token[i])); + } + word.back() += ""; + + auto pairs = get_pairs(word); + + if (pairs.empty()) { + return token + ""; + } + + while (true) { + auto bigram = *std::min_element( + pairs.begin(), pairs.end(), + [this](const std::pair& a, const std::pair& b) { + return bpe_ranks.count(a) ? (bpe_ranks.count(b) ? bpe_ranks.at(a) < bpe_ranks.at(b) : true) : false; + }); + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::vector new_word; + size_t i = 0; + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), bigram.first); + if (it != word.end() && it + 1 != word.end() && *(it + 1) == bigram.second) { + new_word.insert(new_word.end(), word.begin() + i, it); + new_word.push_back(bigram.first + bigram.second); + i = it - word.begin() + 2; + } else { + new_word.push_back(word[i]); + ++i; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + + pairs = get_pairs(word); + } + + std::ostringstream result; + for (const auto& w : word) { + result << w << " "; + } + std::string final_result = result.str(); + final_result.pop_back(); + + cache[token] = final_result; + return final_result; + }; + + // Encode function + std::vector encode(const std::string& text) { + std::vector bpe_tokens; + + // Clean and lowercase the input text + std::string cleaned_text = whitespace_clean(basic_clean(text)); + std::transform(cleaned_text.begin(), cleaned_text.end(), cleaned_text.begin(), ::tolower); + // Tokenize the cleaned text using regex + std::sregex_iterator it(cleaned_text.begin(), cleaned_text.end(), pat); + std::sregex_iterator end; + + while (it != end) { + if (it->str().empty()) + { + continue; + } + std::string token = it->str(); + + // Convert token to its byte representation + std::string byte_representation; + for (char b : token) { + + byte_representation += mConverter.to_bytes(byte_encoder[static_cast(b)]); + } + // Apply BPE to the byte representation + std::string bpe_result = bpe(byte_representation); + + // Encode each BPE token to its corresponding ID + std::istringstream iss(bpe_result); + std::string bep_str; + while (std::getline(iss, bep_str, ' ')) { + std::wstring bpe_token = mConverter.from_bytes(bep_str); + bpe_tokens.push_back(encoder[bpe_token]); + } + + ++it; + } + + return bpe_tokens; + }; + + + // Decode function + std::string decode(const std::vector& tokens) { + // std::ostringstream decoded_text; + std::string result; + for (int token : tokens) { + if (decoder.find(token) != decoder.end()) { + result = mConverter.to_bytes(decoder[token]); + } + } + + // Convert the decoded string to a byte array and replace with space + result = std::regex_replace(result, std::regex(""), " "); // Replace with space + + return result; + }; + + std::vector> bytes_to_unicode() { + std::vector bs; + for (int i = static_cast(u'!'); i <= static_cast(u'~'); ++i) bs.push_back(i); + for (int i = static_cast(u'¡'); i <= static_cast(u'¬'); ++i) bs.push_back(i); + for (int i = static_cast(u'®'); i <= static_cast(u'ÿ'); ++i) bs.push_back(i); + + std::vector cs = bs; + int n = 0; + for (int b = 0; b < 256; b++) { + if (std::find(bs.begin(), bs.end(), b) == bs.end()) { + bs.push_back(b); + cs.push_back(256 + n); + n++; + } + } + + std::vector> result; + for (size_t i = 0; i < bs.size(); i++) { + result.push_back(std::make_pair(bs[i], std::wstring(1, static_cast(cs[i])))); + } + + return result; + } + + std::set> get_pairs(const std::vector& word) { + std::set> pairs; + for (size_t i = 0; i < word.size() - 1; ++i) { + pairs.emplace(word[i], word[i + 1]); + } + return pairs; + }; + + std::string whitespace_clean(const std::string& text) { + std::regex whitespace_regex("\\s+"); + std::string result = std::regex_replace(text, whitespace_regex, " "); + + size_t start = result.find_first_not_of(" "); + size_t end = result.find_last_not_of(" "); + + if (start == std::string::npos) { + return ""; + } + + return result.substr(start, end - start + 1); + }; + + // Function to decode HTML entities + std::string html_unescape(const std::string& input) { + static const std::unordered_map html_entities = { + {""", "\""}, {"&", "&"}, {"<", "<"}, + {">", ">"}, {" ", " "}, {"'", "'"} + }; + + std::string output = input; + for (const auto& [entity, character] : html_entities) { + size_t pos = 0; + while ((pos = output.find(entity, pos)) != std::string::npos) { + output.replace(pos, entity.length(), character); + pos += character.length(); + } + } + return output; + }; + + // Function to trim leading and trailing whitespace + std::string trim(const std::string& str) { + size_t start = str.find_first_not_of(" \t\n\r"); + size_t end = str.find_last_not_of(" \t\n\r"); + return (start == std::string::npos) ? "" : str.substr(start, end - start + 1); + }; + + // Main `basic_clean` function + std::string basic_clean(const std::string& text) { + std::string cleaned_text = text; + + // Decode HTML entities twice + cleaned_text = html_unescape(cleaned_text); + cleaned_text = html_unescape(cleaned_text); + + // Trim whitespace + cleaned_text = trim(cleaned_text); + + return cleaned_text; + }; + + int getStartTextToken() {return mStartTextToken;}; + int getEndTextToken() {return mEndTextToken;}; + +}; + + class OCRNetEngine { public: OCRNetEngine(const std::string& engine_path, const std::string& dict_path, - const bool upside_down=0, const DecodeMode decode_mode=CTC); + const bool upside_down, const DecodeMode decode_mode, const std::string& text_engine_path="", bool only_alnum=true, bool only_lowercase=true, const std::string& vocab_file="", const int vocab_size=32000); ~OCRNetEngine(); bool initTRTBuffer(BufferManager& buffer_mgr); @@ -32,16 +332,36 @@ class OCRNetEngine bool setOutputDeviceBuffer(DeviceBuffer& device_buffer, const int index); bool infer(BufferManager& buffer_mgr, std::vector>& de_texts, const cudaStream_t& stream = 0); - + std::pair clip4strDecode( const std::vector& output_prob, const int batch_idx, const int context_len, const int charset_len); int mTRTInputBufferIndex; int mTRTOutputBufferIndex; + // CLIP4STR buffer + int mVisTRTOutputDecodeProbsBufferIndex; + int mVisTRTOutputImgFeatureBufferIndex; + int mVisTRTOutputContextBufferIndex; + + int mTextTRTInputTextTokenBufferIndex; + int mTextTRTOutputBufferIndex; private: std::unique_ptr mEngine; std::vector mDict; bool mUDFlag; DecodeMode mDecodeMode; - // int mDecodeOutputBufferIndex; + // CLIP4STR + std::unique_ptr mTextEngine; + std::string mVisualOutDecodeProbName = "visual_decode_probs"; + std::string mVisualOutContextName = "tgt_in"; + std::string mVisualOutImgFeaturebName = "img_feature"; + std::string mTextInTokenName = "text_token"; + std::string mTextOutLogitName = "logits"; + int mVocabSize = 32000; + int mMaxContextLen = 16; + // only output alpha and digit + bool mOnlyAlNum = true; + // only output lower case alpha + bool mOnlyLowerCase = true; + SimpleTokenizer mTokenizer; }; } #endif \ No newline at end of file diff --git a/src/RectEngine.cpp b/src/RectEngine.cpp index e32add3..ee1bf41 100644 --- a/src/RectEngine.cpp +++ b/src/RectEngine.cpp @@ -15,7 +15,7 @@ RectEngine::RectEngine(const int& output_height, const int& output_width, const checkCudaErrors(cublasCreate(&mHandle)); mIsRGBOutput = (rec_output_channel==3) ? true : false; #ifdef RECT_DEBUG - mImgSavePath = "/localhome/local-bizhao/dataset/pcb_images/FRAME_0_1_H.jpg"; + mImgSavePath = "/home/binz/ssd_4t/NVIDIA-Optical-Character-Detection-and-Recognition-Solution/c++_samples/test_img/"; #endif } @@ -53,6 +53,10 @@ bool RectEngine::initBuffer(BufferManager& buffer_mgr) // if we use gray ouput, we need to init this RGB buff to calculate the gray ouput mRGBOutputBufferDevIdx = buffer_mgr.initDeviceBuffer(mOcrInferBatch*3*mOutputHeight*mOutputWidth, sizeof(float)); } + else + { + mGrayOutputBufferDevIdx = buffer_mgr.initDeviceBuffer(mOcrInferBatch*mOutputHeight*mOutputWidth, sizeof(float)); + } #ifdef RECT_DEBUG @@ -140,56 +144,72 @@ RectEngine::infer(void* input_data, const Dims& input_shape, #ifdef RECT_DEBUG int img_size = input_shape.d[1] * input_shape.d[2] * input_shape.d[3]; - checkCudaErrors(cudaMemcpy(buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data(), buffer_mgr.mDeviceBuffer[mGrayOutputBufferDevIdx].data(), buffer_mgr.mDeviceBuffer[mGrayOutputBufferDevIdx].nbBytes(), cudaMemcpyDeviceToHost)); - checkCudaErrors(cudaMemcpy(buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].data(), buffer_mgr.mDeviceBuffer[mRGBOutputBufferDevIdx].data(), buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].size()*sizeof(float), cudaMemcpyDeviceToHost)); + if (!mIsRGBOutput) + { + checkCudaErrors(cudaMemcpy(buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data(), buffer_mgr.mDeviceBuffer[mGrayOutputBufferDevIdx].data(), buffer_mgr.mDeviceBuffer[mGrayOutputBufferDevIdx].nbBytes(), cudaMemcpyDeviceToHost)); + } + else + { + checkCudaErrors(cudaMemcpy(buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].data(), buffer_mgr.mDeviceBuffer[mRGBOutputBufferDevIdx].data(), buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].size()*sizeof(float), cudaMemcpyDeviceToHost)); + } for (int idx_poly=0; idx_poly(buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].data() + idx_poly*pt_img_size); - cv::Mat pt_frame(mOutputHeight, mOutputWidth, CV_8UC3, h_pt_data); - std::string pt_img_file = mImgSavePath + std::to_string(polys_to_imgs[idx_poly]) + '_' + std::to_string(idx_poly) + "_pt_cuda.png"; - cv::imwrite(pt_img_file, pt_frame); - - // write gray img - int gray_img_size = mOutputWidth*mOutputHeight; - float* h_gray_data = static_cast( buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data()+ idx_poly*gray_img_size*sizeof(float)); - uchar h_gray_data_uchar[224*224]; - float gray_data; - for(int n=0; n(buffer_mgr.mHostBuffer[mRGBOutputBufferHostIdx].data() + idx_poly*pt_img_size*sizeof(float)); + uchar h_pt_data_uchar[3*224*224]; + for(int n=0; n(buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data()+(idx_poly+polys_to_imgs.size())*gray_img_size*sizeof(float)); + // write gray img + int gray_img_size = mOutputWidth*mOutputHeight; + float* h_gray_data = static_cast( buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data()+ idx_poly*gray_img_size*sizeof(float)); + uchar h_gray_data_uchar[224*224]; + float gray_data; for(int n=0; n(buffer_mgr.mHostBuffer[mGrayOutputBufferHostIdx].data()+(idx_poly+polys_to_imgs.size())*gray_img_size*sizeof(float)); + for(int n=0; nsetTensorAddress(mInputName.c_str(), buffer); } +void +TRTEngine::setInputBufferbyName(void* buffer, std::string& tensorName) +{ + mContext->setTensorAddress(tensorName.c_str(), buffer); +} + void TRTEngine::setInputShape(const Dims shape) @@ -121,6 +127,20 @@ TRTEngine::setInputShape(const Dims shape) } +void +TRTEngine::setInputBatchSizebyName(const std::string& tensorName, const int batch_size) +{ + nvinfer1::Dims maxInputShape = mEngine->getProfileShape(tensorName.c_str(), 0, OptProfileSelector::kMAX); + maxInputShape.d[0] = batch_size; + mContext->setInputShape(tensorName.c_str(), maxInputShape); + mExactOutputShapes.clear(); + for(int i = 0; i < mOutputNames.size(); ++i) + { + mExactOutputShapes.emplace(mOutputNames[i], mContext->getTensorShape(mOutputNames[i].c_str())); + } +} + + size_t TRTEngine::getMaxOutputBufferSize() { @@ -139,6 +159,41 @@ TRTEngine::getMaxOutputBufferSize() } +size_t +TRTEngine::getMaxTrtIoTensorSizeByName(std::string& tensorName) +{ + nvinfer1::Dims tensorDims = mContext->getTensorShape(tensorName.c_str()); + + size_t data_size = 1; + for(int i = 0; i < tensorDims.nbDims; ++i) + { + if (tensorDims.d[i] == -1) + { + data_size *= mMaxBatchSize; + } + else + { + data_size *= tensorDims.d[i]; + } + + } + return data_size; +} + + +size_t +TRTEngine::getTrtIoTensorDtypeSizeByName(std::string& tensorName) +{ + int bindingIndex = mEngine->getBindingIndex(tensorName.c_str()); + if (bindingIndex == -1) { + std::cerr << "Tensor name not found in the engine." << std::endl; + return 0; + } + nvinfer1::DataType dataType = mEngine->getBindingDataType(bindingIndex); + return sizeof(dataType); +} + + void TRTEngine::setOutputBuffer(void* buffer) { @@ -159,6 +214,12 @@ TRTEngine::setOutputBuffer(void* buffer) } +void TRTEngine::setOutputBufferByName(void* buffer, std::string& tensorName) +{ + mContext->setTensorAddress(tensorName.c_str(), buffer); +} + + const void* TRTEngine::getOutputAddr(std::string output_name) { diff --git a/src/TRTEngine.h b/src/TRTEngine.h index a6f3205..2caf7c2 100644 --- a/src/TRTEngine.h +++ b/src/TRTEngine.h @@ -45,8 +45,12 @@ class TRTEngine // Manager will only give a buffer with max size and let TRT Engine to use the buffer void setInputBuffer(void* buffer); + void setInputBufferbyName(void* buffer, std::string& tensorName); + void setInputShape(const Dims shape); + void setInputBatchSizebyName(const std::string& tensorName, const int batch_size); + size_t getMaxInputBufferSize(); Dims getExactInputShape() {return mExactInputShape;}; @@ -57,8 +61,14 @@ class TRTEngine size_t getMaxOutputBufferSize(); + size_t getMaxTrtIoTensorSizeByName(std::string& tensorName); + + size_t getTrtIoTensorDtypeSizeByName(std::string& tensorName); + void setOutputBuffer(void* buffer); + void setOutputBufferByName(void* buffer, std::string& tensorName); + const void* getOutputAddr(std::string output_name); Dims getExactOutputShape(std::string output_name); diff --git a/src/nvOCDR.cpp b/src/nvOCDR.cpp index 48b70a7..d389d3b 100644 --- a/src/nvOCDR.cpp +++ b/src/nvOCDR.cpp @@ -285,9 +285,9 @@ nvOCDR::nvOCDR(nvOCDRParam param): { decode_mode = DecodeMode::Attention; } - else if (param.ocrnet_decode == OCRNetDecode::CLIP) + else if (param.ocrnet_decode == OCRNetDecode::Transformer) { - decode_mode = DecodeMode::CLIP; + decode_mode = DecodeMode::Transformer; } else { @@ -296,14 +296,32 @@ nvOCDR::nvOCDR(nvOCDRParam param): // Init ocrnet std::string ocr_engine_path(param.ocrnet_trt_engine_path); + std::string ocr_text_engine_path = ""; + std::string vocab_file = ""; + bool only_alnum = true; + bool only_lowercase = true; + int vocab_size = 0; + if (param.ocrnet_decode == OCRNetDecode::Transformer) + { + ocr_text_engine_path = std::string(param.ocrnet_text_trt_engine_path); + only_alnum = param.ocrnet_only_alnum; + only_lowercase = param.ocrnet_only_lowercase; + vocab_file = std::string(param.ocrnet_vocab_file); + vocab_size = param.ocrnet_vocab_size; + } std::string ocr_dict_path(param.ocrnet_dict_file); mOCRNet = std::move(std::unique_ptr(new OCRNetEngine(ocr_engine_path, ocr_dict_path, upsidedown, - decode_mode))); + decode_mode, + ocr_text_engine_path, + only_alnum, + only_lowercase, + vocab_file, + vocab_size))); // Init input and output buffer for OCRNet TRT inference mOCRNet->initTRTBuffer(mBuffMgr); - mOCRNetInputShape.nbDims=4; + mOCRNetInputShape.nbDims = 4; mOCRNetInputShape.d[0] = -1; // Dynamic batch size mOCRNetInputShape.d[1] = param.ocrnet_infer_input_shape[0]; mOCRNetInputShape.d[2] = param.ocrnet_infer_input_shape[1];