Skip to content

Commit

Permalink
Add gemma.cpp bindings / llm-chain-gemma. (#281)
Browse files Browse the repository at this point in the history
* Add gemma.cpp bindings / llm-chain-gemma.

* style fixes

* fix macos build

* possibly fix windows build

* minor fixes

* style fix again

* potential fix for windows

* Uprev gemma.cpp version

* exclude windows support

* style fixes

* update mio (unrelated but make the CI happy)
  • Loading branch information
jmuk authored Mar 21, 2024
1 parent 989b970 commit 5e15391
Show file tree
Hide file tree
Showing 15 changed files with 1,042 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "crates/llm-chain-llama/sys/llama.cpp"]
path = crates/llm-chain-llama-sys/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "crates/llm-chain-gemma-sys/gemma.cpp"]
path = crates/llm-chain-gemma-sys/gemma.cpp
url = https://github.com/google/gemma.cpp.git
29 changes: 22 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions crates/llm-chain-gemma-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "llm-chain-gemma-sys"
description = "A library with bindings for gemma.cpp"
version = "0.1.0"
edition = "2021"
license = "MIT"
keywords = ["llm", "langchain", "gemma", "chain"]
categories = ["science"]
authors = [
"Jun Mukai <[email protected]>",
]
repository = "https://github.com/sobelio/llm-chain/"
readme = "README.md"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]

[build-dependencies]
cc = "1.0.87"
130 changes: 130 additions & 0 deletions crates/llm-chain-gemma-sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#![allow(clippy::uninlined_format_args)]

extern crate cc;

use std::env;

fn main() {
#[cfg(target_os = "windows")]
{
// Gemma.cpp does not support MSBuild at this point --
// it does support clang-cl though. At this time, Windows
// is out of the support because of this.
// See: https://github.com/google/gemma.cpp/pull/6
cc::Build::new()
.cpp(true)
.file("src/bindings_win.cc")
.std("c++17")
.compile("bindings");
return;
}
let target = env::var("TARGET").unwrap();
// Link C++ standard library
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) {
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib);
println!("cargo:rustc-link-arg=-l{}", cpp_stdlib);
}
// Link macOS Accelerate framework for matrix calculations
if target.contains("apple") {
println!("cargo:rustc-link-lib=framework=Accelerate");
}
println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap());
println!("cargo:rustc-link-lib=static=gemma");
println!("cargo:rustc-link-lib=static=hwy");
println!("cargo:rustc-link-lib=static=hwy_contrib");
println!("cargo:rustc-link-lib=static=sentencepiece");
println!("cargo:rustc-link-lib=static=bindings");
println!("cargo:rerun-if-changed=wrapper.h");

// stop if we're on docs.rs
if env::var("DOCS_RS").is_ok() {
return;
}

// Run cmake to generate build files.
env::set_current_dir("gemma.cpp").expect("Unable to change directory to gemma.cpp");
env::set_current_dir("build").expect("Unable to change directory to gemma.cpp build");

env::set_var("CXXFLAGS", "-fPIC");
env::set_var("CFLAGS", "-fPIC");

let mut code = std::process::Command::new("cmake");
let code = code
.arg("..")
.arg("-DCMAKE_BUILD_TYPE=Release")
.arg("-DBUILD_SHARED_LIBS=OFF")
.arg("-DWEIGHT_TYPE=hwy::bfloat16_t")
.arg("-DSPM_ENABLE_SHARED=OFF");
let code = code.status().expect("Failed to generate build script");
if code.code() != Some(0) {
panic!("Failed to generate build script");
}

// Build binary.
#[allow(clippy::suspicious_command_arg_space)]
let code = std::process::Command::new("cmake")
.arg("--build")
.arg(".")
.arg("--config")
.arg("Release")
.arg("--target")
.arg("libgemma")
.status()
.expect("Failed to build lib");
if code.code() != Some(0) {
panic!("Failed to build lib");
}

