From 7321702c4c419bb873a6c26d53fe8c643653a420 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Sat, 11 Jan 2025 16:00:43 +0800 Subject: [PATCH] feat: use string enum for pooling_type for make it same with llama.rn --- lib/binding.ts | 2 +- src/LlamaContext.cpp | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/lib/binding.ts b/lib/binding.ts index b8a4459..83e60fc 100644 --- a/lib/binding.ts +++ b/lib/binding.ts @@ -9,7 +9,7 @@ export type LlamaModelOptions = { model: string embedding?: boolean embd_normalize?: number - pooling_type?: number + pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank' n_ctx?: number n_batch?: number n_ubatch?: number diff --git a/src/LlamaContext.cpp b/src/LlamaContext.cpp index 0bf511a..eb9ee8d 100644 --- a/src/LlamaContext.cpp +++ b/src/LlamaContext.cpp @@ -82,6 +82,15 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { throw std::runtime_error("Unsupported cache type: " + s); } +static int32_t pooling_type_from_str(const std::string & s) { + if (s == "none") return LLAMA_POOLING_TYPE_NONE; + if (s == "mean") return LLAMA_POOLING_TYPE_MEAN; + if (s == "cls") return LLAMA_POOLING_TYPE_CLS; + if (s == "last") return LLAMA_POOLING_TYPE_LAST; + if (s == "rank") return LLAMA_POOLING_TYPE_RANK; + return LLAMA_POOLING_TYPE_UNSPECIFIED; +} + // construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers, // use_mlock, use_mmap }): LlamaContext throws error LlamaContext::LlamaContext(const Napi::CallbackInfo &info) @@ -112,8 +121,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info) params.n_ubatch = params.n_batch; } params.embd_normalize = get_option(options, "embd_normalize", 2); - int32_t pooling_type = get_option(options, "pooling_type", -1); - params.pooling_type = (enum llama_pooling_type) pooling_type; + params.pooling_type = (enum llama_pooling_type) pooling_type_from_str( + get_option(options, "pooling_type", "").c_str() + ); params.cpuparams.n_threads = get_option(options, "n_threads", cpu_get_num_math() / 2);