Skip to content

Commit

Permalink
feat(backend): remove constexpig
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Dec 3, 2024
1 parent 881527a commit 6253064
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
4 changes: 2 additions & 2 deletions backends/trtllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE Catch2::Catch2WithMain tensorrt_llm nlohmann_json::nlohmann_json spdlog::spdlog)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)

if(CMAKE_BUILD_TYPE MATCHES "Debug")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Werror -fsanitize=undefined -fsanitize=address")
Expand Down
8 changes: 4 additions & 4 deletions backends/trtllm/csrc/ffi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
#include <memory>
#include <thread>

#include <nvml.h>
#include <tensorrt_llm/common/tllmException.h>
#include <tensorrt_llm/plugins/api/tllmPlugin.h>

#include <spdlog/spdlog.h>
#include <spdlog/pattern_formatter.h>
#include <spdlog/fmt/fmt.h>

#include <backend.hpp>
#include <hardware.hpp>

namespace rust::behavior {
template<typename Try, typename Fail>
Expand Down Expand Up @@ -111,7 +111,7 @@ namespace huggingface::tgi::backends::trtllm {
}

void cancel(request_id_t requestId) noexcept {
SPDLOG_DEBUG(FMT_STRING("[FFI] cancelling request {:d}"), requestId);
SPDLOG_DEBUG("[FFI] cancelling request {:d}", requestId);
inner_.cancel(requestId);
}
};
Expand Down Expand Up @@ -144,7 +144,7 @@ namespace huggingface::tgi::backends::trtllm {

const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();
if (numGpus.has_value()) {
SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", numGpus.value());
SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus);
} else {
SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system");
// todo: throw
Expand Down
5 changes: 2 additions & 3 deletions backends/trtllm/csrc/hardware.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ namespace huggingface::tgi::hardware::cuda {
* Get the number of GPUs on the local machine
* @return std::nullopt if no device is available, otherwise >= 1
*/
std::optional<size_t> get_device_count() {
inline std::optional<size_t> get_device_count() {
uint32_t numGpus = 0;
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
return numGpus;
} else {
return std::nullopt;
}
return std::nullopt;
}

/**
Expand Down
68 changes: 66 additions & 2 deletions backends/trtllm/tests/test_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,43 @@

#include <catch2/catch_all.hpp>
#include <nlohmann/json.hpp>
#include "../csrc/backend.hpp"
#include <tensorrt_llm/executor/executor.h>

#include "backend.hpp"



using namespace huggingface::tgi::backends::trtllm;

TEST_CASE("parse generation_config.json", "[generation_config_t]")
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
{
const json config_j = {{"temperature", 0.6}, {"top_p", 0.95}, {"eos_token_id", {1,2,3}}};
const auto generation_config = generation_config_t(config_j);

REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));

// Stop words
REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());

for (auto [lhs, rhs] : std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1}, {2}, {3}}))
{
// Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1);
REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
}
}

TEST_CASE("parse generation_config.json default", "[generation_config_t]")
{
const json config_j = {{"eos_token_id", {1,2,3}}};
const auto generation_config = generation_config_t(config_j);

REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));

REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());

Expand All @@ -25,8 +53,44 @@ TEST_CASE("parse generation_config.json", "[generation_config_t]")
}
}

TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
{
const json config_j = {{"eos_token_id", {}}};
const auto generation_config = generation_config_t(config_j);

REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));

REQUIRE(generation_config.stop_words.empty());

const json config_j2 = {};
const auto generation_config2 = generation_config_t(config_j);

REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));

REQUIRE(generation_config2.stop_words.empty());
}

TEST_CASE("parallel_config", "[backend_workspace_t]")
{
// Generate temporary folder
const auto tmp_p = std::filesystem::temp_directory_path();
const auto config_p = tmp_p / "config.json";
const auto generation_config_p = tmp_p / "generation_config.json";

// Generate content
std::ofstream o_config(config_p);
o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json;
o_config.close();

std::ofstream o_generation_config(generation_config_p);
o_generation_config << R"({"eos_token_id": []})"_json;
o_generation_config.close();

const auto workspace = backend_workspace_t(absolute(tmp_p).generic_string(), absolute(tmp_p).generic_string());
const auto parallel = workspace.parallel_config();
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);

}

Expand Down

0 comments on commit 6253064

Please sign in to comment.