// move libllama.a to where Cargo expects it (OUT_DIR)
std::fs::copy(
"libgemma.a",
format!("{}/libgemma.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy lib");

std::fs::copy(
"_deps/highway-build/libhwy.a",
format!("{}/libhwy.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libhwy.a");

std::fs::copy(
"_deps/highway-build/libhwy_contrib.a",
format!("{}/libhwy_contrib.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libhwy_contrib.a");

std::fs::copy(
"_deps/sentencepiece-build/src/libsentencepiece.a",
format!("{}/libsentencepiece.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libsentencepiece.a");

// Finally, build bindings.cc to allow access for gemma.cpp.
// So far, bindgen does not correctly generate buildable rust file,
// so I manually wrote bindings.rs for hand-written src/bindings.cc file.
env::set_current_dir("..").expect("Unlable to change directory back to gemma.cpp");
env::set_current_dir("..").expect("Unlable to change directory back to crate top");

cc::Build::new()
.cpp(true)
.file("src/bindings.cc")
.std("c++17")
.include("./gemma.cpp")
.include("./gemma.cpp/build/_deps/highway-src")
.include("./gemma.cpp/build/_deps/sentencepiece-src")
.compile("bindings");
}

// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> {
if target.contains("msvc") {
None
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") {
Some("c++")
} else if target.contains("android") {
Some("c++_shared")
} else {
Some("stdc++")
}
}
1 change: 1 addition & 0 deletions crates/llm-chain-gemma-sys/gemma.cpp
Submodule gemma.cpp added at 022195
124 changes: 124 additions & 0 deletions crates/llm-chain-gemma-sys/src/bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include <gemma.h>

extern "C" {

hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) {
return new hwy::ThreadPool(num_threads);
}

void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) {
delete pool;
}

gcpp::Gemma* gcpp_Gemma_Gemma(
const char* tokenizer_path, size_t tokenizer_path_len,
const char* compressed_weights_path, size_t compressed_weights_path_len,
const char* weights_path, size_t weights_path_len,
gcpp::Model model_type, hwy::ThreadPool* pool) {
gcpp::Path tpath;
tpath.path = std::string(tokenizer_path, tokenizer_path_len);
gcpp::Path cwpath;
cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len);
gcpp::Path wpath;
wpath.path = std::string(weights_path, weights_path_len);
return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool);
}

void gcpp_Gemma_destructor(gcpp::Gemma* gemma) {
delete gemma;
}

void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) {
gemma->model_training = training;
}

gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) {
gcpp::KVCache* cache = new gcpp::KVCache{};
*cache = gcpp::CreateKVCache(model_type);
return cache;
}

void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) {
delete kvcache;
}

std::vector<int>* std_vector_int_vector() {
return new std::vector<int>();
}

void std_vector_int_destructor(std::vector<int>* v) {
delete v;
}

size_t std_vector_int_size(const std::vector<int>* v) {
return v->size();
}

int std_vector_int_at(const std::vector<int>* v, size_t i) {
return v->at(i);
}

std::string* std_string_string() {
return new std::string();
}

void std_string_destructor(std::string* s) {
delete s;
}

const char* std_string_c_str(const std::string* s) {
return s->c_str();
}

bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector<int>* out) {
return gemma->Tokenizer()->Encode(std::string(input, len), out).ok();
}

bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) {
return gemma->Tokenizer()->Decode(std::vector<int>{token}, out).ok();
}

bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) {
std::vector<int> v;
v.reserve(num_tokens);
for (int i = 0; i < num_tokens; i++) {
v.push_back(tokens[i]);
}
return gemma->Tokenizer()->Decode(v, out).ok();
}

std::mt19937* std_mt19937_mt19937() {
return new std::mt19937();
}

void std_mt19937_destructor(std::mt19937* gen) {
delete gen;
}

void std_mt19937_seed(std::mt19937* gen, int seed) {
gen->seed(seed);
}

void std_mt19937_random_seed(std::mt19937* gen) {
std::random_device rd;
gen->seed(rd());
}

typedef bool (*stream_callback)(void*, int, float);
typedef bool (*accept_callback)(void*, int);

void gcpp_GenerateGemma(
gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config,
const std::vector<int>* prompt, size_t start_pos,
gcpp::KVCache* kvcache, hwy::ThreadPool* pool,
void* stream_context, stream_callback stream_token,
std::mt19937* gen) {
gcpp::GenerateGemma(
*gemma, *config, *prompt, start_pos, *kvcache, *pool,
[&stream_context, &stream_token](int token, float value) {
return stream_token(stream_context, token, value);
},
*gen);
}

}
Loading

0 comments on commit 5e15391

Please sign in to comment.