Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gemma.cpp bindings / llm-chain-gemma. #281

Merged
merged 11 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "crates/llm-chain-llama/sys/llama.cpp"]
path = crates/llm-chain-llama-sys/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "crates/llm-chain-gemma-sys/gemma.cpp"]
path = crates/llm-chain-gemma-sys/gemma.cpp
url = https://github.com/google/gemma.cpp.git
29 changes: 22 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions crates/llm-chain-gemma-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "llm-chain-gemma-sys"
description = "A library with bindings for gemma.cpp"
version = "0.1.0"
edition = "2021"
license = "MIT"
keywords = ["llm", "langchain", "gemma", "chain"]
categories = ["science"]
authors = [
"Jun Mukai <[email protected]>",
]
repository = "https://github.com/sobelio/llm-chain/"
readme = "README.md"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]

[build-dependencies]
cc = "1.0.87"
130 changes: 130 additions & 0 deletions crates/llm-chain-gemma-sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#![allow(clippy::uninlined_format_args)]

extern crate cc;

use std::env;

fn main() {
#[cfg(target_os = "windows")]
{
// Gemma.cpp does not support MSBuild at this point --
// it does support clang-cl though. At this time, Windows
// is out of the support because of this.
// See: https://github.com/google/gemma.cpp/pull/6
cc::Build::new()
.cpp(true)
.file("src/bindings_win.cc")
.std("c++17")
.compile("bindings");
return;
}
let target = env::var("TARGET").unwrap();
// Link C++ standard library
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) {
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib);
println!("cargo:rustc-link-arg=-l{}", cpp_stdlib);
}
// Link macOS Accelerate framework for matrix calculations
if target.contains("apple") {
println!("cargo:rustc-link-lib=framework=Accelerate");
}
println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap());
println!("cargo:rustc-link-lib=static=gemma");
println!("cargo:rustc-link-lib=static=hwy");
println!("cargo:rustc-link-lib=static=hwy_contrib");
println!("cargo:rustc-link-lib=static=sentencepiece");
println!("cargo:rustc-link-lib=static=bindings");
println!("cargo:rerun-if-changed=wrapper.h");

// stop if we're on docs.rs
if env::var("DOCS_RS").is_ok() {
return;
}

// Run cmake to generate build files.
env::set_current_dir("gemma.cpp").expect("Unable to change directory to gemma.cpp");
env::set_current_dir("build").expect("Unable to change directory to gemma.cpp build");

env::set_var("CXXFLAGS", "-fPIC");
env::set_var("CFLAGS", "-fPIC");

let mut code = std::process::Command::new("cmake");
let code = code
.arg("..")
.arg("-DCMAKE_BUILD_TYPE=Release")
.arg("-DBUILD_SHARED_LIBS=OFF")
.arg("-DWEIGHT_TYPE=hwy::bfloat16_t")
.arg("-DSPM_ENABLE_SHARED=OFF");
let code = code.status().expect("Failed to generate build script");
if code.code() != Some(0) {
panic!("Failed to generate build script");
}

// Build binary.
#[allow(clippy::suspicious_command_arg_space)]
let code = std::process::Command::new("cmake")
.arg("--build")
.arg(".")
.arg("--config")
.arg("Release")
.arg("--target")
.arg("libgemma")
.status()
.expect("Failed to build lib");
if code.code() != Some(0) {
panic!("Failed to build lib");
}

