diff --git a/packages/cli/src/index.ts b/packages/cli/src/index.ts index f5bb075..48b5906 100644 --- a/packages/cli/src/index.ts +++ b/packages/cli/src/index.ts @@ -12,7 +12,7 @@ import { existsSync } from "fs"; const convertType = ["q4_0", "q4_1", "f16", "f32"] as const; -type ConvertType = typeof convertType[number]; +type ConvertType = (typeof convertType)[number]; interface CLIInferenceArguments extends LLamaInferenceArguments, LLamaConfig { logger?: boolean; @@ -75,7 +75,7 @@ class InferenceCommand implements yargs.CommandModule { if (logger) { LLama.enableLogger(); } - const llama = LLama.create({ path: absolutePath, numCtxTokens }); + const llama = await LLama.create({ path: absolutePath, numCtxTokens }); llama.inference(rest, (result) => { switch (result.type) { case InferenceResultType.Data: diff --git a/packages/core/__test__/index.spec.ts b/packages/core/__test__/index.spec.ts index 20eac5f..3772c51 100644 --- a/packages/core/__test__/index.spec.ts +++ b/packages/core/__test__/index.spec.ts @@ -6,7 +6,7 @@ test( async () => { LLama.enableLogger(); - const llama = LLama.create({ + const llama = await LLama.create({ path: process.env.model?.toString()!, numCtxTokens: 128, }); diff --git a/packages/core/example/cachesession.ts b/packages/core/example/cachesession.ts index 7661700..d28fbf3 100644 --- a/packages/core/example/cachesession.ts +++ b/packages/core/example/cachesession.ts @@ -6,14 +6,15 @@ const saveSession = path.resolve(process.cwd(), "./tmp/session.bin"); LLama.enableLogger(); -const llama = LLama.create({ - path: model, - numCtxTokens: 128, -}); +const run = async () => { + const llama = await LLama.create({ + path: model, + numCtxTokens: 128, + }); -const template = `how are you`; + const template = `how are you`; -const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. + const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: @@ -21,31 +22,34 @@ ${template} ### Response:`; -llama.inference( - { - prompt, - numPredict: 128, - temp: 0.2, - topP: 1, - topK: 40, - repeatPenalty: 1, - repeatLastN: 64, - seed: 0, - feedPrompt: true, - feedPromptOnly: true, - saveSession, - }, - (response) => { - switch (response.type) { - case InferenceResultType.Data: { - process.stdout.write(response.data?.token ?? ""); - break; - } - case InferenceResultType.End: - case InferenceResultType.Error: { - console.log(response); - break; + llama.inference( + { + prompt, + numPredict: 128, + temp: 0.2, + topP: 1, + topK: 40, + repeatPenalty: 1, + repeatLastN: 64, + seed: 0, + feedPrompt: true, + feedPromptOnly: true, + saveSession, + }, + (response) => { + switch (response.type) { + case InferenceResultType.Data: { + process.stdout.write(response.data?.token ?? ""); + break; + } + case InferenceResultType.End: + case InferenceResultType.Error: { + console.log(response); + break; + } } } - } -); + ); +}; + +run(); diff --git a/packages/core/example/embedding.ts b/packages/core/example/embedding.ts index be8aeb2..897e306 100644 --- a/packages/core/example/embedding.ts +++ b/packages/core/example/embedding.ts @@ -1,4 +1,4 @@ -import { EmbeddingResultType, LLama } from "../index"; +import { LLama } from "../index"; import path from "path"; import fs from "fs"; @@ -6,46 +6,42 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin"); LLama.enableLogger(); -const llama = LLama.create({ - path: model, - numCtxTokens: 128, -}); - -const getWordEmbeddings = (prompt: string, file: string) => { - llama.getWordEmbeddings( - { - prompt, - numPredict: 128, - temp: 0.2, - topP: 1, - topK: 40, - repeatPenalty: 1, - repeatLastN: 64, - seed: 0, - }, - (response) => { - switch (response.type) { - case EmbeddingResultType.Data: { - fs.writeFileSync( - path.resolve(process.cwd(), file), - JSON.stringify(response.data) - ); - break; - } - case EmbeddingResultType.Error: { - console.log(response); - break; - } - } - } +const getWordEmbeddings = async ( + llama: LLama, + prompt: string, + file: string +) => { + const response = await llama.getWordEmbeddings({ + prompt, + numPredict: 128, + temp: 0.2, + topP: 1, + topK: 40, + repeatPenalty: 1, + repeatLastN: 64, + seed: 0, + }); + + fs.writeFileSync( + path.resolve(process.cwd(), file), + JSON.stringify(response) ); }; -const dog1 = `My favourite animal is the dog`; -getWordEmbeddings(dog1, "./example/semantic-compare/dog1.json"); +const run = async () => { + const llama = await LLama.create({ + path: model, + numCtxTokens: 128, + }); + + const dog1 = `My favourite animal is the dog`; + getWordEmbeddings(llama, dog1, "./example/semantic-compare/dog1.json"); -const dog2 = `I have just adopted a cute dog`; -getWordEmbeddings(dog2, "./example/semantic-compare/dog2.json"); + const dog2 = `I have just adopted a cute dog`; + getWordEmbeddings(llama, dog2, "./example/semantic-compare/dog2.json"); + + const cat1 = `My favourite animal is the cat`; + getWordEmbeddings(llama, cat1, "./example/semantic-compare/cat1.json"); +}; -const cat1 = `My favourite animal is the cat`; -getWordEmbeddings(cat1, "./example/semantic-compare/cat1.json"); +run(); diff --git a/packages/core/example/inference.ts b/packages/core/example/inference.ts index 617b9a0..4e487e9 100644 --- a/packages/core/example/inference.ts +++ b/packages/core/example/inference.ts @@ -6,14 +6,15 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin"); LLama.enableLogger(); -const llama = LLama.create({ - path: model, - numCtxTokens: 128, -}); +const run = async () => { + const llama = await LLama.create({ + path: model, + numCtxTokens: 128, + }); -const template = `how are you`; + const template = `how are you`; -const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. + const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: @@ -21,30 +22,31 @@ ${template} ### Response:`; -llama.inference( - { - prompt, - numPredict: 128, - temp: 0.2, - topP: 1, - topK: 40, - repeatPenalty: 1, - repeatLastN: 64, - seed: 0, - feedPrompt: true, - // persistSession, - }, - (response) => { - switch (response.type) { - case InferenceResultType.Data: { - process.stdout.write(response.data?.token ?? ""); - break; - } - case InferenceResultType.End: - case InferenceResultType.Error: { - console.log(response); - break; + llama.inference( + { + prompt, + numPredict: 128, + temp: 0.2, + topP: 1, + topK: 40, + repeatPenalty: 1, + repeatLastN: 64, + seed: 0, + feedPrompt: true, + }, + (response) => { + switch (response.type) { + case InferenceResultType.Data: { + process.stdout.write(response.data?.token ?? ""); + break; + } + case InferenceResultType.End: + case InferenceResultType.Error: { + console.log(response); + break; + } } } - } -); + ); +}; +run(); diff --git a/packages/core/example/loadsession.ts b/packages/core/example/loadsession.ts index a8b2500..b0457ed 100644 --- a/packages/core/example/loadsession.ts +++ b/packages/core/example/loadsession.ts @@ -6,34 +6,38 @@ const loadSession = path.resolve(process.cwd(), "./tmp/session.bin"); LLama.enableLogger(); -const llama = LLama.create({ - path: model, - numCtxTokens: 128, -}); +const run = async () => { + const llama = await LLama.create({ + path: model, + numCtxTokens: 128, + }); -llama.inference( - { - prompt: "", - numPredict: 128, - temp: 0.2, - topP: 1, - topK: 40, - repeatPenalty: 1, - repeatLastN: 64, - seed: 0, - loadSession, - }, - (response) => { - switch (response.type) { - case InferenceResultType.Data: { - process.stdout.write(response.data?.token ?? ""); - break; - } - case InferenceResultType.End: - case InferenceResultType.Error: { - console.log(response); - break; + llama.inference( + { + prompt: "", + numPredict: 128, + temp: 0.2, + topP: 1, + topK: 40, + repeatPenalty: 1, + repeatLastN: 64, + seed: 0, + loadSession, + }, + (response) => { + switch (response.type) { + case InferenceResultType.Data: { + process.stdout.write(response.data?.token ?? ""); + break; + } + case InferenceResultType.End: + case InferenceResultType.Error: { + console.log(response); + break; + } } } - } -); + ); +}; + +run(); \ No newline at end of file diff --git a/packages/core/example/tokenize.ts b/packages/core/example/tokenize.ts index 8924442..6665552 100644 --- a/packages/core/example/tokenize.ts +++ b/packages/core/example/tokenize.ts @@ -5,14 +5,17 @@ const model = path.resolve(process.cwd(), "../../ggml-alpaca-7b-q4.bin"); LLama.enableLogger(); -const llama = LLama.create({ - path: model, - numCtxTokens: 128, -}); +const run = async () => { + const llama = await LLama.create({ + path: model, + numCtxTokens: 128, + }); -const prompt = "My favourite animal is the cat"; + const prompt = "My favourite animal is the cat"; -llama.tokenize(prompt, (response) => { - console.log(response); - console.log(response.data.length); // 7 -}); + const tokens = await llama.tokenize(prompt); + + console.log(tokens); +}; + +run(); \ No newline at end of file diff --git a/packages/core/index.d.ts b/packages/core/index.d.ts index 235eec9..f47eee1 100644 --- a/packages/core/index.d.ts +++ b/packages/core/index.d.ts @@ -17,28 +17,6 @@ export interface InferenceResult { message?: string data?: InferenceToken } -/** - * Embedding result -*/ -export const enum EmbeddingResultType { - Data = 'Data', - Error = 'Error' -} -export interface EmbeddingResult { - type: EmbeddingResultType - message?: string - data?: Array -} -/** - * Tokenize result -*/ -export const enum TokenizeResultType { - Data = 'Data' -} -export interface TokenizeResult { - type: TokenizeResultType - data: Array -} /** * LLama model load config */ @@ -66,10 +44,6 @@ export interface LLamaConfig { */ useMmap?: boolean } -export interface LoadModelResult { - error: boolean - message?: string -} export interface LLamaInferenceArguments { /** * Sets the number of threads to use @@ -194,11 +168,11 @@ export class LLama { /** Enable logger. */ static enableLogger(): void /** Create a new LLama instance. */ - static create(config: LLamaConfig): LLama + static create(config: LLamaConfig): Promise /** Get the tokenized result as number array, the result will be passed to the callback function. */ - tokenize(params: string, callback: (result: TokenizeResult) => void): void + tokenize(params: string): Promise> /** Get the embedding result as number array, the result will be passed to the callback function. */ - getWordEmbeddings(params: LLamaInferenceArguments, callback: (result: EmbeddingResult) => void): void + getWordEmbeddings(params: LLamaInferenceArguments): Promise> /** Streaming the inference result as string, the result will be passed to the callback function. */ inference(params: LLamaInferenceArguments, callback: (result: InferenceResult) => void): void } diff --git a/packages/core/index.js b/packages/core/index.js index 889e74a..26e0dc8 100644 --- a/packages/core/index.js +++ b/packages/core/index.js @@ -252,11 +252,9 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { InferenceResultType, EmbeddingResultType, TokenizeResultType, ElementType, convert, LLama } = nativeBinding +const { InferenceResultType, ElementType, convert, LLama } = nativeBinding module.exports.InferenceResultType = InferenceResultType -module.exports.EmbeddingResultType = EmbeddingResultType -module.exports.TokenizeResultType = TokenizeResultType module.exports.ElementType = ElementType module.exports.convert = convert module.exports.LLama = LLama diff --git a/packages/core/src/lib.rs b/packages/core/src/lib.rs index 3ddec11..ddcba24 100644 --- a/packages/core/src/lib.rs +++ b/packages/core/src/lib.rs @@ -7,18 +7,11 @@ extern crate napi_derive; mod llama; mod types; -use std::{ - path::Path, - sync::{mpsc::channel, Arc}, - thread, time, -}; +use std::{path::Path, sync::Arc}; -use llama::LLamaChannel; +use llama::LLamaInternal; use llama_rs::convert::convert_pth_to_ggml; -use types::{ - EmbeddingResult, InferenceResult, InferenceResultType, LLamaConfig, LLamaInferenceArguments, - LoadModelResult, TokenizeResult, -}; +use types::{InferenceResult, LLamaConfig, LLamaInferenceArguments}; use napi::{ bindgen_prelude::*, @@ -80,9 +73,8 @@ pub async fn convert(path: String, element_type: ElementType) -> Result<()> { } #[napi(js_name = "LLama")] -#[derive(Clone)] pub struct LLama { - llama_channel: Arc, + llama: Arc, } /// LLama class is a Rust wrapper for llama-rs. @@ -99,109 +91,24 @@ impl LLama { /// Create a new LLama instance. #[napi] - pub fn create(config: LLamaConfig) -> Result { - let (load_result_sender, load_result_receiver) = channel::(); - - let llama_channel = LLamaChannel::new(); - - llama_channel.load_model(config, load_result_sender); - - // currently this loop blocked main thread, will try improve in the future - 'waiting_load: loop { - let recv = load_result_receiver.recv(); - match recv { - Ok(r) => { - if r.error { - return Err(Error::new( - Status::InvalidArg, - r.message.unwrap_or("Unknown Error".to_string()), - )); - } - break 'waiting_load; - } - _ => { - thread::yield_now(); - } - } - } + pub async fn create(config: LLamaConfig) -> Result { + let llama = LLamaInternal::load_model(&config).await?; - Ok(LLama { llama_channel }) + Ok(LLama { + llama: Arc::new(llama), + }) } /// Get the tokenized result as number array, the result will be passed to the callback function. #[napi] - pub fn tokenize( - &self, - params: String, - #[napi(ts_arg_type = "(result: TokenizeResult) => void")] callback: JsFunction, - ) -> Result<()> { - let (tokenize_sender, tokenize_receiver) = channel::(); - - let tsfn: ThreadsafeFunction = callback - .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { - Ok(vec![ctx.value]) - })?; - - let llama_channel = self.llama_channel.clone(); - - llama_channel.tokenize(¶ms, tokenize_sender); - - thread::spawn(move || { - 'waiting_tokenize: loop { - let recv = tokenize_receiver.recv(); - match recv { - Ok(callback) => { - tsfn.call(callback, ThreadsafeFunctionCallMode::Blocking); - break 'waiting_tokenize; - } - _ => { - thread::yield_now(); - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) + pub async fn tokenize(&self, params: String) -> Result> { + self.llama.tokenize(¶ms).await } /// Get the embedding result as number array, the result will be passed to the callback function. #[napi] - pub fn get_word_embeddings( - &self, - params: LLamaInferenceArguments, - #[napi(ts_arg_type = "(result: EmbeddingResult) => void")] callback: JsFunction, - ) -> Result<()> { - let (embedding_sender, embedding_receiver) = channel::(); - - let tsfn: ThreadsafeFunction = callback - .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { - Ok(vec![ctx.value]) - })?; - - let llama_channel = self.llama_channel.clone(); - - llama_channel.get_word_embedding(params, embedding_sender); - - thread::spawn(move || { - 'waiting_embedding: loop { - let recv = embedding_receiver.recv(); - match recv { - Ok(callback) => { - tsfn.call(callback, ThreadsafeFunctionCallMode::Blocking); - break 'waiting_embedding; - } - _ => { - thread::yield_now(); - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) + pub async fn get_word_embeddings(&self, params: LLamaInferenceArguments) -> Result> { + self.llama.get_word_embedding(¶ms).await } /// Streaming the inference result as string, the result will be passed to the callback function. @@ -211,37 +118,19 @@ impl LLama { params: LLamaInferenceArguments, #[napi(ts_arg_type = "(result: InferenceResult) => void")] callback: JsFunction, ) -> Result<()> { - let (inference_sender, inference_receiver) = channel::(); - let tsfn: ThreadsafeFunction = callback .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { Ok(vec![ctx.value]) })?; - let llama_channel = self.llama_channel.clone(); - - llama_channel.inference(params, inference_sender); - - thread::spawn(move || { - 'waiting_inference: loop { - let recv = inference_receiver.recv(); - match recv { - Ok(callback) => match callback.r#type { - InferenceResultType::End => { - tsfn.call(callback, ThreadsafeFunctionCallMode::Blocking); - break 'waiting_inference; - } - _ => { - tsfn.call(callback, ThreadsafeFunctionCallMode::NonBlocking); - } - }, - _ => { - thread::yield_now(); - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); + let llama = self.llama.clone(); + + tokio::spawn(async move { + llama + .inference(¶ms, |result| { + tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); + }) + .await; }); Ok(()) diff --git a/packages/core/src/llama.rs b/packages/core/src/llama.rs index 32e5e1d..46d33d9 100644 --- a/packages/core/src/llama.rs +++ b/packages/core/src/llama.rs @@ -3,18 +3,8 @@ use std::{ fs::File, io::{BufReader, BufWriter}, path::Path, - sync::{ - mpsc::{channel, Receiver, Sender, TryRecvError}, - Arc, Mutex, - }, - thread, }; -use crate::types::{ - EmbeddingResult, EmbeddingResultType, InferenceResult, InferenceResultType, InferenceToken, - LLamaCommand, LLamaConfig, LLamaInferenceArguments, LoadModelResult, TokenizeResult, - TokenizeResultType, -}; use anyhow::{Error, Result}; use llama_rs::{ EvaluateOutputRequest, InferenceError, InferenceParameters, InferenceSession, @@ -23,16 +13,14 @@ use llama_rs::{ use rand::SeedableRng; use zstd::{zstd_safe::CompressionLevel, Decoder, Encoder}; -const CACHE_COMPRESSION_LEVEL: CompressionLevel = 1; +use crate::types::{ + InferenceResult, InferenceResultType, InferenceToken, LLamaConfig, LLamaInferenceArguments, +}; -#[derive(Clone)] -pub struct LLamaChannel { - command_sender: Sender, - command_receiver: Arc>>, -} +const CACHE_COMPRESSION_LEVEL: CompressionLevel = 1; -struct LLamaInternal { - model: Option, +pub struct LLamaInternal { + pub model: Model, } fn parse_bias(s: &str) -> Result { @@ -40,7 +28,7 @@ fn parse_bias(s: &str) -> Result { } impl LLamaInternal { - pub fn load_model(&mut self, params: &LLamaConfig, sender: &Sender) { + pub async fn load_model(params: &LLamaConfig) -> Result { let num_ctx_tokens = params.num_ctx_tokens.unwrap_or(512); let use_mmap = params.use_mmap.unwrap_or(true); log::info!("num_ctx_tokens: {}", num_ctx_tokens); @@ -102,43 +90,25 @@ impl LLamaInternal { } }, ) { - self.model = Some(model); - log::info!("Model fully loaded!"); - sender - .send(LoadModelResult { - error: false, - message: None, - }) - .unwrap(); + Ok(LLamaInternal { model }) } else { - sender - .send(LoadModelResult { - error: true, - message: Some("Could not load model".to_string()), - }) - .unwrap(); + // TODO: optimiza error handling + Err(napi::Error::from_reason("Could not load model")) } } - pub fn tokenize(&self, text: &str, sender: &Option>) -> Vec { - let vocab = self.model.as_ref().unwrap().vocabulary(); + pub async fn tokenize(&self, text: &str) -> Result, napi::Error> { + let vocab = self.model.vocabulary(); let tokens = vocab .tokenize(text, false) .unwrap() .iter() .map(|(_, tid)| *tid) .collect::>(); - if let Some(sender) = sender { - sender - .send(TokenizeResult { - data: tokens.clone(), - r#type: TokenizeResultType::Data, - }) - .unwrap(); - } - tokens + + Ok(tokens) } fn get_inference_params(&self, params: &LLamaInferenceArguments) -> InferenceParameters { @@ -209,7 +179,7 @@ impl LLamaInternal { persist_session: Option<&Path>, inference_session_params: InferenceSessionParameters, ) -> Result { - let model = self.model.as_ref().ok_or(Error::msg("Model not loaded"))?; + let model = &self.model; fn load(model: &Model, path: &Path) -> Result { let file = File::open(path)?; @@ -258,14 +228,13 @@ impl LLamaInternal { .unwrap() } - pub fn get_word_embedding( + pub async fn get_word_embedding( &self, params: &LLamaInferenceArguments, - sender: &Sender, - ) { + ) -> Result, napi::Error> { let mut session = self.start_new_session(params); let inference_params = self.get_inference_params(params); - let model = self.model.as_ref().unwrap(); + let model = &self.model; let prompt_for_feed = format!(" {}", params.prompt); if let Err(InferenceError::ContextFull) = @@ -273,16 +242,10 @@ impl LLamaInternal { Ok(()) }) { - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Error, - message: Some("Context window full.".to_string()), - data: None, - }) - .unwrap(); + return Err(napi::Error::from_reason("Context window full.")); } - let end_token = self.tokenize("\n", &None); + let end_token = self.tokenize("\n").await.unwrap(); let mut output_request = EvaluateOutputRequest { all_logits: None, @@ -296,20 +259,20 @@ impl LLamaInternal { &mut output_request, ); - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Data, - message: None, - data: output_request - .embeddings - .map(|embd| embd.into_iter().map(|data| data.into()).collect()), - }) - .unwrap(); + let output: Option> = output_request + .embeddings + .map(|embd| embd.into_iter().map(|data| data.into()).collect()); + + Ok(output.unwrap_or(Vec::new())) } - pub fn inference(&mut self, params: &LLamaInferenceArguments, sender: &Sender) { + pub async fn inference( + &self, + params: &LLamaInferenceArguments, + callback: impl Fn(InferenceResult), + ) { let num_predict = params.num_predict.unwrap_or(512) as usize; - let model = self.model.as_ref().unwrap(); + let model = &self.model; let prompt = ¶ms.prompt; let feed_prompt_only = params.feed_prompt_only.unwrap_or(false); @@ -332,13 +295,11 @@ impl LLamaInternal { if let Err(InferenceError::ContextFull) = session.feed_prompt::(model, &inference_params, prompt, |_| Ok(())) { - sender - .send(InferenceResult { - r#type: InferenceResultType::Error, - message: Some("Context window full.".to_string()), - data: None, - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Error, + message: Some("Context window full.".to_string()), + data: None, + }); } if !feed_prompt_only { @@ -362,7 +323,7 @@ impl LLamaInternal { }), }; - sender.send(to_send).unwrap(); + callback(to_send); Ok(()) }, @@ -379,25 +340,23 @@ impl LLamaInternal { }), }; - sender.send(to_send).unwrap(); + callback(to_send); } Err(error) => { - sender - .send(InferenceResult { - r#type: InferenceResultType::Error, - message: match error { - llama_rs::InferenceError::EndOfText => Some("End of text.".to_string()), - llama_rs::InferenceError::ContextFull => { - Some("Context window full, stopping inference.".to_string()) - } - llama_rs::InferenceError::TokenizationFailed => { - Some("Tokenization failed.".to_string()) - } - llama_rs::InferenceError::UserCallback(_) => Some("Inference failed.".to_string()), - }, - data: None, - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Error, + message: match error { + llama_rs::InferenceError::EndOfText => Some("End of text.".to_string()), + llama_rs::InferenceError::ContextFull => { + Some("Context window full, stopping inference.".to_string()) + } + llama_rs::InferenceError::TokenizationFailed => { + Some("Tokenization failed.".to_string()) + } + llama_rs::InferenceError::UserCallback(_) => Some("Inference failed.".to_string()), + }, + data: None, + }); } } } else { @@ -410,101 +369,17 @@ impl LLamaInternal { }), }; - sender.send(to_send).unwrap(); + callback(to_send); } if let Some(session_path) = params.save_session.as_ref() { self.write_session(session, session_path).unwrap(); } - sender - .send(InferenceResult { - r#type: InferenceResultType::End, - message: None, - data: None, - }) - .unwrap(); - } -} - -impl LLamaChannel { - pub fn new() -> Arc { - let (command_sender, command_receiver) = channel::(); - - let channel = LLamaChannel { - command_receiver: Arc::new(Mutex::new(command_receiver)), - command_sender, - }; - - channel.spawn(); - - Arc::new(channel) - } - - pub fn load_model(&self, params: LLamaConfig, sender: Sender) { - self - .command_sender - .send(LLamaCommand::LoadModel(params, sender)) - .unwrap(); - } - - pub fn inference(&self, params: LLamaInferenceArguments, sender: Sender) { - self - .command_sender - .send(LLamaCommand::Inference(params, sender)) - .unwrap(); - } - - pub fn get_word_embedding( - &self, - params: LLamaInferenceArguments, - sender: Sender, - ) { - self - .command_sender - .send(LLamaCommand::Embedding(params, sender)) - .unwrap() - } - - pub fn tokenize(&self, text: &str, sender: Sender) { - self - .command_sender - .send(LLamaCommand::Tokenize(text.to_string(), sender)) - .unwrap(); - } - - // llama instance main loop - pub fn spawn(&self) { - let rv = self.command_receiver.clone(); - - thread::spawn(move || { - let mut llama = LLamaInternal { model: None }; - - let rv = rv.lock().unwrap(); - - 'llama_loop: loop { - let command = rv.try_recv(); - match command { - Ok(LLamaCommand::Inference(params, sender)) => { - llama.inference(¶ms, &sender); - } - Ok(LLamaCommand::LoadModel(params, sender)) => { - llama.load_model(¶ms, &sender); - } - Ok(LLamaCommand::Embedding(params, sender)) => { - llama.get_word_embedding(¶ms, &sender); - } - Ok(LLamaCommand::Tokenize(text, sender)) => { - llama.tokenize(&text, &Some(sender)); - } - Err(TryRecvError::Disconnected) => { - break 'llama_loop; - } - _ => { - thread::yield_now(); - } - } - } + callback(InferenceResult { + r#type: InferenceResultType::End, + message: None, + data: None, }); } } diff --git a/packages/core/src/types.rs b/packages/core/src/types.rs index d8a6153..f85724a 100644 --- a/packages/core/src/types.rs +++ b/packages/core/src/types.rs @@ -1,5 +1,4 @@ use napi::bindgen_prelude::*; -use std::sync::mpsc::Sender; #[napi(object)] #[derive(Clone, Debug)] @@ -24,40 +23,6 @@ pub struct InferenceResult { pub data: Option, } -/** - * Embedding result - */ -#[napi(string_enum)] -#[derive(Debug)] -pub enum EmbeddingResultType { - Data, - Error, -} - -#[napi(object)] -#[derive(Clone, Debug)] -pub struct EmbeddingResult { - pub r#type: EmbeddingResultType, - pub message: Option, - pub data: Option>, -} - -/** - * Tokenize result - */ -#[napi(string_enum)] -#[derive(Debug)] -pub enum TokenizeResultType { - Data, -} - -#[napi(object)] -#[derive(Clone, Debug)] -pub struct TokenizeResult { - pub r#type: TokenizeResultType, - pub data: Vec, -} - /** * LLama model load config */ @@ -86,13 +51,6 @@ pub struct LLamaConfig { pub use_mmap: Option, } -#[napi(object)] -#[derive(Clone, Debug)] -pub struct LoadModelResult { - pub error: bool, - pub message: Option, -} - #[napi(object)] #[derive(Clone, Debug)] pub struct LLamaInferenceArguments { @@ -178,11 +136,3 @@ pub struct LLamaInferenceArguments { /// Default is None pub save_session: Option, } - -#[derive(Clone, Debug)] -pub enum LLamaCommand { - LoadModel(LLamaConfig, Sender), - Inference(LLamaInferenceArguments, Sender), - Embedding(LLamaInferenceArguments, Sender), - Tokenize(String, Sender), -} diff --git a/packages/llama-cpp/example/embedding.ts b/packages/llama-cpp/example/embedding.ts index 8ba9042..2ace888 100644 --- a/packages/llama-cpp/example/embedding.ts +++ b/packages/llama-cpp/example/embedding.ts @@ -1,34 +1,38 @@ import { LLama, LlamaContextParams, LlamaInvocation } from "../index"; import path from "path"; -const llama = LLama.load( - path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), - { - nCtx: 512, - nParts: -1, - seed: 0, - f16Kv: false, - logitsAll: false, - vocabOnly: false, - useMlock: false, - embedding: true, - useMmap: true, - }, - false -); +const run = async () => { + const llama = await LLama.load( + path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), + { + nCtx: 512, + nParts: -1, + seed: 0, + f16Kv: false, + logitsAll: false, + vocabOnly: false, + useMlock: false, + embedding: true, + useMmap: true, + }, + false + ); -const prompt = `Who is the president of the United States?`; + const prompt = `Who is the president of the United States?`; -const params: LlamaInvocation = { - nThreads: 4, - nTokPredict: 2048, - topK: 40, - topP: 0.1, - temp: 0.2, - repeatPenalty: 1, - prompt, + const params: LlamaInvocation = { + nThreads: 4, + nTokPredict: 2048, + topK: 40, + topP: 0.1, + temp: 0.2, + repeatPenalty: 1, + prompt, + }; + + llama.getWordEmbedding(params).then((data) => { + console.log(data); + }); }; -llama.getWordEmbedding(params, (data) => { - console.log(data.data); -}); +run(); diff --git a/packages/llama-cpp/example/load.ts b/packages/llama-cpp/example/load.ts index d68f0ca..931a48f 100644 --- a/packages/llama-cpp/example/load.ts +++ b/packages/llama-cpp/example/load.ts @@ -1,28 +1,32 @@ import { LLama, LlamaInvocation } from "../index"; import path from "path"; -const llama = LLama.load( - path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), - null, - true -); +const run = async () => { + const llama = await LLama.load( + path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), + null, + true + ); -const template = `Who is the president of the United States?`; + const template = `Who is the president of the United States?`; -const prompt = `A chat between a user and an assistant. + const prompt = `A chat between a user and an assistant. USER: ${template} ASSISTANT:`; -const params: LlamaInvocation = { - nThreads: 4, - nTokPredict: 2048, - topK: 40, - topP: 0.1, - temp: 0.2, - repeatPenalty: 1, - prompt, + const params: LlamaInvocation = { + nThreads: 4, + nTokPredict: 2048, + topK: 40, + topP: 0.1, + temp: 0.2, + repeatPenalty: 1, + prompt, + }; + + llama.inference(params, (data) => { + process.stdout.write(data.data?.token ?? ""); + }); }; -llama.inference(params, (data) => { - process.stdout.write(data.data?.token ?? ""); -}); +run(); diff --git a/packages/llama-cpp/example/tokenize.ts b/packages/llama-cpp/example/tokenize.ts index 60603c0..74fb80b 100644 --- a/packages/llama-cpp/example/tokenize.ts +++ b/packages/llama-cpp/example/tokenize.ts @@ -1,14 +1,18 @@ import { LLama } from "../index"; import path from "path"; -const llama = LLama.load( - path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), - null, - false -); +const run = async () => { + const llama = await LLama.load( + path.resolve(process.cwd(), "../../ggml-vicuna-7b-1.1-q4_1.bin"), + null, + false + ); -const template = `Who is the president of the United States?`; + const template = `Who is the president of the United States?`; -llama.tokenize(template, 2048, (data) => { - console.log(data.data); -}); + llama.tokenize(template, 2048).then((data) => { + console.log(data); + }); +}; + +run(); diff --git a/packages/llama-cpp/index.d.ts b/packages/llama-cpp/index.d.ts index 01fe9ef..1eaaec8 100644 --- a/packages/llama-cpp/index.d.ts +++ b/packages/llama-cpp/index.d.ts @@ -3,6 +3,20 @@ /* auto-generated by NAPI-RS */ +export interface InferenceToken { + token: string + completed: boolean +} +export const enum InferenceResultType { + Error = 'Error', + Data = 'Data', + End = 'End' +} +export interface InferenceResult { + type: InferenceResultType + data?: InferenceToken + message?: string +} export interface LlamaInvocation { nThreads: number nTokPredict: number @@ -30,39 +44,9 @@ export interface LlamaContextParams { embedding: boolean useMmap: boolean } -export const enum TokenizeResultType { - Error = 'Error', - Data = 'Data' -} -export interface TokenizeResult { - type: TokenizeResultType - data: Array -} -export interface InferenceToken { - token: string - completed: boolean -} -export const enum InferenceResultType { - Error = 'Error', - Data = 'Data', - End = 'End' -} -export interface InferenceResult { - type: InferenceResultType - data?: InferenceToken - message?: string -} -export const enum EmbeddingResultType { - Error = 'Error', - Data = 'Data' -} -export interface EmbeddingResult { - type: EmbeddingResultType - data: Array -} export class LLama { - static load(path: string, params: LlamaContextParams | undefined | null, enableLogger: boolean): LLama - getWordEmbedding(input: LlamaInvocation, callback: (result: EmbeddingResult) => void): void - tokenize(params: string, nCtx: number, callback: (result: TokenizeResult) => void): void - inference(input: LlamaInvocation, callback: (result: InferenceResult) => void): void + static load(path: string, params: LlamaContextParams | undefined | null, enableLogger: boolean): Promise + getWordEmbedding(params: LlamaInvocation): Promise> + tokenize(params: string, nCtx: number): Promise> + inference(params: LlamaInvocation, callback: (result: InferenceResult) => void): void } diff --git a/packages/llama-cpp/index.js b/packages/llama-cpp/index.js index 46bb1a6..c6b1612 100644 --- a/packages/llama-cpp/index.js +++ b/packages/llama-cpp/index.js @@ -252,9 +252,7 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { TokenizeResultType, InferenceResultType, EmbeddingResultType, LLama } = nativeBinding +const { InferenceResultType, LLama } = nativeBinding -module.exports.TokenizeResultType = TokenizeResultType module.exports.InferenceResultType = InferenceResultType -module.exports.EmbeddingResultType = EmbeddingResultType module.exports.LLama = LLama diff --git a/packages/llama-cpp/src/context.rs b/packages/llama-cpp/src/context.rs index 1577f03..1b06939 100644 --- a/packages/llama-cpp/src/context.rs +++ b/packages/llama-cpp/src/context.rs @@ -11,39 +11,7 @@ use llama_sys::{ llama_token_to_str, }; -#[napi(object)] -#[derive(Debug, Clone)] -pub struct LlamaInvocation { - pub n_threads: i32, - pub n_tok_predict: i32, - pub top_k: i32, // 40 - pub top_p: Option, // default 0.95f, 1.0 = disabled - pub tfs_z: Option, // default 1.00f, 1.0 = disabled - pub temp: Option, // default 0.80f, 1.0 = disabled - pub typical_p: Option, // default 1.00f, 1.0 = disabled - pub repeat_penalty: Option, // default 1.10f, 1.0 = disabled - pub repeat_last_n: Option, // default 64, last n tokens to penalize (0 = disable penalty, -1 = context size) - pub frequency_penalty: Option, // default 0.00f, 1.0 = disabled - pub presence_penalty: Option, // default 0.00f, 1.0 = disabled - pub stop_sequence: Option, - pub penalize_nl: Option, - pub prompt: String, -} - -// Represents the configuration parameters for a LLamaContext. -#[napi(object)] -#[derive(Debug, Clone)] -pub struct LlamaContextParams { - pub n_ctx: i32, - pub n_parts: i32, - pub seed: i32, - pub f16_kv: bool, - pub logits_all: bool, - pub vocab_only: bool, - pub use_mlock: bool, - pub embedding: bool, - pub use_mmap: bool, -} +use crate::types::{LlamaContextParams, LlamaInvocation}; impl LlamaContextParams { // Returns the default parameters or the user-specified parameters. @@ -74,13 +42,14 @@ impl From for llama_context_params { } // Represents the LLamaContext which wraps FFI calls to the llama.cpp library. +#[derive(Clone)] pub struct LLamaContext { ctx: *mut llama_context, } impl LLamaContext { // Creates a new LLamaContext from the specified file and configuration parameters. - pub fn from_file_and_params(path: &str, params: &Option) -> Self { + pub async fn from_file_and_params(path: &str, params: &Option) -> Self { let params = LlamaContextParams::or_default(params); let ctx = unsafe { llama_init_from_file(path.as_ptr() as *const i8, params) }; Self { ctx } diff --git a/packages/llama-cpp/src/lib.rs b/packages/llama-cpp/src/lib.rs index 4d216ca..9217a4b 100644 --- a/packages/llama-cpp/src/lib.rs +++ b/packages/llama-cpp/src/lib.rs @@ -8,13 +8,9 @@ mod llama; mod tokenizer; mod types; -use std::{ - sync::{mpsc::channel, Arc}, - thread, time, -}; +use std::sync::Arc; -use context::{LlamaContextParams, LlamaInvocation}; -use llama::LLamaChannel; +use llama::LLamaInternal; use napi::{ bindgen_prelude::*, threadsafe_function::{ @@ -22,17 +18,18 @@ use napi::{ }, JsFunction, }; -use types::{EmbeddingResult, InferenceResult, TokenizeResult}; +use tokio::sync::Mutex; +use types::{InferenceResult, LlamaContextParams, LlamaInvocation}; #[napi] pub struct LLama { - llama_channel: Arc, + llama: Arc>, } #[napi] impl LLama { #[napi] - pub fn load( + pub async fn load( path: String, params: Option, enable_logger: bool, @@ -44,122 +41,43 @@ impl LLama { .init(); } - let (load_result_sender, load_result_receiver) = channel::(); - let llama_channel = LLamaChannel::new(path, params, load_result_sender, enable_logger); - 'waiting_load: loop { - let recv = load_result_receiver.recv(); - match recv { - Ok(r) => { - if !r { - return Err(Error::new(Status::InvalidArg, "Load error".to_string())); - } - break 'waiting_load; - } - _ => { - thread::yield_now(); - } - } - } - Ok(Self { llama_channel }) + Ok(Self { + llama: LLamaInternal::load(path, params, enable_logger).await, + }) } #[napi] - pub fn get_word_embedding( - &self, - input: LlamaInvocation, - #[napi(ts_arg_type = "(result: EmbeddingResult) => void")] callback: JsFunction, - ) -> Result<()> { - let tsfn: ThreadsafeFunction = - callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let (embeddings_sender, embeddings_receiver) = channel(); - let llama_channel = self.llama_channel.clone(); - - llama_channel.embedding(input, embeddings_sender); - - thread::spawn(move || { - loop { - let result = embeddings_receiver.recv(); - match result { - Ok(result) => { - tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); - } - Err(_) => { - break; - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) + pub async fn get_word_embedding(&self, params: LlamaInvocation) -> Result> { + let llama = self.llama.lock().await; + llama.embedding(¶ms).await } #[napi] - pub fn tokenize( - &self, - params: String, - n_ctx: i32, - #[napi(ts_arg_type = "(result: TokenizeResult) => void")] callback: JsFunction, - ) -> Result<()> { - let (tokenize_sender, tokenize_receiver) = channel::(); - - let tsfn: ThreadsafeFunction = callback - .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { - Ok(vec![ctx.value]) - })?; - - let llama_channel = self.llama_channel.clone(); - - llama_channel.tokenize(params, n_ctx as usize, tokenize_sender); - - thread::spawn(move || { - 'waiting_tokenize: loop { - let recv = tokenize_receiver.recv(); - match recv { - Ok(callback) => { - tsfn.call(callback, ThreadsafeFunctionCallMode::Blocking); - break 'waiting_tokenize; - } - _ => { - thread::yield_now(); - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) + pub async fn tokenize(&self, params: String, n_ctx: i32) -> Result> { + let llama = self.llama.lock().await; + llama.tokenize(¶ms, n_ctx as usize).await } #[napi] pub fn inference( &self, - input: LlamaInvocation, + params: LlamaInvocation, #[napi(ts_arg_type = "(result: InferenceResult) => void")] callback: JsFunction, ) -> Result<()> { - let tsfn: ThreadsafeFunction = - callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let (inference_sender, inference_receiver) = channel(); - let llama_channel = self.llama_channel.clone(); + let tsfn: ThreadsafeFunction = callback + .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { + Ok(vec![ctx.value]) + })?; - llama_channel.inference(input, inference_sender); + let llama = self.llama.clone(); - thread::spawn(move || { - loop { - let result = inference_receiver.recv(); - match result { - Ok(result) => { - tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); - } - Err(_) => { - break; - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); + tokio::spawn(async move { + let llama = llama.lock().await; + llama + .inference(¶ms, |result| { + tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); + }) + .await; }); Ok(()) diff --git a/packages/llama-cpp/src/llama.rs b/packages/llama-cpp/src/llama.rs index 3c23b95..46f2cbb 100644 --- a/packages/llama-cpp/src/llama.rs +++ b/packages/llama-cpp/src/llama.rs @@ -1,51 +1,45 @@ -use std::{ - sync::{ - mpsc::{channel, Receiver, Sender, TryRecvError}, - Arc, Mutex, - }, - thread, -}; +use std::sync::Arc; + +use tokio::sync::Mutex; use crate::{ - context::{LLamaContext, LlamaContextParams, LlamaInvocation}, + context::{LLamaContext}, tokenizer::{llama_token_eos, tokenize}, - types::{ - EmbeddingResult, EmbeddingResultType, InferenceResult, InferenceResultType, InferenceToken, - LLamaCommand, TokenizeResult, TokenizeResultType, - }, + types::{InferenceResult, InferenceResultType, InferenceToken, LlamaContextParams, LlamaInvocation}, }; #[derive(Clone)] -pub struct LLamaChannel { - command_sender: Sender, - command_receiver: Arc>>, -} - pub struct LLamaInternal { context: LLamaContext, context_params: Option, } impl LLamaInternal { - pub fn tokenize(&self, input: &str, n_ctx: usize, sender: &Sender) { + pub async fn load( + path: String, + params: Option, + enable_logger: bool, + ) -> Arc> { + let llama = LLamaInternal { + context: LLamaContext::from_file_and_params(&path, ¶ms).await, + context_params: params, + }; + + if enable_logger { + llama.context.llama_print_system_info(); + } + + Arc::new(Mutex::new(llama)) + } + pub async fn tokenize(&self, input: &str, n_ctx: usize) -> Result, napi::Error> { if let Ok(data) = tokenize(&self.context, input, n_ctx, false) { - sender - .send(TokenizeResult { - data, - r#type: TokenizeResultType::Data, - }) - .unwrap(); + Ok(data) } else { - sender - .send(TokenizeResult { - data: vec![], - r#type: TokenizeResultType::Error, - }) - .unwrap(); + Err(napi::Error::from_reason("Failed to tokenize")) } } - pub fn embedding(&self, input: &LlamaInvocation, sender: &Sender) { + pub async fn embedding(&self, input: &LlamaInvocation) -> Result, napi::Error> { let context_params_c = LlamaContextParams::or_default(&self.context_params); let input_ctx = &self.context; let embd_inp = tokenize( @@ -67,23 +61,13 @@ impl LLamaInternal { let embeddings = input_ctx.llama_get_embeddings(); if let Ok(embeddings) = embeddings { - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Data, - data: embeddings.iter().map(|&x| x as f64).collect(), - }) - .unwrap(); + Ok(embeddings.iter().map(|&x| x as f64).collect()) } else { - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Error, - data: vec![], - }) - .unwrap(); + Err(napi::Error::from_reason("Failed to get embeddings")) } } - pub fn inference(&self, input: &LlamaInvocation, sender: &Sender) { + pub async fn inference(&self, input: &LlamaInvocation, callback: impl Fn(InferenceResult)) { let context_params_c = LlamaContextParams::or_default(&self.context_params); let input_ctx = &self.context; // Tokenize the stop sequence and input prompt. @@ -139,13 +123,11 @@ impl LLamaInternal { if input.n_tok_predict != 0 && n_used > (input.n_tok_predict as usize) + tokenized_input.len() - 1 { - sender - .send(InferenceResult { - r#type: InferenceResultType::Error, - data: None, - message: Some("Too many tokens predicted".to_string()), - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Error, + data: None, + message: Some("Too many tokens predicted".to_string()), + }); break; } @@ -169,128 +151,33 @@ impl LLamaInternal { if let Some(output) = output { if stop_sequence_i == 0 { - sender - .send(InferenceResult { - r#type: InferenceResultType::Data, - data: Some(InferenceToken { - token: output, - completed: false, - }), - message: None, - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Data, + data: Some(InferenceToken { + token: output, + completed: false, + }), + message: None, + }); } } } if completed { - sender - .send(InferenceResult { - r#type: InferenceResultType::Data, - data: Some(InferenceToken { - token: "\n\n\n".to_string(), - completed: true, - }), - message: None, - }) - .unwrap(); - } - - sender - .send(InferenceResult { - r#type: InferenceResultType::End, - data: None, + callback(InferenceResult { + r#type: InferenceResultType::Data, + data: Some(InferenceToken { + token: "\n\n\n".to_string(), + completed: true, + }), message: None, - }) - .unwrap(); - // embedding_to_output( - // input_ctx, - // &embd[tokenized_input.len()..n_used + 1 - stop_sequence_i], - // ); - } -} - -impl LLamaChannel { - pub fn new( - path: String, - params: Option, - load_result_sender: Sender, - enable_logger: bool, - ) -> Arc { - let (command_sender, command_receiver) = channel::(); - - let channel = LLamaChannel { - command_receiver: Arc::new(Mutex::new(command_receiver)), - command_sender, - }; - - channel.spawn(path, params, load_result_sender, enable_logger); - - Arc::new(channel) - } - - pub fn tokenize(&self, input: String, n_ctx: usize, sender: Sender) { - self.command_sender - .send(LLamaCommand::Tokenize(input, n_ctx, sender)) - .unwrap(); - } - - pub fn embedding(&self, params: LlamaInvocation, sender: Sender) { - self.command_sender - .send(LLamaCommand::Embedding(params, sender)) - .unwrap(); - } - - pub fn inference(&self, params: LlamaInvocation, sender: Sender) { - self.command_sender - .send(LLamaCommand::Inference(params, sender)) - .unwrap(); - } - - // llama instance main loop - pub fn spawn( - &self, - path: String, - params: Option, - load_result_sender: Sender, - enable_logger: bool, - ) { - let rv = self.command_receiver.clone(); - - thread::spawn(move || { - let llama = LLamaInternal { - context: LLamaContext::from_file_and_params(&path, ¶ms), - context_params: params, - }; - - if enable_logger { - llama.context.llama_print_system_info(); - } - - load_result_sender.send(true).unwrap(); - - let rv = rv.lock().unwrap(); + }); + } - 'llama_loop: loop { - let command = rv.try_recv(); - match command { - Ok(LLamaCommand::Inference(params, sender)) => { - llama.inference(¶ms, &sender); - } - Ok(LLamaCommand::Embedding(params, sender)) => { - llama.embedding(¶ms, &sender); - } - Ok(LLamaCommand::Tokenize(text, n_ctx, sender)) => { - llama.tokenize(&text, n_ctx, &sender); - } - Err(TryRecvError::Disconnected) => { - break 'llama_loop; - } - _ => { - thread::yield_now(); - } - } - } + callback(InferenceResult { + r#type: InferenceResultType::End, + data: None, + message: None, }); } } diff --git a/packages/llama-cpp/src/types.rs b/packages/llama-cpp/src/types.rs index c421920..d7daa60 100644 --- a/packages/llama-cpp/src/types.rs +++ b/packages/llama-cpp/src/types.rs @@ -1,25 +1,4 @@ -use crate::context::LlamaInvocation; use napi::bindgen_prelude::*; -use std::sync::mpsc::Sender; - -#[derive(Clone, Debug)] -pub enum LLamaCommand { - Inference(LlamaInvocation, Sender), - Tokenize(String, usize, Sender), - Embedding(LlamaInvocation, Sender), -} - -#[napi(string_enum)] -pub enum TokenizeResultType { - Error, - Data, -} - -#[napi(object)] -pub struct TokenizeResult { - pub r#type: TokenizeResultType, - pub data: Vec, -} #[napi(object)] #[derive(Clone, Debug)] @@ -42,14 +21,36 @@ pub struct InferenceResult { pub message: Option, } -#[napi(string_enum)] -pub enum EmbeddingResultType { - Error, - Data, -} - #[napi(object)] -pub struct EmbeddingResult { - pub r#type: EmbeddingResultType, - pub data: Vec, +#[derive(Debug, Clone)] +pub struct LlamaInvocation { + pub n_threads: i32, + pub n_tok_predict: i32, + pub top_k: i32, // 40 + pub top_p: Option, // default 0.95f, 1.0 = disabled + pub tfs_z: Option, // default 1.00f, 1.0 = disabled + pub temp: Option, // default 0.80f, 1.0 = disabled + pub typical_p: Option, // default 1.00f, 1.0 = disabled + pub repeat_penalty: Option, // default 1.10f, 1.0 = disabled + pub repeat_last_n: Option, // default 64, last n tokens to penalize (0 = disable penalty, -1 = context size) + pub frequency_penalty: Option, // default 0.00f, 1.0 = disabled + pub presence_penalty: Option, // default 0.00f, 1.0 = disabled + pub stop_sequence: Option, + pub penalize_nl: Option, + pub prompt: String, +} + +// Represents the configuration parameters for a LLamaContext. +#[napi(object)] +#[derive(Debug, Clone)] +pub struct LlamaContextParams { + pub n_ctx: i32, + pub n_parts: i32, + pub seed: i32, + pub f16_kv: bool, + pub logits_all: bool, + pub vocab_only: bool, + pub use_mlock: bool, + pub embedding: bool, + pub use_mmap: bool, } diff --git a/packages/rwkv-cpp/example/load.ts b/packages/rwkv-cpp/example/load.ts index a8df07a..8ee0bfb 100644 --- a/packages/rwkv-cpp/example/load.ts +++ b/packages/rwkv-cpp/example/load.ts @@ -1,31 +1,35 @@ import { Rwkv, RwkvInvocation } from "../index"; import path from "path"; -const rwkv = Rwkv.load( - path.resolve( - process.cwd(), - "../../ggml-rwkv-4_raven-7b-v9-Eng99%-20230412-ctx8192-Q4_1_0.bin" - ), - path.resolve(process.cwd(), "../../20B_tokenizer.json"), - 4, - true -); +const run = async () => { + const rwkv = await Rwkv.load( + path.resolve( + process.cwd(), + "../../ggml-rwkv-4_raven-7b-v9-Eng99%-20230412-ctx8192-Q4_1_0.bin" + ), + path.resolve(process.cwd(), "../../20B_tokenizer.json"), + 4, + true + ); -const template = `Who is the president of the United States?`; + const template = `Who is the president of the United States?`; -const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. + const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: ${template} ### Response:`; -const params: RwkvInvocation = { - maxPredictLength: 2048, - topP: 0.1, - temp: 0.1, - prompt, + const params: RwkvInvocation = { + maxPredictLength: 2048, + topP: 0.1, + temp: 0.1, + prompt, + }; + + rwkv.inference(params, (data) => { + process.stdout.write(data.data?.token ?? ""); + }); }; -rwkv.inference(params, (data) => { - process.stdout.write(data.data?.token ?? ""); -}); +run(); diff --git a/packages/rwkv-cpp/example/tokenize.ts b/packages/rwkv-cpp/example/tokenize.ts index c37fd3f..b06c2f2 100644 --- a/packages/rwkv-cpp/example/tokenize.ts +++ b/packages/rwkv-cpp/example/tokenize.ts @@ -1,18 +1,22 @@ import { Rwkv } from "../index"; import path from "path"; -const rwkv = Rwkv.load( - path.resolve( - process.cwd(), - "../../ggml-rwkv-4_raven-7b-v9-Eng99%-20230412-ctx8192-Q4_1_0.bin" - ), - path.resolve(process.cwd(), "../../20B_tokenizer.json"), - 4, - true -); +const run = async () => { + const rwkv = await Rwkv.load( + path.resolve( + process.cwd(), + "../../ggml-rwkv-4_raven-7b-v9-Eng99%-20230412-ctx8192-Q4_1_0.bin" + ), + path.resolve(process.cwd(), "../../20B_tokenizer.json"), + 4, + true + ); -const template = `Who is the president of the United States?`; + const template = `Who is the president of the United States?`; -rwkv.tokenize(template, (data) => { - console.log(data.data); -}); + const tokens = await rwkv.tokenize(template); + + console.log(tokens); +}; + +run(); diff --git a/packages/rwkv-cpp/index.d.ts b/packages/rwkv-cpp/index.d.ts index 61ef40e..6b0fbb5 100644 --- a/packages/rwkv-cpp/index.d.ts +++ b/packages/rwkv-cpp/index.d.ts @@ -11,14 +11,6 @@ export interface RwkvInvocation { seed?: number prompt: string } -export const enum TokenizeResultType { - Error = 'Error', - Data = 'Data' -} -export interface TokenizeResult { - type: TokenizeResultType - data: Array -} export interface InferenceToken { token: string completed: boolean @@ -33,18 +25,9 @@ export interface InferenceResult { data?: InferenceToken message?: string } -export const enum EmbeddingResultType { - Error = 'Error', - Data = 'Data' -} -export interface EmbeddingResult { - type: EmbeddingResultType - data: Array -} export type RWKV = Rwkv export class Rwkv { - static load(modelPath: string, tokenizerPath: string, nThreads: number, enableLogger: boolean): Rwkv - getWordEmbedding(input: RwkvInvocation, callback: (result: EmbeddingResult) => void): void - tokenize(params: string, callback: (result: TokenizeResult) => void): void - inference(input: RwkvInvocation, callback: (result: InferenceResult) => void): void + static load(modelPath: string, tokenizerPath: string, nThreads: number, enableLogger: boolean): Promise + tokenize(params: string): Promise> + inference(params: RwkvInvocation, callback: (result: InferenceResult) => void): void } diff --git a/packages/rwkv-cpp/index.js b/packages/rwkv-cpp/index.js index 5ad6531..158a693 100644 --- a/packages/rwkv-cpp/index.js +++ b/packages/rwkv-cpp/index.js @@ -252,9 +252,7 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { TokenizeResultType, InferenceResultType, EmbeddingResultType, Rwkv } = nativeBinding +const { InferenceResultType, Rwkv } = nativeBinding -module.exports.TokenizeResultType = TokenizeResultType module.exports.InferenceResultType = InferenceResultType -module.exports.EmbeddingResultType = EmbeddingResultType module.exports.Rwkv = Rwkv diff --git a/packages/rwkv-cpp/src/lib.rs b/packages/rwkv-cpp/src/lib.rs index 23a3496..5e6a260 100644 --- a/packages/rwkv-cpp/src/lib.rs +++ b/packages/rwkv-cpp/src/lib.rs @@ -8,10 +8,7 @@ mod rwkv; mod sampling; mod types; -use std::{ - sync::{mpsc::channel, Arc}, - thread, time, -}; +use std::sync::Arc; use context::RWKVInvocation; use napi::{ @@ -21,18 +18,19 @@ use napi::{ }, JsFunction, }; -use rwkv::RWKVChannel; -use types::{EmbeddingResult, InferenceResult, TokenizeResult}; +use rwkv::RWKVInternal; +use tokio::sync::Mutex; +use types::InferenceResult; #[napi] pub struct RWKV { - rwkv_channel: Arc, + rwkv: Arc>, } #[napi] impl RWKV { #[napi] - pub fn load( + pub async fn load( model_path: String, tokenizer_path: String, n_threads: u32, @@ -45,127 +43,35 @@ impl RWKV { .init(); } - let (load_result_sender, load_result_receiver) = channel::(); - let rwkv_channel = RWKVChannel::new( - model_path, - tokenizer_path, - n_threads, - load_result_sender, - enable_logger, - ); - 'waiting_load: loop { - let recv = load_result_receiver.recv(); - match recv { - Ok(r) => { - if !r { - return Err(Error::new(Status::InvalidArg, "Load error".to_string())); - } - break 'waiting_load; - } - _ => { - thread::yield_now(); - } - } - } - Ok(Self { rwkv_channel }) + Ok(Self { + rwkv: RWKVInternal::load(model_path, tokenizer_path, n_threads, enable_logger).await, + }) } #[napi] - pub fn get_word_embedding( - &self, - input: RWKVInvocation, - #[napi(ts_arg_type = "(result: EmbeddingResult) => void")] callback: JsFunction, - ) -> Result<()> { - let tsfn: ThreadsafeFunction = - callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let (embeddings_sender, embeddings_receiver) = channel(); - let rwkv_channel = self.rwkv_channel.clone(); - - rwkv_channel.embedding(input, embeddings_sender); - - thread::spawn(move || { - loop { - let result = embeddings_receiver.recv(); - match result { - Ok(result) => { - tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); - } - Err(_) => { - break; - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) - } - - #[napi] - pub fn tokenize( - &self, - params: String, - #[napi(ts_arg_type = "(result: TokenizeResult) => void")] callback: JsFunction, - ) -> Result<()> { - let (tokenize_sender, tokenize_receiver) = channel::(); - - let tsfn: ThreadsafeFunction = callback - .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { - Ok(vec![ctx.value]) - })?; - - let rwkv_channel = self.rwkv_channel.clone(); - - rwkv_channel.tokenize(params, tokenize_sender); - - thread::spawn(move || { - 'waiting_tokenize: loop { - let recv = tokenize_receiver.recv(); - match recv { - Ok(callback) => { - tsfn.call(callback, ThreadsafeFunctionCallMode::Blocking); - break 'waiting_tokenize; - } - _ => { - thread::yield_now(); - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); - }); - - Ok(()) + pub async fn tokenize(&self, params: String) -> Result> { + let rwkv = self.rwkv.lock().await; + rwkv.tokenize(¶ms).await } #[napi] pub fn inference( &self, - input: RWKVInvocation, + params: RWKVInvocation, #[napi(ts_arg_type = "(result: InferenceResult) => void")] callback: JsFunction, ) -> Result<()> { - let tsfn: ThreadsafeFunction = - callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let (inference_sender, inference_receiver) = channel(); - let rwkv_channel = self.rwkv_channel.clone(); - - rwkv_channel.inference(input, inference_sender); + let tsfn: ThreadsafeFunction = callback + .create_threadsafe_function(0, |ctx: ThreadSafeCallContext| { + Ok(vec![ctx.value]) + })?; - thread::spawn(move || { - loop { - let result = inference_receiver.recv(); - match result { - Ok(result) => { - tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); - } - Err(_) => { - break; - } - } - } - thread::sleep(time::Duration::from_millis(300)); // wait for end signal - tsfn.abort().unwrap(); + let rwkv = self.rwkv.clone(); + tokio::spawn(async move { + let mut rwkv = rwkv.lock().await; + rwkv.inference(¶ms, |result| { + tsfn.call(result, ThreadsafeFunctionCallMode::NonBlocking); + }) + .await; }); Ok(()) diff --git a/packages/rwkv-cpp/src/rwkv.rs b/packages/rwkv-cpp/src/rwkv.rs index 7afd70c..21c188b 100644 --- a/packages/rwkv-cpp/src/rwkv.rs +++ b/packages/rwkv-cpp/src/rwkv.rs @@ -1,92 +1,47 @@ -use std::{ - sync::{ - mpsc::{channel, Receiver, Sender, TryRecvError}, - Arc, Mutex, - }, - thread, -}; +use std::sync::Arc; + +use tokio::sync::Mutex; use crate::{ context::{RWKVContext, RWKVInvocation}, sampling::sample_logits, - types::{ - EmbeddingResult, InferenceResult, InferenceResultType, InferenceToken, RWKVCommand, - TokenizeResult, TokenizeResultType, - }, + types::{InferenceResult, InferenceResultType, InferenceToken}, }; -#[derive(Clone)] -pub struct RWKVChannel { - command_sender: Sender, - command_receiver: Arc>>, -} - #[derive(Clone)] pub struct RWKVInternal { context: RWKVContext, } impl RWKVInternal { - pub fn tokenize(&self, input: &str, sender: &Sender) { + pub async fn load( + mode_path: String, + tokenizer_path: String, + n_threads: u32, + enable_logger: bool, + ) -> Arc> { + let rwkv = RWKVInternal { + context: RWKVContext::new(&mode_path, &tokenizer_path, n_threads), + }; + + if enable_logger { + rwkv.context.rwkv_print_system_info_string(); + } + + Arc::new(Mutex::new(rwkv)) + } + pub async fn tokenize(&self, input: &str) -> Result, napi::Error> { let tokenizer = &self.context.tokenizer; let tokens_result = tokenizer.encode(input, false).map(Some).unwrap_or(None); if let Some(result) = tokens_result { let tokens = result.get_ids().to_vec(); - sender - .send(TokenizeResult { - r#type: TokenizeResultType::Data, - data: tokens.iter().map(|x| *x as i32).collect(), - }) - .unwrap(); + Ok(tokens.iter().map(|x| *x as i32).collect()) } else { - sender - .send(TokenizeResult { - r#type: TokenizeResultType::Error, - data: vec![], - }) - .unwrap(); + Err(napi::Error::from_reason("Failed to tokenize")) } } - /* pub fn embedding(&self, input: &LlamaInvocation, sender: &Sender) { - let context_params_c = LlamaContextParams::or_default(&self.context_params); - let input_ctx = &self.context; - let embd_inp = tokenize( - input_ctx, - input.prompt.as_str(), - context_params_c.n_ctx as usize, - true, - ) - .unwrap(); - - // let end_text = "\n"; - // let end_token = - // tokenize(input_ctx, end_text, context_params_c.n_ctx as usize, false).unwrap(); - - input_ctx - .llama_eval(embd_inp.as_slice(), embd_inp.len() as i32, 0, input) - .unwrap(); - - let embeddings = input_ctx.llama_get_embeddings(); - - if let Ok(embeddings) = embeddings { - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Data, - data: embeddings.iter().map(|&x| x as f64).collect(), - }) - .unwrap(); - } else { - sender - .send(EmbeddingResult { - r#type: EmbeddingResultType::Error, - data: vec![], - }) - .unwrap(); - } - } */ - - pub fn inference(&mut self, input: &RWKVInvocation, sender: &Sender) { + pub async fn inference(&mut self, input: &RWKVInvocation, callback: impl Fn(InferenceResult)) { let end_token = input.end_token.unwrap_or(0) as usize; let temp = input.temp as f32; let top_p = input.top_p as f32; @@ -96,7 +51,11 @@ impl RWKVInternal { let tokenizer = &context.tokenizer; let prompt = &input.prompt; let binding = tokenizer.encode(prompt.as_str(), false).unwrap(); - let tokens = binding.get_ids().iter().map(|x| *x as i32).collect::>(); + let tokens = binding + .get_ids() + .iter() + .map(|x| *x as i32) + .collect::>(); let mut session = context.create_new_session(); @@ -109,16 +68,14 @@ impl RWKVInternal { let token = sample_logits(logits, temp, top_p, &seed); if token >= 50276 || token == end_token { - sender - .send(InferenceResult { - r#type: InferenceResultType::Data, - message: None, - data: Some(InferenceToken { - token: "\n\n\n".to_string(), - completed: true, - }), - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Data, + message: None, + data: Some(InferenceToken { + token: "\n\n\n".to_string(), + completed: true, + }), + }); return; } @@ -128,111 +85,17 @@ impl RWKVInternal { if !decoded.contains('\u{FFFD}') { accumulated_token.clear(); - sender - .send(InferenceResult { - r#type: InferenceResultType::Data, - message: None, - data: Some(InferenceToken { - token: decoded, - completed: false, - }), - }) - .unwrap(); + callback(InferenceResult { + r#type: InferenceResultType::Data, + message: None, + data: Some(InferenceToken { + token: decoded, + completed: false, + }), + }); } session.process_tokens(&[token.try_into().unwrap()]); } } } - -impl RWKVChannel { - pub fn new( - model_path: String, - tokenizer_path: String, - n_threads: u32, - load_result_sender: Sender, - enable_logger: bool, - ) -> Arc { - let (command_sender, command_receiver) = channel::(); - - let channel = RWKVChannel { - command_receiver: Arc::new(Mutex::new(command_receiver)), - command_sender, - }; - - channel.spawn( - model_path, - tokenizer_path, - n_threads, - load_result_sender, - enable_logger, - ); - - Arc::new(channel) - } - - pub fn tokenize(&self, input: String, sender: Sender) { - self.command_sender - .send(RWKVCommand::Tokenize(input, sender)) - .unwrap(); - } - - pub fn embedding(&self, params: RWKVInvocation, sender: Sender) { - self.command_sender - .send(RWKVCommand::Embedding(params, sender)) - .unwrap(); - } - - pub fn inference(&self, params: RWKVInvocation, sender: Sender) { - self.command_sender - .send(RWKVCommand::Inference(params, sender)) - .unwrap(); - } - - // rwkv instance main loop - pub fn spawn( - &self, - mode_path: String, - tokenizer_path: String, - n_threads: u32, - load_result_sender: Sender, - enable_logger: bool, - ) { - let rv = self.command_receiver.clone(); - - thread::spawn(move || { - let mut rwkv = RWKVInternal { - context: RWKVContext::new(&mode_path, &tokenizer_path, n_threads), - }; - - if enable_logger { - rwkv.context.rwkv_print_system_info_string(); - } - - load_result_sender.send(true).unwrap(); - - let rv = rv.lock().unwrap(); - - 'rwkv_loop: loop { - let command = rv.try_recv(); - match command { - Ok(RWKVCommand::Inference(params, sender)) => { - rwkv.inference(¶ms, &sender); - } - // Ok(RWKVCommand::Embedding(params, sender)) => { - // rwkv.embedding(¶ms, &sender); - // } - Ok(RWKVCommand::Tokenize(text, sender)) => { - rwkv.tokenize(&text, &sender); - } - Err(TryRecvError::Disconnected) => { - break 'rwkv_loop; - } - _ => { - thread::yield_now(); - } - } - } - }); - } -} diff --git a/packages/rwkv-cpp/src/types.rs b/packages/rwkv-cpp/src/types.rs index e5f38aa..cab002d 100644 --- a/packages/rwkv-cpp/src/types.rs +++ b/packages/rwkv-cpp/src/types.rs @@ -1,25 +1,4 @@ -use crate::context::RWKVInvocation; use napi::bindgen_prelude::*; -use std::sync::mpsc::Sender; - -#[derive(Clone, Debug)] -pub enum RWKVCommand { - Inference(RWKVInvocation, Sender), - Tokenize(String, Sender), - Embedding(RWKVInvocation, Sender), -} - -#[napi(string_enum)] -pub enum TokenizeResultType { - Error, - Data, -} - -#[napi(object)] -pub struct TokenizeResult { - pub r#type: TokenizeResultType, - pub data: Vec, -} #[napi(object)] #[derive(Clone, Debug)] @@ -41,15 +20,3 @@ pub struct InferenceResult { pub data: Option, pub message: Option, } - -#[napi(string_enum)] -pub enum EmbeddingResultType { - Error, - Data, -} - -#[napi(object)] -pub struct EmbeddingResult { - pub r#type: EmbeddingResultType, - pub data: Vec, -} diff --git a/src/llm/llama-cpp.ts b/src/llm/llama-cpp.ts index 4cb0ba6..5ece954 100644 --- a/src/llm/llama-cpp.ts +++ b/src/llm/llama-cpp.ts @@ -1,10 +1,8 @@ import { - EmbeddingResultType, InferenceResultType, LLama, LlamaContextParams, LlamaInvocation, - TokenizeResultType, } from "@llama-node/llama-cpp"; import { type ILLM, type LLMResult, LLMError } from "./type"; @@ -31,9 +29,9 @@ export class LLamaCpp { instance!: LLama; - load(config: LoadConfig) { + async load(config: LoadConfig) { const { path, enableLogging, ...rest } = config; - this.instance = LLama.load(path, rest, enableLogging); + this.instance = await LLama.load(path, rest, enableLogging); } async createCompletion( @@ -84,18 +82,7 @@ export class LLamaCpp } async getEmbedding(params: LlamaInvocation): Promise { - return new Promise((res, rej) => { - this.instance.getWordEmbedding(params, (response) => { - switch (response.type) { - case EmbeddingResultType.Data: - res(response.data ?? []); - break; - case EmbeddingResultType.Error: - rej(new Error("Unknown Error")); - break; - } - }); - }); + return await this.instance.getWordEmbedding(params); } async getDefaultEmbedding(text: string): Promise { @@ -111,14 +98,6 @@ export class LLamaCpp } async tokenize(params: TokenizeArguments): Promise { - return new Promise((res, rej) => { - this.instance.tokenize(params.content, params.nCtx, (response) => { - if (response.type === TokenizeResultType.Data) { - res(response.data); - } else { - rej(new Error("Unknown Error")); - } - }); - }); + return await this.instance.tokenize(params.content, params.nCtx); } } diff --git a/src/llm/llama-rs.ts b/src/llm/llama-rs.ts index f0eb11b..c4501c2 100644 --- a/src/llm/llama-rs.ts +++ b/src/llm/llama-rs.ts @@ -1,5 +1,4 @@ import { - EmbeddingResultType, InferenceResultType, LLama, LLamaConfig, @@ -20,8 +19,8 @@ export class LLamaRS { instance!: LLama; - load(config: LLamaConfig) { - this.instance = LLama.create(config); + async load(config: LLamaConfig) { + this.instance = await LLama.create(config); } async createCompletion( @@ -72,18 +71,7 @@ export class LLamaRS } async getEmbedding(params: LLamaInferenceArguments): Promise { - return new Promise((res, rej) => { - this.instance.getWordEmbeddings(params, (response) => { - switch (response.type) { - case EmbeddingResultType.Data: - res(response.data ?? []); - break; - case EmbeddingResultType.Error: - rej(response.message); - break; - } - }); - }); + return await this.instance.getWordEmbeddings(params); } async getDefaultEmbedding(text: string): Promise { @@ -99,10 +87,6 @@ export class LLamaRS } async tokenize(params: string): Promise { - return new Promise((res) => { - this.instance.tokenize(params, (response) => { - res(response.data); - }); - }); + return await this.instance.tokenize(params); } } diff --git a/src/llm/rwkv-cpp.ts b/src/llm/rwkv-cpp.ts index d13c9e1..626a195 100644 --- a/src/llm/rwkv-cpp.ts +++ b/src/llm/rwkv-cpp.ts @@ -3,7 +3,6 @@ import { InferenceResultType, Rwkv, RwkvInvocation, - TokenizeResultType, } from "@llama-node/rwkv-cpp"; import { type ILLM, type LLMResult, LLMError } from "./type"; @@ -25,9 +24,9 @@ export class RwkvCpp { instance!: Rwkv; - load(config: LoadConfig) { + async load(config: LoadConfig) { const { modelPath, tokenizerPath, nThreads, enableLogging } = config; - this.instance = Rwkv.load( + this.instance = await Rwkv.load( modelPath, tokenizerPath, nThreads, @@ -109,14 +108,6 @@ export class RwkvCpp } */ async tokenize(params: TokenizeArguments): Promise { - return new Promise((res, rej) => { - this.instance.tokenize(params.content, (response) => { - if (response.type === TokenizeResultType.Data) { - res(response.data); - } else { - rej(new Error("Unknown Error")); - } - }); - }); + return await this.instance.tokenize(params.content); } } diff --git a/src/llm/type.ts b/src/llm/type.ts index 3fae491..50802d6 100644 --- a/src/llm/type.ts +++ b/src/llm/type.ts @@ -11,7 +11,7 @@ export interface ILLM< > { readonly instance: Instance; - load(config: LoadConfig): void; + load(config: LoadConfig): Promise; createCompletion( params: LLMInferenceArguments,