Skip to content

Commit

Permalink
feat: use string enum for pooling_type
Browse files Browse the repository at this point in the history
for make it same with llama.rn
  • Loading branch information
jhen0409 committed Jan 11, 2025
1 parent ff47142 commit 7321702
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export type LlamaModelOptions = {
model: string
embedding?: boolean
embd_normalize?: number
pooling_type?: number
pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank'
n_ctx?: number
n_batch?: number
n_ubatch?: number
Expand Down
14 changes: 12 additions & 2 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
throw std::runtime_error("Unsupported cache type: " + s);
}

static int32_t pooling_type_from_str(const std::string & s) {
if (s == "none") return LLAMA_POOLING_TYPE_NONE;
if (s == "mean") return LLAMA_POOLING_TYPE_MEAN;
if (s == "cls") return LLAMA_POOLING_TYPE_CLS;
if (s == "last") return LLAMA_POOLING_TYPE_LAST;
if (s == "rank") return LLAMA_POOLING_TYPE_RANK;
return LLAMA_POOLING_TYPE_UNSPECIFIED;
}

// construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers,
// use_mlock, use_mmap }): LlamaContext throws error
LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
Expand Down Expand Up @@ -112,8 +121,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
params.n_ubatch = params.n_batch;
}
params.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
int32_t pooling_type = get_option<int32_t>(options, "pooling_type", -1);
params.pooling_type = (enum llama_pooling_type) pooling_type;
params.pooling_type = (enum llama_pooling_type) pooling_type_from_str(
get_option<std::string>(options, "pooling_type", "").c_str()
);

params.cpuparams.n_threads =
get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
Expand Down

0 comments on commit 7321702

Please sign in to comment.