diff --git a/.gitmodules b/.gitmodules index 7029b7ed..38f0d41d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "crates/llm-chain-llama/sys/llama.cpp"] path = crates/llm-chain-llama-sys/llama.cpp url = https://github.com/ggerganov/llama.cpp.git +[submodule "crates/llm-chain-gemma-sys/gemma.cpp"] + path = crates/llm-chain-gemma-sys/gemma.cpp + url = https://github.com/google/gemma.cpp.git diff --git a/Cargo.lock b/Cargo.lock index 29d6965b..5f000c71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,12 +698,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] +checksum = "3286b845d0fccbdd15af433f61c5970e711987036cb468f437ff6badd70f4e24" [[package]] name = "cexpr" @@ -1800,6 +1797,24 @@ dependencies = [ "uuid", ] +[[package]] +name = "llm-chain-gemma" +version = "0.1.0" +dependencies = [ + "async-trait", + "llm-chain", + "llm-chain-gemma-sys", + "thiserror", + "tokio", +] + +[[package]] +name = "llm-chain-gemma-sys" +version = "0.1.0" +dependencies = [ + "cc", +] + [[package]] name = "llm-chain-hnsw" version = "0.13.0" @@ -2102,9 +2117,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", diff --git a/crates/llm-chain-gemma-sys/Cargo.toml b/crates/llm-chain-gemma-sys/Cargo.toml new file mode 100644 index 00000000..58c22a20 --- /dev/null +++ b/crates/llm-chain-gemma-sys/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "llm-chain-gemma-sys" +description = "A library with bindings for gemma.cpp" +version = "0.1.0" +edition = "2021" +license = "MIT" +keywords = ["llm", "langchain", "gemma", "chain"] +categories = ["science"] +authors = [ + "Jun Mukai ", +] +repository = "https://github.com/sobelio/llm-chain/" +readme = "README.md" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[build-dependencies] +cc = "1.0.87" diff --git a/crates/llm-chain-gemma-sys/build.rs b/crates/llm-chain-gemma-sys/build.rs new file mode 100644 index 00000000..e1064a75 --- /dev/null +++ b/crates/llm-chain-gemma-sys/build.rs @@ -0,0 +1,130 @@ +#![allow(clippy::uninlined_format_args)] + +extern crate cc; + +use std::env; + +fn main() { + #[cfg(target_os = "windows")] + { + // Gemma.cpp does not support MSBuild at this point -- + // it does support clang-cl though. At this time, Windows + // is out of the support because of this. + // See: https://github.com/google/gemma.cpp/pull/6 + cc::Build::new() + .cpp(true) + .file("src/bindings_win.cc") + .std("c++17") + .compile("bindings"); + return; + } + let target = env::var("TARGET").unwrap(); + // Link C++ standard library + if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { + println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); + println!("cargo:rustc-link-arg=-l{}", cpp_stdlib); + } + // Link macOS Accelerate framework for matrix calculations + if target.contains("apple") { + println!("cargo:rustc-link-lib=framework=Accelerate"); + } + println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap()); + println!("cargo:rustc-link-lib=static=gemma"); + println!("cargo:rustc-link-lib=static=hwy"); + println!("cargo:rustc-link-lib=static=hwy_contrib"); + println!("cargo:rustc-link-lib=static=sentencepiece"); + println!("cargo:rustc-link-lib=static=bindings"); + println!("cargo:rerun-if-changed=wrapper.h"); + + // stop if we're on docs.rs + if env::var("DOCS_RS").is_ok() { + return; + } + + // Run cmake to generate build files. + env::set_current_dir("gemma.cpp").expect("Unable to change directory to gemma.cpp"); + env::set_current_dir("build").expect("Unable to change directory to gemma.cpp build"); + + env::set_var("CXXFLAGS", "-fPIC"); + env::set_var("CFLAGS", "-fPIC"); + + let mut code = std::process::Command::new("cmake"); + let code = code + .arg("..") + .arg("-DCMAKE_BUILD_TYPE=Release") + .arg("-DBUILD_SHARED_LIBS=OFF") + .arg("-DWEIGHT_TYPE=hwy::bfloat16_t") + .arg("-DSPM_ENABLE_SHARED=OFF"); + let code = code.status().expect("Failed to generate build script"); + if code.code() != Some(0) { + panic!("Failed to generate build script"); + } + + // Build binary. + #[allow(clippy::suspicious_command_arg_space)] + let code = std::process::Command::new("cmake") + .arg("--build") + .arg(".") + .arg("--config") + .arg("Release") + .arg("--target") + .arg("libgemma") + .status() + .expect("Failed to build lib"); + if code.code() != Some(0) { + panic!("Failed to build lib"); + } + + // move libllama.a to where Cargo expects it (OUT_DIR) + std::fs::copy( + "libgemma.a", + format!("{}/libgemma.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy lib"); + + std::fs::copy( + "_deps/highway-build/libhwy.a", + format!("{}/libhwy.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libhwy.a"); + + std::fs::copy( + "_deps/highway-build/libhwy_contrib.a", + format!("{}/libhwy_contrib.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libhwy_contrib.a"); + + std::fs::copy( + "_deps/sentencepiece-build/src/libsentencepiece.a", + format!("{}/libsentencepiece.a", env::var("OUT_DIR").unwrap()), + ) + .expect("Failed to copy libsentencepiece.a"); + + // Finally, build bindings.cc to allow access for gemma.cpp. + // So far, bindgen does not correctly generate buildable rust file, + // so I manually wrote bindings.rs for hand-written src/bindings.cc file. + env::set_current_dir("..").expect("Unlable to change directory back to gemma.cpp"); + env::set_current_dir("..").expect("Unlable to change directory back to crate top"); + + cc::Build::new() + .cpp(true) + .file("src/bindings.cc") + .std("c++17") + .include("./gemma.cpp") + .include("./gemma.cpp/build/_deps/highway-src") + .include("./gemma.cpp/build/_deps/sentencepiece-src") + .compile("bindings"); +} + +// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 +fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { + if target.contains("msvc") { + None + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + Some("c++") + } else if target.contains("android") { + Some("c++_shared") + } else { + Some("stdc++") + } +} diff --git a/crates/llm-chain-gemma-sys/gemma.cpp b/crates/llm-chain-gemma-sys/gemma.cpp new file mode 160000 index 00000000..0221956b --- /dev/null +++ b/crates/llm-chain-gemma-sys/gemma.cpp @@ -0,0 +1 @@ +Subproject commit 0221956b2e4fb5ec65d3685fad09f257cf5700e7 diff --git a/crates/llm-chain-gemma-sys/src/bindings.cc b/crates/llm-chain-gemma-sys/src/bindings.cc new file mode 100644 index 00000000..73b54d62 --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/bindings.cc @@ -0,0 +1,124 @@ +#include + +extern "C" { + +hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) { + return new hwy::ThreadPool(num_threads); +} + +void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) { + delete pool; +} + +gcpp::Gemma* gcpp_Gemma_Gemma( + const char* tokenizer_path, size_t tokenizer_path_len, + const char* compressed_weights_path, size_t compressed_weights_path_len, + const char* weights_path, size_t weights_path_len, + gcpp::Model model_type, hwy::ThreadPool* pool) { + gcpp::Path tpath; + tpath.path = std::string(tokenizer_path, tokenizer_path_len); + gcpp::Path cwpath; + cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len); + gcpp::Path wpath; + wpath.path = std::string(weights_path, weights_path_len); + return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool); +} + +void gcpp_Gemma_destructor(gcpp::Gemma* gemma) { + delete gemma; +} + +void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) { + gemma->model_training = training; +} + +gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) { + gcpp::KVCache* cache = new gcpp::KVCache{}; + *cache = gcpp::CreateKVCache(model_type); + return cache; +} + +void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) { + delete kvcache; +} + +std::vector* std_vector_int_vector() { + return new std::vector(); +} + +void std_vector_int_destructor(std::vector* v) { + delete v; +} + +size_t std_vector_int_size(const std::vector* v) { + return v->size(); +} + +int std_vector_int_at(const std::vector* v, size_t i) { + return v->at(i); +} + +std::string* std_string_string() { + return new std::string(); +} + +void std_string_destructor(std::string* s) { + delete s; +} + +const char* std_string_c_str(const std::string* s) { + return s->c_str(); +} + +bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector* out) { + return gemma->Tokenizer()->Encode(std::string(input, len), out).ok(); +} + +bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) { + return gemma->Tokenizer()->Decode(std::vector{token}, out).ok(); +} + +bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) { + std::vector v; + v.reserve(num_tokens); + for (int i = 0; i < num_tokens; i++) { + v.push_back(tokens[i]); + } + return gemma->Tokenizer()->Decode(v, out).ok(); +} + +std::mt19937* std_mt19937_mt19937() { + return new std::mt19937(); +} + +void std_mt19937_destructor(std::mt19937* gen) { + delete gen; +} + +void std_mt19937_seed(std::mt19937* gen, int seed) { + gen->seed(seed); +} + +void std_mt19937_random_seed(std::mt19937* gen) { + std::random_device rd; + gen->seed(rd()); +} + +typedef bool (*stream_callback)(void*, int, float); +typedef bool (*accept_callback)(void*, int); + +void gcpp_GenerateGemma( + gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config, + const std::vector* prompt, size_t start_pos, + gcpp::KVCache* kvcache, hwy::ThreadPool* pool, + void* stream_context, stream_callback stream_token, + std::mt19937* gen) { + gcpp::GenerateGemma( + *gemma, *config, *prompt, start_pos, *kvcache, *pool, + [&stream_context, &stream_token](int token, float value) { + return stream_token(stream_context, token, value); + }, + *gen); +} + +} \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/bindings.rs b/crates/llm-chain-gemma-sys/src/bindings.rs new file mode 100644 index 00000000..ab5c2cd1 --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/bindings.rs @@ -0,0 +1,162 @@ +use std::ffi; + +pub type gcpp_Model = ffi::c_int; +pub const gcpp_Model_GEMMA_2B: gcpp_Model = 0; +pub const gcpp_Model_GEMMA_7B: gcpp_Model = 1; + +pub type gcpp_ModelTraining = ffi::c_int; +pub const gcpp_ModelTraining_GEMMA_IT: gcpp_ModelTraining = 0; +pub const gcpp_ModelTraining_GEMMA_PT: gcpp_ModelTraining = 1; + +pub const EOS_ID: i32 = 1; + +#[repr(C)] +pub struct gcpp_RuntimeConfig { + pub max_tokens: ffi::c_uint, + pub max_generated_tokens: ffi::c_uint, + pub temperature: ffi::c_float, + pub verbosity: ffi::c_int, +} + +#[repr(C)] +pub struct hwy_ThreadPool { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn hwy_ThreadPool_ThreadPool(num_threads: ffi::c_uint) -> *mut hwy_ThreadPool; + pub fn hwy_ThreadPool_destructor(pool: *mut hwy_ThreadPool); +} + +#[repr(C)] +pub struct gcpp_Gemma { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_Gemma_Gemma( + tokenizer_path: *const ffi::c_char, tokenizer_path_len: ffi::c_uint, + compressed_weights_path: *const ffi::c_char, compressed_weights_path_len: ffi::c_uint, + weights_path: *const ffi::c_char, weights_path_len: ffi::c_uint, + model_type: gcpp_Model, pool: *mut hwy_ThreadPool) -> *mut gcpp_Gemma; + pub fn gcpp_Gemma_destructor(gemma: *mut gcpp_Gemma); + pub fn gcpp_Gemma_SetModelTraining(gemma: *mut gcpp_Gemma, training: gcpp_ModelTraining); + pub fn gcpp_Gemma_Encode(gemma: *mut gcpp_Gemma, input: *mut ffi::c_char, len: ffi::c_uint, out: *mut std_vector_int) -> ffi::c_char; + pub fn gcpp_Gemma_Decode(gemma: *mut gcpp_Gemma, token: ffi::c_int, out: *mut std_string) -> ffi::c_char; + pub fn gcpp_Gemma_Decodes(gemma: *mut gcpp_Gemma, tokens: *const ffi::c_int, num_tokens: ffi::c_int, out: *mut std_string) -> ffi::c_char; + + pub fn gcpp_GenerateGemma( + gemma: *mut gcpp_Gemma, config: *const gcpp_RuntimeConfig, + prompt: *const std_vector_int, start_pos: ffi::c_uint, + kvcache: *mut gcpp_KVCache, pool: *mut hwy_ThreadPool, + stream_context: *mut ffi::c_void, + stream_token: extern fn(*mut ffi::c_void, ffi::c_int, ffi::c_float) -> ffi::c_char, + gen: *mut std_mt19937, + ); +} + +#[repr(C)] +pub struct gcpp_KVCache { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn gcpp_CreateKVCache(model_type: gcpp_Model) -> *mut gcpp_KVCache; + pub fn gcpp_KVCache_destructor(cache: *mut gcpp_KVCache); +} + +#[repr(C)] +pub struct std_vector_int { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_vector_int_vector() -> *mut std_vector_int; + pub fn std_vector_int_destructor(v: *mut std_vector_int); + pub fn std_vector_int_size(v: *const std_vector_int) -> ffi::c_uint; + pub fn std_vector_int_at(v: *const std_vector_int, i: ffi::c_uint) -> ffi::c_int; +} + +pub struct std_vector_int_iter { + v: *mut std_vector_int, + i: ffi::c_uint, +} + +impl std_vector_int_iter { + pub fn new(v: *mut std_vector_int) -> std_vector_int_iter { + std_vector_int_iter{ + v: v, + i: 0, + } + } +} + +impl ExactSizeIterator for std_vector_int_iter { + fn len(&self) -> usize { + unsafe { std_vector_int_size(self.v) as usize } + } +} + +impl Iterator for std_vector_int_iter { + type Item = i32; + + fn next(&mut self) -> Option { + unsafe { + if self.i < std_vector_int_size(self.v) { + let v = std_vector_int_at(self.v, self.i); + self.i += 1; + Some(v as i32) + } else { + None + } + } + } +} + +#[repr(C)] +pub struct std_string { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_string_string() -> *mut std_string; + pub fn std_string_destructor(s: *mut std_string); + pub fn std_string_c_str(s: *const std_string) -> *mut ffi::c_char; +} + +#[repr(C)] +pub struct std_mt19937 { + _data: [u8; 0], + _marker: + core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, +} + +extern "C" { + pub fn std_mt19937_mt19937() -> *mut std_mt19937; + pub fn std_mt19937_destructor(gen: *mut std_mt19937); + pub fn std_mt19937_seed(gen: *mut std_mt19937, seed: ffi::c_int); + pub fn std_mt19937_random_seed(gen: *mut std_mt19937); +} + +#[cfg(test)] +mod test { + use crate::*; + + #[test] + fn create_and_delete_pool() { + unsafe { + let pool = hwy_ThreadPool_ThreadPool(1); + hwy_ThreadPool_destructor(pool); + } + } +} \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/bindings_win.cc b/crates/llm-chain-gemma-sys/src/bindings_win.cc new file mode 100644 index 00000000..a70927cd --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/bindings_win.cc @@ -0,0 +1,95 @@ +#include + +extern "C" { + +void* hwy_ThreadPool_ThreadPool(size_t num_threads) { + return nullptr; +} + +void hwy_ThreadPool_destructor(void* pool) { +} + +void* gcpp_Gemma_Gemma( + const char* tokenizer_path, size_t tokenizer_path_len, + const char* compressed_weights_path, size_t compressed_weights_path_len, + const char* weights_path, size_t weights_path_len, + int model_type, void* pool) { + return nullptr; +} + +void gcpp_Gemma_destructor(void* gemma) { +} + +void gcpp_Gemma_SetModelTraining(void* gemma, int training) { +} + +void* gcpp_CreateKVCache(int model_type) { + return nullptr; +} + +void gcpp_KVCache_destructor(void* kvcache) { +} + +void* std_vector_int_vector() { + return nullptr; +} + +void std_vector_int_destructor(void* v) { +} + +size_t std_vector_int_size(const void* v) { + return 0; +} + +int std_vector_int_at(void* v, size_t i) { + return 0; +} + +void* std_string_string() { + return nullptr; +} + +void std_string_destructor(void* s) { +} + +const char* std_string_c_str(const void* s) { + return nullptr; +} + +bool gcpp_Gemma_Encode(void* gemma, const char* input, size_t len, void* out) { + return false; +} + +bool gcpp_Gemma_Decode(void* gemma, int token, void* out) { + return false; +} + +bool gcpp_Gemma_Decodes(void* gemma, const int* tokens, int num_tokens, void* out) { + return false; +} + +void* std_mt19937_mt19937() { + return nullptr; +} + +void std_mt19937_destructor(void* gen) { +} + +void std_mt19937_seed(void* gen, int seed) { +} + +void std_mt19937_random_seed(void* gen) { +} + +typedef bool (*stream_callback)(void*, int, float); +typedef bool (*accept_callback)(void*, int); + +void gcpp_GenerateGemma( + void* gemma, const void* config, + const void* prompt, size_t start_pos, + void* kvcache, void* pool, + void* stream_context, stream_callback stream_token, + void* gen) { +} + +} \ No newline at end of file diff --git a/crates/llm-chain-gemma-sys/src/lib.rs b/crates/llm-chain-gemma-sys/src/lib.rs new file mode 100644 index 00000000..05d92f29 --- /dev/null +++ b/crates/llm-chain-gemma-sys/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!("./bindings.rs"); diff --git a/crates/llm-chain-gemma/Cargo.toml b/crates/llm-chain-gemma/Cargo.toml new file mode 100644 index 00000000..d65f8313 --- /dev/null +++ b/crates/llm-chain-gemma/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "llm-chain-gemma" +description = "The llm-chain implementation for Gemma." +version = "0.1.0" +edition = "2021" +license = "MIT" +keywords = ["llm", "langchain", "gemma", "chain"] +categories = ["science"] +authors = [ + "Jun Mukai ", +] +readme = "./README.md" +repository = "https://github.com/sobelio/llm-chain/" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait.workspace = true +llm-chain = {path = "../llm-chain", version="0.13.0"} +llm-chain-gemma-sys = {path = "../llm-chain-gemma-sys", version="0.1.0"} +tokio.workspace = true +thiserror.workspace = true diff --git a/crates/llm-chain-gemma/examples/simple.rs b/crates/llm-chain-gemma/examples/simple.rs new file mode 100644 index 00000000..777c744e --- /dev/null +++ b/crates/llm-chain-gemma/examples/simple.rs @@ -0,0 +1,54 @@ +use llm_chain::options; +use llm_chain::options::ModelRef; +use llm_chain::{executor, parameters, prompt}; +use std::env::args; +use std::path::Path; + +/// This example demonstrates how to use the llm-chain-gemma crate to generate text using a +/// Gemma. +/// +/// Usage: cargo run --example simple path/to/model prompt +/// +/// Note: gemma requires 2 files to load, one for the model itself and the other is for +/// sentencepiece. Currently it assumes both resides in the same directory, and the +/// sentencepiece file name is tokenizer.sbm + +fn get_model_type(model_path: &str) -> &str { + let p = Path::new(model_path); + if let Some(stem) = p.file_stem() { + if let Some(model_type) = stem.to_str() { + return model_type; + } + } + "2b-it" +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let raw_args: Vec = args().collect(); + let args = match &raw_args.len() { + 2 => ( + raw_args[1].as_str(), + "Rust is a cool programming language because", + ), + 3 => (raw_args[1].as_str(), raw_args[2].as_str()), + _ => { + panic!("Usage: cargo run --release --example simple ") + } + }; + + let model_path = args.0; + let prompt = args.1; + let opts = options!( + Model: ModelRef::from_path(model_path), + ModelType: get_model_type(model_path), + ModelType: "gemma", + Temperature: 0.8 + ); + let exec = executor!(gemma, opts.clone())?; + + let res = prompt!(prompt).run(¶meters!(), &exec).await?; + + println!("{}", res.to_immediate().await?); + Ok(()) +} diff --git a/crates/llm-chain-gemma/src/context.rs b/crates/llm-chain-gemma/src/context.rs new file mode 100644 index 00000000..e7c78168 --- /dev/null +++ b/crates/llm-chain-gemma/src/context.rs @@ -0,0 +1,285 @@ +use llm_chain::options::{Opt, OptDiscriminants, Options}; +use llm_chain::output::StreamSegment; +use llm_chain::tokens::{TokenCollection, Tokenizer, TokenizerError}; +use llm_chain::traits::ExecutorCreationError; +use llm_chain_gemma_sys::{ + gcpp_CreateKVCache, gcpp_Gemma, gcpp_Gemma_Decode, gcpp_Gemma_Decodes, gcpp_Gemma_Encode, + gcpp_Gemma_Gemma, gcpp_Gemma_SetModelTraining, gcpp_Gemma_destructor, gcpp_GenerateGemma, + gcpp_KVCache, gcpp_KVCache_destructor, gcpp_Model, gcpp_ModelTraining, + gcpp_ModelTraining_GEMMA_IT, gcpp_ModelTraining_GEMMA_PT, gcpp_Model_GEMMA_2B, + gcpp_Model_GEMMA_7B, gcpp_RuntimeConfig, hwy_ThreadPool, hwy_ThreadPool_ThreadPool, + hwy_ThreadPool_destructor, std_mt19937, std_mt19937_destructor, std_mt19937_mt19937, + std_mt19937_random_seed, std_string_c_str, std_string_destructor, std_string_string, + std_vector_int_destructor, std_vector_int_iter, std_vector_int_size, std_vector_int_vector, + EOS_ID, +}; +use std::ffi; +use std::path::Path; +use tokio::sync::mpsc; + +#[derive(thiserror::Error, Debug)] +#[error("Gemma.cpp is not supported")] +pub struct GemmaNotSupportedError {} + +pub struct GemmaContext { + gemma: *mut gcpp_Gemma, + model_training: gcpp_ModelTraining, + gen: *mut std_mt19937, + pub config: gcpp_RuntimeConfig, + kvcache: *mut gcpp_KVCache, + pool: *mut hwy_ThreadPool, + pos: u32, +} + +impl GemmaContext { + pub fn new(options: &Options) -> Result { + let mut model_type: gcpp_Model = gcpp_Model_GEMMA_2B; + let mut model_training: gcpp_ModelTraining = gcpp_ModelTraining_GEMMA_IT; + let mut tokenizer_path = String::new(); + let mut compressed_weights_path = String::new(); + let mut config = gcpp_RuntimeConfig { + max_tokens: 3072, + max_generated_tokens: 2048, + temperature: 1.0, + verbosity: 0, + }; + if let Some(Opt::ModelType(mt)) = options.get(OptDiscriminants::ModelType) { + let parts = Vec::from_iter(mt.split("-").into_iter()); + if parts.len() != 2 { + return Err(ExecutorCreationError::InvalidValue(format!( + "model type {} is invalid", + mt + ))); + } + match parts[0] { + "2b" => {} + "7b" => { + model_type = gcpp_Model_GEMMA_7B; + } + _ => { + return Err(ExecutorCreationError::InvalidValue(format!( + "model type {} must be 2b or 7b", + parts[0] + ))); + } + } + match parts[1] { + "it" => {} + "pt" => { + model_training = gcpp_ModelTraining_GEMMA_PT; + } + _ => { + return Err(ExecutorCreationError::InvalidValue(format!( + "model training {} must be it or pt", + parts[1] + ))); + } + } + } + if let Some(Opt::Model(m)) = options.get(OptDiscriminants::Model) { + compressed_weights_path = m.to_path(); + let parent = Path::new(&compressed_weights_path).parent(); + if parent.is_none() { + return Err(ExecutorCreationError::InvalidValue(String::from( + "no parent for path", + ))); + } + if let Some(tpath) = parent.unwrap().join("tokenizer.spm").to_str() { + tokenizer_path = String::from(tpath); + } + } + if let Some(Opt::Temperature(t)) = options.get(OptDiscriminants::Temperature) { + config.temperature = *t; + } + if let Some(Opt::MaxTokens(m)) = options.get(OptDiscriminants::MaxTokens) { + config.max_tokens = *m as ffi::c_uint; + } + unsafe { + let pool = hwy_ThreadPool_ThreadPool( + if let Some(Opt::NThreads(nt)) = options.get(OptDiscriminants::NThreads) { + *nt as ffi::c_uint + } else { + 0 + }, + ); + if pool == std::ptr::null_mut() { + return Err(ExecutorCreationError::InnerError(Box::new( + GemmaNotSupportedError {}, + ))); + } + + let gemma = gcpp_Gemma_Gemma( + tokenizer_path.as_ptr() as *const i8, + tokenizer_path.len() as ffi::c_uint, + compressed_weights_path.as_ptr() as *const i8, + compressed_weights_path.len() as ffi::c_uint, + std::ptr::null(), + 0, + model_type, + pool, + ); + if gemma == std::ptr::null_mut() { + return Err(ExecutorCreationError::InnerError(Box::new( + GemmaNotSupportedError {}, + ))); + } + gcpp_Gemma_SetModelTraining(gemma, model_training); + + let gen = std_mt19937_mt19937(); + if gen == std::ptr::null_mut() { + return Err(ExecutorCreationError::InnerError(Box::new( + GemmaNotSupportedError {}, + ))); + } + std_mt19937_random_seed(gen); + + Ok(GemmaContext { + gemma: gemma, + gen: gen, + model_training: model_training as gcpp_ModelTraining, + config: config, + kvcache: gcpp_CreateKVCache(model_type), + pool: pool, + pos: 0, + }) + } + } +} + +impl Drop for GemmaContext { + fn drop(&mut self) { + unsafe { + gcpp_Gemma_destructor(self.gemma); + std_mt19937_destructor(self.gen); + gcpp_KVCache_destructor(self.kvcache); + hwy_ThreadPool_destructor(self.pool); + } + } +} + +#[repr(C)] +struct GenerateContext { + gemma: *mut gcpp_Gemma, + pos: u32, + tokens_processed: u32, + input_tokens: u32, + out: mpsc::UnboundedSender, +} + +extern "C" fn stream_token( + ctx: *mut ffi::c_void, + token: ffi::c_int, + _: ffi::c_float, +) -> ffi::c_char { + unsafe { + let gctx = ctx as *mut GenerateContext; + (*gctx).pos += 1; + (*gctx).tokens_processed += 1; + if (*gctx).tokens_processed < (*gctx).input_tokens { + return true as ffi::c_char; + } + if token == EOS_ID { + return true as ffi::c_char; + } + let s = std_string_string(); + if gcpp_Gemma_Decode((*gctx).gemma, token, s) == 0 { + return false as ffi::c_char; + } + let decoded = ffi::CString::from_raw(std_string_c_str(s)).into_string(); + if decoded.is_err() { + return false as ffi::c_char; + } + (*gctx) + .out + .send(StreamSegment::Content(decoded.unwrap())) + .is_ok() as ffi::c_char + } +} + +impl GemmaContext { + pub fn generate<'a>(&mut self, prompt: String, out: mpsc::UnboundedSender) { + unsafe { + if self.model_training != gcpp_ModelTraining_GEMMA_IT { + self.pos = 0 + } + let mut prompt_text = if self.model_training == gcpp_ModelTraining_GEMMA_IT { + format!("{prompt}model\n") + } else { + prompt + }; + if self.pos > 0 { + prompt_text = format!("{prompt_text}"); + } + let tokens = std_vector_int_vector(); + gcpp_Gemma_Encode( + self.gemma, + prompt_text.as_mut_ptr() as *mut ffi::c_char, + prompt_text.len() as ffi::c_uint, + tokens, + ); + let mut genctx = GenerateContext { + gemma: self.gemma, + pos: self.pos, + tokens_processed: 0, + input_tokens: std_vector_int_size(tokens) as u32, + out: out, + }; + gcpp_GenerateGemma( + self.gemma, + &mut self.config, + tokens, + self.pos, + self.kvcache, + self.pool, + (&mut genctx as *mut GenerateContext) as *mut ffi::c_void, + stream_token, + self.gen, + ); + self.pos = genctx.pos; + std_vector_int_destructor(tokens); + } + } + + pub fn max_generated_tokens(&self) -> u32 { + self.config.max_generated_tokens as u32 + } +} + +impl Tokenizer for GemmaContext { + fn tokenize_str(&self, doc: &str) -> Result { + unsafe { + let mut doc_copied = String::from(doc); + let tokens = std_vector_int_vector(); + let result = gcpp_Gemma_Encode( + self.gemma, + doc_copied.as_mut_ptr() as *mut ffi::c_char, + doc.len() as ffi::c_uint, + tokens, + ); + if result == 0 { + return Err(TokenizerError::ToStringError); + } + Ok(TokenCollection::from(Vec::from_iter( + std_vector_int_iter::new(tokens), + ))) + } + } + + fn to_string(&self, tokens: TokenCollection) -> Result { + let ts = tokens.as_i32()?; + unsafe { + let out = std_string_string(); + let ok = gcpp_Gemma_Decodes(self.gemma, ts.as_ptr(), ts.len() as ffi::c_int, out); + if ok == 0 { + std_string_destructor(out); + return Err(TokenizerError::ToStringError); + } + let out_str = ffi::CString::from_raw(std_string_c_str(out)).into_string(); + std_string_destructor(out); + out_str.map_err(|_| TokenizerError::ToStringError) + } + } +} + +unsafe impl Sync for GemmaContext {} +unsafe impl Send for GemmaContext {} diff --git a/crates/llm-chain-gemma/src/executor.rs b/crates/llm-chain-gemma/src/executor.rs new file mode 100644 index 00000000..b4383d61 --- /dev/null +++ b/crates/llm-chain-gemma/src/executor.rs @@ -0,0 +1,111 @@ +use crate::context::GemmaContext; +use async_trait::async_trait; +use llm_chain::options::{Opt, OptDiscriminants, Options}; +use llm_chain::output::Output; +use llm_chain::prompt::Prompt; +use llm_chain::tokens::{ + PromptTokensError, TokenCollection, TokenCount, Tokenizer, TokenizerError, +}; +use llm_chain::traits::{Executor as ExecutorTrait, ExecutorCreationError, ExecutorError}; +use std::sync::{Arc, Mutex}; +use tokio; + +pub struct Executor { + context: Arc>, + stream: bool, +} + +#[async_trait] +impl ExecutorTrait for Executor { + type StepTokenizer<'a> = GemmaTokenizer; + + fn new_with_options(options: Options) -> Result { + let gemma_context = GemmaContext::new(&options)?; + Ok(Executor { + context: Arc::new(Mutex::new(gemma_context)), + stream: if let Some(Opt::Stream(s)) = options.get(OptDiscriminants::Stream) { + *s + } else { + false + }, + }) + } + + async fn execute(&self, options: &Options, prompt: &Prompt) -> Result { + let is_stream = if let Some(Opt::Stream(s)) = options.get(OptDiscriminants::Stream) { + *s + } else { + self.stream + }; + let (sender, stream) = Output::new_stream(); + let context = self.context.clone(); + let prompt_text = prompt.to_string(); + if is_stream { + tokio::task::spawn_blocking(move || { + if let Ok(mut ctx) = context.lock() { + ctx.generate(prompt_text, sender); + } + }); + return Ok(stream); + } else { + let mut ctx = context.lock().map_err(|_| ExecutorError::InvalidOptions)?; + ctx.generate(prompt_text, sender); + } + stream + .to_immediate() + .await + .map(|imm| Output::Immediate(imm)) + } + + fn tokens_used( + &self, + options: &Options, + prompt: &Prompt, + ) -> Result { + let tokenizer = self.get_tokenizer(options)?; + let tokens = tokenizer.tokenize_str(prompt.to_string().as_str())?; + Ok(TokenCount::new( + self.max_tokens_allowed(options), + tokens.len() as i32, + )) + } + + fn max_tokens_allowed(&self, options: &Options) -> i32 { + if let Some(Opt::MaxTokens(mt)) = options.get(OptDiscriminants::MaxTokens) { + return *mt as i32; + } + self.context.lock().unwrap().max_generated_tokens() as i32 + } + + fn get_tokenizer(&self, _options: &Options) -> Result, TokenizerError> { + Ok(GemmaTokenizer { + context: self.context.clone(), + }) + } + + fn answer_prefix(&self, _prompt: &Prompt) -> Option { + None + } +} + +pub struct GemmaTokenizer { + context: Arc>, +} + +impl Tokenizer for GemmaTokenizer { + fn tokenize_str(&self, doc: &str) -> Result { + let ctx = self + .context + .lock() + .map_err(|_| TokenizerError::TokenizationError)?; + ctx.tokenize_str(doc) + } + + fn to_string(&self, tokens: TokenCollection) -> Result { + let ctx = self + .context + .lock() + .map_err(|_| TokenizerError::ToStringError)?; + ctx.to_string(tokens) + } +} diff --git a/crates/llm-chain-gemma/src/lib.rs b/crates/llm-chain-gemma/src/lib.rs new file mode 100644 index 00000000..8f1bb6e8 --- /dev/null +++ b/crates/llm-chain-gemma/src/lib.rs @@ -0,0 +1,4 @@ +mod context; +mod executor; + +pub use executor::Executor; diff --git a/crates/llm-chain/src/executor.rs b/crates/llm-chain/src/executor.rs index dbf044cb..6de5fe52 100644 --- a/crates/llm-chain/src/executor.rs +++ b/crates/llm-chain/src/executor.rs @@ -56,6 +56,10 @@ macro_rules! executor { use llm_chain::traits::Executor; llm_chain_openai::chatgpt::Executor::new_with_options($options) }}; + (gemma, $options:expr) => {{ + use llm_chain::traits::Executor; + llm_chain_gemma::Executor::new_with_options($options) + }}; (llama) => {{ use llm_chain::traits::Executor; llm_chain_llama::Executor::new()