// move libllama.a to where Cargo expects it (OUT_DIR)
std::fs::copy(
"libgemma.a",
format!("{}/libgemma.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy lib");

std::fs::copy(
"_deps/highway-build/libhwy.a",
format!("{}/libhwy.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libhwy.a");

std::fs::copy(
"_deps/highway-build/libhwy_contrib.a",
format!("{}/libhwy_contrib.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libhwy_contrib.a");

std::fs::copy(
"_deps/sentencepiece-build/src/libsentencepiece.a",
format!("{}/libsentencepiece.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libsentencepiece.a");

// Finally, build bindings.cc to allow access for gemma.cpp.
// So far, bindgen does not correctly generate buildable rust file,
// so I manually wrote bindings.rs for hand-written src/bindings.cc file.
env::set_current_dir("..").expect("Unlable to change directory back to gemma.cpp");
env::set_current_dir("..").expect("Unlable to change directory back to crate top");

cc::Build::new()
.cpp(true)
.file("src/bindings.cc")
.std("c++17")
.include("./gemma.cpp")
.include("./gemma.cpp/build/_deps/highway-src")
.include("./gemma.cpp/build/_deps/sentencepiece-src")
.compile("bindings");
}

// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> {
if target.contains("msvc") {
None
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") {
Some("c++")
} else if target.contains("android") {
Some("c++_shared")
} else {
Some("stdc++")
}
}
1 change: 1 addition & 0 deletions crates/llm-chain-gemma-sys/gemma.cpp
Submodule gemma.cpp added at 022195
124 changes: 124 additions & 0 deletions crates/llm-chain-gemma-sys/src/bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include <gemma.h>

extern "C" {

hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) {
return new hwy::ThreadPool(num_threads);
}

void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) {
delete pool;
}

gcpp::Gemma* gcpp_Gemma_Gemma(
const char* tokenizer_path, size_t tokenizer_path_len,
const char* compressed_weights_path, size_t compressed_weights_path_len,
const char* weights_path, size_t weights_path_len,
gcpp::Model model_type, hwy::ThreadPool* pool) {
gcpp::Path tpath;
tpath.path = std::string(tokenizer_path, tokenizer_path_len);
gcpp::Path cwpath;
cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len);
gcpp::Path wpath;
wpath.path = std::string(weights_path, weights_path_len);
return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool);
}

void gcpp_Gemma_destructor(gcpp::Gemma* gemma) {
delete gemma;
}

void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) {
gemma->model_training = training;
}

gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) {
gcpp::KVCache* cache = new gcpp::KVCache{};
*cache = gcpp::CreateKVCache(model_type);
return cache;
}

void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) {
delete kvcache;
}

std::vector<int>* std_vector_int_vector() {
return new std::vector<int>();
}

void std_vector_int_destructor(std::vector<int>* v) {
delete v;
}

size_t std_vector_int_size(const std::vector<int>* v) {
return v->size();
}

int std_vector_int_at(const std::vector<int>* v, size_t i) {
return v->at(i);
}

std::string* std_string_string() {
return new std::string();
}

void std_string_destructor(std::string* s) {
delete s;
}

const char* std_string_c_str(const std::string* s) {
return s->c_str();
}

bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector<int>* out) {
return gemma->Tokenizer()->Encode(std::string(input, len), out).ok();
}

bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) {
return gemma->Tokenizer()->Decode(std::vector<int>{token}, out).ok();
}

bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) {
std::vector<int> v;
v.reserve(num_tokens);
for (int i = 0; i < num_tokens; i++) {
v.push_back(tokens[i]);
}
return gemma->Tokenizer()->Decode(v, out).ok();
}

std::mt19937* std_mt19937_mt19937() {
return new std::mt19937();
}

void std_mt19937_destructor(std::mt19937* gen) {
delete gen;
}

void std_mt19937_seed(std::mt19937* gen, int seed) {
gen->seed(seed);
}

void std_mt19937_random_seed(std::mt19937* gen) {
std::random_device rd;
gen->seed(rd());
}

typedef bool (*stream_callback)(void*, int, float);
typedef bool (*accept_callback)(void*, int);

void gcpp_GenerateGemma(
gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config,
const std::vector<int>* prompt, size_t start_pos,
gcpp::KVCache* kvcache, hwy::ThreadPool* pool,
void* stream_context, stream_callback stream_token,
std::mt19937* gen) {
gcpp::GenerateGemma(
*gemma, *config, *prompt, start_pos, *kvcache, *pool,
[&stream_context, &stream_token](int token, float value) {
return stream_token(stream_context, token, value);
},
*gen);
}

}
Loading
Loading