From 62530649b82876278b0d3d81356cb74367bfcc5c Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 3 Dec 2024 16:47:48 +0100 Subject: [PATCH] feat(backend): remove constexpig --- backends/trtllm/CMakeLists.txt | 4 +- backends/trtllm/csrc/ffi.hpp | 8 +-- backends/trtllm/csrc/hardware.hpp | 5 +- backends/trtllm/tests/test_backend.cpp | 68 +++++++++++++++++++++++++- 4 files changed, 74 insertions(+), 11 deletions(-) diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt index 61058fd4e6e..49e597d0c74 100644 --- a/backends/trtllm/CMakeLists.txt +++ b/backends/trtllm/CMakeLists.txt @@ -77,8 +77,8 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp) target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/") - target_link_libraries(tgi_trtllm_backend_tests PRIVATE CUDA::cudart CUDA::nvml) - target_link_libraries(tgi_trtllm_backend_tests PRIVATE Catch2::Catch2WithMain tensorrt_llm nlohmann_json::nlohmann_json spdlog::spdlog) + target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) + target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl) if(CMAKE_BUILD_TYPE MATCHES "Debug") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address") diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index dc9fdd0fbc4..de2333afe37 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -4,14 +4,14 @@ #include #include +#include #include #include #include -#include -#include #include +#include namespace rust::behavior { template @@ -111,7 +111,7 @@ namespace huggingface::tgi::backends::trtllm { } void cancel(request_id_t requestId) noexcept { - SPDLOG_DEBUG(FMT_STRING("[FFI] cancelling request {:d}"), requestId); + SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId); inner_.cancel(requestId); } }; @@ -144,7 +144,7 @@ namespace huggingface::tgi::backends::trtllm { const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count(); if (numGpus.has_value()) { - SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", numGpus.value()); + SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus); } else { SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system"); // todo: throw diff --git a/backends/trtllm/csrc/hardware.hpp b/backends/trtllm/csrc/hardware.hpp index 480cf6800d4..8e5fa696dbb 100644 --- a/backends/trtllm/csrc/hardware.hpp +++ b/backends/trtllm/csrc/hardware.hpp @@ -16,13 +16,12 @@ namespace huggingface::tgi::hardware::cuda { * Get the number of GPUs on the local machine * @return std::nullopt if no device is available, otherwise >= 1 */ - std::optional get_device_count() { + inline std::optional get_device_count() { uint32_t numGpus = 0; if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { return numGpus; - } else { - return std::nullopt; } + return std::nullopt; } /** diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp index 215cb114778..e58a7e1a6db 100644 --- a/backends/trtllm/tests/test_backend.cpp +++ b/backends/trtllm/tests/test_backend.cpp @@ -4,15 +4,43 @@ #include #include -#include "../csrc/backend.hpp" +#include + +#include "backend.hpp" + + using namespace huggingface::tgi::backends::trtllm; -TEST_CASE("parse generation_config.json", "[generation_config_t]") +TEST_CASE("parse generation_config.json all set", "[generation_config_t]") { const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}}; const auto generation_config = generation_config_t(config_j); + REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6)); + REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6)); + + // Stop words + REQUIRE_FALSE(generation_config.stop_words.empty()); + REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); + + for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list>{{1}, {2}, {3}})) + { + // Currently we do not support multi-tokens stop words + REQUIRE(lhs.size() == 1); + REQUIRE(rhs.size() == 1); + REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs)); + } +} + +TEST_CASE("parse generation_config.json default", "[generation_config_t]") +{ + const json config_j = {{"eos_token_id", {1,2,3}}}; + const auto generation_config = generation_config_t(config_j); + + REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); + REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6)); + REQUIRE_FALSE(generation_config.stop_words.empty()); REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size()); @@ -25,8 +53,44 @@ TEST_CASE("parse generation_config.json", "[generation_config_t]") } } +TEST_CASE("parse generation_config.json empty", "[generation_config_t]") +{ + const json config_j = {{"eos_token_id", {}}}; + const auto generation_config = generation_config_t(config_j); + + REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); + REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6)); + + REQUIRE(generation_config.stop_words.empty()); + + const json config_j2 = {}; + const auto generation_config2 = generation_config_t(config_j); + + REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6)); + REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6)); + + REQUIRE(generation_config2.stop_words.empty()); +} + TEST_CASE("parallel_config", "[backend_workspace_t]") { + // Generate temporary folder + const auto tmp_p = std::filesystem::temp_directory_path(); + const auto config_p = tmp_p / "config.json"; + const auto generation_config_p = tmp_p / "generation_config.json"; + + // Generate content + std::ofstream o_config(config_p); + o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json; + o_config.close(); + + std::ofstream o_generation_config(generation_config_p); + o_generation_config << R"({"eos_token_id": []})"_json; + o_generation_config.close(); + + const auto workspace = backend_workspace_t(absolute(tmp_p).generic_string(), absolute(tmp_p).generic_string()); + const auto parallel = workspace.parallel_config(); + REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR); }