diff --git a/CMakeLists.txt b/CMakeLists.txt index fa6395f..baf6679 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,23 @@ cmake_minimum_required(VERSION 3.10) project(MetalTranslate) +include(ExternalProject) + +set(EXTERNAL_INSTALL_LOCATION ${CMAKE_BINARY_DIR}/external) + +ExternalProject_Add(CTranslate2 + GIT_REPOSITORY https://github.com/OpenNMT/CTranslate2 + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EXTERNAL_INSTALL_LOCATION} -DWITH_MKL=OFF -DWITH_DNNL=ON +) + +ExternalProject_Add(Tokenizer + GIT_REPOSITORY https://github.com/OpenNMT/Tokenizer + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EXTERNAL_INSTALL_LOCATION} +) + +include_directories(${EXTERNAL_INSTALL_LOCATION}/include) +link_directories(${EXTERNAL_INSTALL_LOCATION}/lib) + add_executable(metaltranslate src/main.cpp) set(TARGET_H @@ -10,8 +27,11 @@ set(TARGET_H target_sources(metaltranslate PRIVATE src/MetalTranslate.cpp) -add_subdirectory(third_party/CTranslate2/) +#add_subdirectory(third_party/CTranslate2/) target_link_libraries(metaltranslate ctranslate2) -add_subdirectory(third_party/Tokenizer/) +##add_subdirectory(third_party/Tokenizer/) target_link_libraries(metaltranslate OpenNMTTokenizer) + +#target_link_libraries(metaltranslate cpu_features) + diff --git a/src/MetalTranslate.cpp b/src/MetalTranslate.cpp index d039778..962ba60 100644 --- a/src/MetalTranslate.cpp +++ b/src/MetalTranslate.cpp @@ -1,52 +1,53 @@ #include "MetalTranslate.h" -#include +#include #include #include namespace MetalTranslate { -MetalTranslate::MetalTranslate(MetalTranslateConfig config) { - this->_config = config; -} + MetalTranslate::MetalTranslate(MetalTranslateConfig config) { + this->_config = config; + } -std::string MetalTranslate::Translate(std::string source, - std::string source_code, - std::string target_code) { + std::string MetalTranslate::Translate(std::string source, + std::string source_code, + std::string target_code) { - // Tokenizer - onmt::Tokenizer tokenizer(this->_config.ModelPath + "sentencepiece.model"); - std::vector tokens; - tokenizer.tokenize(source, tokens); + // Tokenizer + onmt::Tokenizer tokenizer(this->_config.ModelPath + "sentencepiece.model"); + std::vector tokens; + tokenizer.tokenize(source, tokens); - std::string source_prefix = "__" + source_code + "__"; - tokens.insert(tokens.begin(), source_prefix); + std::string source_prefix = "__" + source_code + "__"; + tokens.insert(tokens.begin(), source_prefix); - // CTranslate2 - const size_t num_translators = 1; - const size_t num_threads_per_translator = 0; // Unused with DNNL - ctranslate2::TranslatorPool translator( - num_translators, num_threads_per_translator, - this->_config.ModelPath + "model", ctranslate2::Device::CPU); + // CTranslate2 + const size_t num_translators = 1; + const size_t num_threads_per_translator = 0; // Unused with DNNL - const std::vector> batch = {tokens}; - const std::vector> target_prefix = { - {"__" + target_code + "__"}}; - const int max_batch_size = 2024; + const std::vector> batch = { {"▁H", "ello", "▁world", "!"} }; - const std::vector results = - translator.translate_batch(batch, target_prefix); + ctranslate2::Translator translator(this->_config.ModelPath, ctranslate2::Device::CPU); + //const std::vector results = translator.translate_batch(batch); - const std::vector translatedTokens = results[0].output(); + const std::vector> target_prefix = { + {"__" + target_code + "__"} }; + const int max_batch_size = 2024; - std::string result = tokenizer.detokenize(translatedTokens); + const std::vector results = + translator.translate_batch(batch, target_prefix); - // Remove target prefix - // __es__ Traducción de texto con MetalTranslate - // -> Traducción de texto con MetalTranslate - result = result.substr(7); + const std::vector translatedTokens = results[0].output(); - return result; -} + std::string result = tokenizer.detokenize(translatedTokens); -} // namespace MetalTranslate + // Remove target prefix + // __es__ Traducción de texto con MetalTranslate + // -> Traducción de texto con MetalTranslate + result = result.substr(7); + + return result; + } + +} // namespace MetalTranslate \ No newline at end of file diff --git a/src/MetalTranslateConfig.h b/src/MetalTranslateConfig.h index 26ba793..1f8cd87 100644 --- a/src/MetalTranslateConfig.h +++ b/src/MetalTranslateConfig.h @@ -3,6 +3,6 @@ namespace MetalTranslate { class MetalTranslateConfig { public: - std::string ModelPath = "models/translate-fairseq_m2m_100_418M/"; + std::string ModelPath = "models/nllb-200-distilled-600M-int8/"; }; } // namespace MetalTranslate diff --git a/third_party/CTranslate2 b/third_party/CTranslate2 deleted file mode 160000 index 4908b9d..0000000 --- a/third_party/CTranslate2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4908b9d39ab135ec51f0209ef1ed74b267428c32 diff --git a/third_party/README.md b/third_party/README.md deleted file mode 100644 index 36702b7..0000000 --- a/third_party/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Third party dependencies -- [CTranslate2](https://github.com/OpenNMT/CTranslate2) -- [OpenNMT Tokenizer](https://github.com/OpenNMT/Tokenizer) -- [SentencePiece](https://github.com/google/sentencepiece) diff --git a/third_party/Tokenizer b/third_party/Tokenizer deleted file mode 160000 index 559b8e7..0000000 --- a/third_party/Tokenizer +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 559b8e716091be26084eea01f0404ecaa2939f80