From 4ca066cdf6e1b87c5e4ce50a7dfae58f28bf4e25 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 12 Nov 2023 21:53:14 +0100 Subject: [PATCH] Merge branch 'develop' in --- Cargo.lock | 102 +-- Cargo.toml | 13 +- README.md | 1 + binaries/generate-ggml-bindings/src/main.rs | 5 + binaries/llm-cli/src/cli_args.rs | 15 +- binaries/llm-cli/src/interactive.rs | 5 +- binaries/llm-cli/src/main.rs | 4 +- binaries/llm-test/configs/mpt.json | 2 +- binaries/llm-test/src/inference.rs | 10 +- crates/ggml/Cargo.toml | 3 +- crates/ggml/README.md | 2 +- crates/ggml/src/accelerator/metal.rs | 13 +- crates/ggml/src/accelerator/mod.rs | 1 + crates/ggml/src/context.rs | 102 ++- crates/ggml/src/format/loader.rs | 2 +- crates/ggml/src/lib.rs | 154 +++- crates/ggml/src/tensor.rs | 19 +- crates/ggml/sys/build.rs | 25 +- crates/ggml/sys/llama-cpp | 2 +- crates/ggml/sys/src/cuda.rs | 43 +- crates/ggml/sys/src/lib.rs | 932 ++++++++++++++++---- crates/ggml/sys/src/llama.rs | 13 +- crates/ggml/sys/src/metal.rs | 36 +- crates/llm-base/src/inference_session.rs | 197 +++-- crates/llm-base/src/lib.rs | 4 +- crates/llm-base/src/loader.rs | 40 +- crates/llm-base/src/lora.rs | 5 +- crates/llm-base/src/model/mod.rs | 11 + crates/llm-base/src/samplers.rs | 52 +- crates/llm-base/src/util.rs | 22 +- crates/llm/Cargo.toml | 4 +- crates/llm/src/lib.rs | 1 + crates/models/bert/Cargo.toml | 14 + crates/models/bert/src/lib.rs | 464 ++++++++++ crates/models/bloom/src/lib.rs | 29 +- crates/models/falcon/src/lib.rs | 39 +- crates/models/gpt2/src/lib.rs | 36 +- crates/models/gptj/src/lib.rs | 30 +- crates/models/gptneox/src/lib.rs | 40 +- crates/models/llama/src/lib.rs | 38 +- crates/models/mpt/src/lib.rs | 40 +- 41 files changed, 1974 insertions(+), 596 deletions(-) create mode 100644 crates/models/bert/Cargo.toml create mode 100644 crates/models/bert/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 049d70df..6f053a17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -101,17 +101,6 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -120,9 +109,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -191,9 +180,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "block-buffer" @@ -404,13 +393,13 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.0" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" dependencies = [ - "atty", + "is-terminal", "lazy_static", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -755,7 +744,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", - "rustix 0.38.1", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -932,9 +921,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "glob" @@ -982,15 +971,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.1" @@ -1151,7 +1131,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] @@ -1164,12 +1144,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.1", - "rustix 0.38.1", + "hermit-abi", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -1229,9 +1209,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -1251,9 +1231,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "llm" @@ -1262,6 +1242,7 @@ dependencies = [ "bytesize", "clap", "llm-base", + "llm-bert", "llm-bloom", "llm-falcon", "llm-gpt2", @@ -1297,6 +1278,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "llm-bert" +version = "0.2.0-dev" +dependencies = [ + "bytemuck", + "llm-base", + "tracing", +] + [[package]] name = "llm-bloom" version = "0.2.0-dev" @@ -1374,9 +1364,9 @@ dependencies = [ [[package]] name = "llm-samplers" -version = "0.0.6" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7553f60d113c9cdc6a5402456a31cd9a273bef79f6f16d8a4f7b4bedf5f754b2" +checksum = "7e85df656cd89e7702cb56171d75aa77c7bec828af7d2054d9987c34411cf896" dependencies = [ "anyhow", "num-traits", @@ -1601,7 +1591,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", ] @@ -1613,9 +1603,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -2015,9 +2005,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.37.21" +version = "0.37.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" +checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" dependencies = [ "bitflags 1.3.2", "errno", @@ -2029,14 +2019,14 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.1" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.3", + "linux-raw-sys 0.4.11", "windows-sys 0.48.0", ] @@ -2256,9 +2246,9 @@ dependencies = [ [[package]] name = "spinoff" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fee259f96b31e7a18657d11741fe30d63f98e07de70e7a19d2b705ab9b331cdc" +checksum = "20aa2ed67fbb202e7b716ff8bfc6571dd9301617767380197d701c31124e88f6" dependencies = [ "colored", "once_cell", @@ -2344,7 +2334,7 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.37.21", + "rustix 0.37.27", "windows-sys 0.48.0", ] diff --git a/Cargo.toml b/Cargo.toml index ae5b22f7..045ecc9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "crates/llm", "crates/llm-base", "crates/models/*", - "binaries/*" + "binaries/*", ] resolver = "2" default-members = ["binaries/llm-cli", "crates/llm"] @@ -27,12 +27,12 @@ anyhow = "1.0" rustyline = { version = "11.0.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } -spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] } +spinoff = { version = "0.8.0", default-features = false, features = ["dots2"] } clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } -llm-samplers = "=0.0.6" +llm-samplers = "=0.0.7" # Config for 'cargo dist' [workspace.metadata.dist] @@ -45,7 +45,12 @@ ci = ["github"] # The installers to generate for each app installers = ["shell", "powershell"] # Target platforms to build apps for (Rust target-triple syntax) -targets = ["x86_64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-pc-windows-msvc", "aarch64-apple-darwin"] +targets = [ + "x86_64-unknown-linux-gnu", + "x86_64-apple-darwin", + "x86_64-pc-windows-msvc", + "aarch64-apple-darwin", +] # The profile that 'cargo dist' will build with [profile.dist] diff --git a/README.md b/README.md index 8a0bd7ae..b027cfdd 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,7 @@ Absolutely! Please see the [contributing guide](./doc/CONTRIBUTING.md). inference API on your local machine using `llm`. - [secondbrain](https://github.com/juliooa/secondbrain): Desktop app to download and run LLMs locally in your computer using `llm`. - [floneum](https://floneum.com/): A graph editor for local AI workflows. +- [poly](https://github.com/pixelspark/poly): A versatile LLM serving back-end with tasks, streaming completion, memory retrieval, and more. #### Libraries diff --git a/binaries/generate-ggml-bindings/src/main.rs b/binaries/generate-ggml-bindings/src/main.rs index 39acbb86..ad73ddba 100644 --- a/binaries/generate-ggml-bindings/src/main.rs +++ b/binaries/generate-ggml-bindings/src/main.rs @@ -27,6 +27,8 @@ fn generate_main(ggml_path: &Path, src_path: &Path) { .allowlist_file(r".*ggml.h") .header(ggml_path.join("k_quants.h").to_string_lossy()) .allowlist_file(r".*k_quants.h") + .header(ggml_path.join("ggml-alloc.h").to_string_lossy()) + .allowlist_file(r".*ggml-alloc.h") // Suppress some warnings .raw_line("#![allow(non_upper_case_globals)]") .raw_line("#![allow(non_camel_case_types)]") @@ -88,6 +90,9 @@ fn generate_metal(ggml_path: &Path, src_path: &Path) { generate_extra("metal", ggml_path, src_path, |b| { b.header(ggml_path.join("ggml-metal.h").to_string_lossy()) .allowlist_file(r".*ggml-metal\.h") + .raw_line("use super::ggml_tensor;") + .raw_line("use super::ggml_log_callback;") + .raw_line("use super::ggml_cgraph;") }); } diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 21b4a897..e158db68 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -290,6 +290,15 @@ pub struct Generate { /// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen. /// p(0.95): The cumulative probability after which no more tokens are kept for sampling. /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information. + /// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler. + /// a2(0.0): Threshold power. A reasonable value is 2. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information. + /// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. #[arg(long = "sampler", short = 's', verbatim_doc_comment)] pub sampler_options: Vec, @@ -533,7 +542,7 @@ impl ModelLoad { let tokenizer_source = match self.model_and_tokenizer.to_source() { Ok(vs) => vs, Err(err) => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.fail(&format!("Failed to load tokenizer: {}", err)); } return Err(err); @@ -586,7 +595,7 @@ impl ModelLoad { file_size, tensor_count, } => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.success(&format!( "Loaded {tensor_count} tensors ({}) after {}ms", bytesize::to_string(file_size, false), @@ -601,7 +610,7 @@ impl ModelLoad { if model.is_err() { // If we've failed at loading the model, we probably haven't stopped the spinner yet. // Cancel it now if needed. - if let Some(sp) = sp { + if let Some(mut sp) = sp { sp.fail("Failed to load model") } } diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs index b971d5d9..3ad7e486 100644 --- a/binaries/llm-cli/src/interactive.rs +++ b/binaries/llm-cli/src/interactive.rs @@ -141,7 +141,7 @@ fn feed_prompt_with_spinner( prompt.insert(0, '\n'); } - let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); + let mut sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); let result = session.feed_prompt( model, &prompt, @@ -165,8 +165,7 @@ fn session_ends_with_newline(session: &llm::InferenceSession) -> bool { session .decoded_tokens() .last() - .map(|t| *t == b'\n') - .unwrap_or(true) + .map_or(true, |t| *t == b'\n') } fn readline_loop(mut body: impl FnMut(String) -> eyre::Result<()>) -> eyre::Result<()> { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index b0eabece..242b42c8 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -164,9 +164,7 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { } fn utf8_or_array(token: &[u8]) -> String { - std::str::from_utf8(token) - .map(|s| s.to_owned()) - .unwrap_or(format!("{:?}", token)) + std::str::from_utf8(token).map_or(format!("{:?}", token), |s| s.to_owned()) } Ok(()) diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 57a8bc89..c5d9d8d0 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -6,7 +6,7 @@ { "Inference": { "input": "When a llama rides a crab, ", - "output": "When a llama rides a crab,  the llama is called the \"crab rider\".\nThe crabs are very popular in South America, especially Brazil. They have been used as transportation for many years and they can carry up to five people at once!", + "output": "When a llama rides a crab,  the llama is called the \"crab rider\"\nThe Llamas are an animal that can be found in The Maze. They have no special abilities, but they do drop Llamaskin and occasionally some other items when killed by players or monsters alike (see below). It's unknown if there was ever any sort of breeding system for these animals as it seems to only exist on this one world so far; however their existence has been confirmed through player reports from multiple worlds where people claim having seen them before being able see anything else about what happened after seeing just 1-2 at most per game session which makes me believe", "maximum_token_count": 128 } }, diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs index a9ace889..3666167e 100644 --- a/binaries/llm-test/src/inference.rs +++ b/binaries/llm-test/src/inference.rs @@ -92,14 +92,14 @@ fn run_inference( // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` // at all #[derive(Debug, Default)] -struct DeterministicSampler(SampleGreedy); +struct DeterministicSampler(SampleGreedy); -impl Sampler for DeterministicSampler { +impl Sampler for DeterministicSampler { fn sample<'a>( &mut self, - res: &mut dyn HasSamplerResources, - logits: &'a mut Logits, - ) -> anyhow::Result<&'a mut Logits> { + res: &mut dyn HasSamplerResources, + logits: &'a mut Logits, + ) -> anyhow::Result<&'a mut Logits> { let mut flat_bias = Default::default(); // This might look a little weird, but it's necessary because the resource diff --git a/crates/ggml/Cargo.toml b/crates/ggml/Cargo.toml index 2c4f43bd..fe60f7a9 100644 --- a/crates/ggml/Cargo.toml +++ b/crates/ggml/Cargo.toml @@ -7,8 +7,9 @@ description = "Semi-idiomatic Rust bindings for the ggml library (from `ggml-sys license = "MIT" [dependencies] -thiserror = { workspace = true } ggml-sys = { path = "sys", version = "0.2.0-dev" } + +thiserror = { workspace = true } memmap2 = { workspace = true } [dev-dependencies] diff --git a/crates/ggml/README.md b/crates/ggml/README.md index 6849e8b7..a165e855 100644 --- a/crates/ggml/README.md +++ b/crates/ggml/README.md @@ -119,7 +119,7 @@ following 2 ⨯ 2 tensor named `tensor_a0`: A simplification of the GGML representation of `tensor_a0` is -`{"tensor_a0", [2, 2, 1, 1], [1.0, 0.0, 0.1, 1.0]}`. Note that the 4-element +`{"tensor_a0", [2, 2, 1, 1], [1.0, 0.0, 0.1, 1.1]}`. Note that the 4-element list of dimensions uses `1` as a placeholder for unused dimensions - this is because the product of the dimensions should not equal zero. diff --git a/crates/ggml/src/accelerator/metal.rs b/crates/ggml/src/accelerator/metal.rs index 8fced466..a15e39f1 100644 --- a/crates/ggml/src/accelerator/metal.rs +++ b/crates/ggml/src/accelerator/metal.rs @@ -14,8 +14,8 @@ pub struct MetalContext { impl MetalContext { /// Create a new Metal context - pub fn new(n_threads: usize) -> Self { - let raw = unsafe { metal::ggml_metal_init(n_threads.try_into().unwrap()) }; + pub fn new() -> Self { + let raw = unsafe { metal::ggml_metal_init(1) }; MetalContext { contexts: vec![], @@ -83,19 +83,14 @@ impl MetalContext { unsafe { metal::ggml_metal_graph_compute( self.ptr.as_ptr(), - graph.inner as *mut ggml_sys::ggml_cgraph as *mut metal::ggml_cgraph, + graph.inner as *mut ggml_sys::ggml_cgraph, ); } } /// Reads a tensor from Metal pub fn get_tensor(&self, tensor: &Tensor) { - unsafe { - metal::ggml_metal_get_tensor( - self.ptr.as_ptr(), - tensor.ptr.as_ptr() as *mut metal::ggml_tensor, - ) - } + unsafe { metal::ggml_metal_get_tensor(self.ptr.as_ptr(), tensor.ptr.as_ptr()) } } } diff --git a/crates/ggml/src/accelerator/mod.rs b/crates/ggml/src/accelerator/mod.rs index 2e1cef17..731de9bc 100644 --- a/crates/ggml/src/accelerator/mod.rs +++ b/crates/ggml/src/accelerator/mod.rs @@ -71,6 +71,7 @@ pub fn initialize(device: i32) { //TODO: Make this configurable sys::cuda::ggml_init_cublas(); sys::cuda::ggml_cuda_set_main_device(device); + sys::cuda::ggml_cuda_set_mul_mat_q(true); let split = 1.0f32; sys::cuda::ggml_cuda_set_tensor_split(&split as *const f32); } diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index e5c4cdb0..96f81b4f 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -21,7 +21,7 @@ pub struct Context { /// allocated tensors. Tensors are owned by the object, so a [`Tensor`] /// contains a `Weak` reference underneath and doesn't let you do anything /// with it if the underlying context has been deallocated. - inner: Arc, + pub inner: Arc, /// The storage for this context. This is stored so that the buffer can be dropped when the context is dropped. storage: Option, @@ -31,7 +31,7 @@ pub struct Context { } /// Contains state shared between a context and its tensors -pub(crate) struct ContextInner { +pub struct ContextInner { pub ptr: NonNull, /// Offloaded tensors. Used to free them when the context is dropped. @@ -56,6 +56,13 @@ impl PartialEq for ContextInner { impl Eq for ContextInner {} impl ContextInner { pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc { + // This context can only be used from one thread at a time - hence why + // it doesn't implement `Send/Sync` - but higher-level abstractions may + // choose to layer their own abstractions that implement higher-level + // synchronization that can offer thread-safety guarantees. To ensure + // that we don't break those, we still use an `Arc` here. + // TODO: check if this is correct? + #[allow(clippy::arc_with_non_send_sync)] Arc::new(Self { ptr: NonNull::new(ptr).expect("Should not be null"), offloaded_tensors: Default::default(), @@ -66,7 +73,12 @@ impl ContextInner { /// Controls how the context uses memory. pub enum ContextStorage { /// Use the provided buffer as memory. - Buffer(Buffer), + Buffer { + /// The buffer to use as memory. + buffer: Buffer, + /// Whether to allocate tensors into this buffer. + allocate: bool, + }, /// Use the provided memory mapped file as memory. Mmap(Mmap), /// Allocate `mem_size` bytes of memory. @@ -87,7 +99,10 @@ impl ContextStorage { /// Returns the `Buffer` if this is a `Buffer` variant. pub fn as_buffer(&self) -> Option<&Buffer> { match self { - Self::Buffer(v) => Some(v), + Self::Buffer { + buffer: v, + allocate: _, + } => Some(v), _ => None, } } @@ -108,7 +123,16 @@ impl PartialEq for ContextStorage { fn eq(&self, other: &Self) -> bool { use ContextStorage::*; match (self, other) { - (Buffer(l0), Buffer(r0)) => l0 == r0, + ( + Buffer { + buffer: l0, + allocate: l1, + }, + Buffer { + buffer: r0, + allocate: r1, + }, + ) => l0 == r0 && l1 == r1, (Mmap(l0), Mmap(r0)) => l0.as_ptr() == r0.as_ptr(), (Allocate { mem_size: l }, Allocate { mem_size: r }) => l == r, _ => false, @@ -118,13 +142,15 @@ impl PartialEq for ContextStorage { impl Eq for ContextStorage {} impl Context { - /// Creates a new [Context] with the given storage.. + // See explanation in [`ContextInner::new`]. + #[allow(clippy::arc_with_non_send_sync)] + /// Creates a new [Context] with the given storage. pub fn new(storage: ContextStorage) -> Self { let init_params = match &storage { - ContextStorage::Buffer(buffer) => sys::ggml_init_params { + ContextStorage::Buffer { buffer, allocate } => sys::ggml_init_params { mem_size: buffer.size(), mem_buffer: buffer.data, - no_alloc: false, + no_alloc: !allocate, }, ContextStorage::Mmap(mmap) => sys::ggml_init_params { mem_size: mmap.len(), @@ -151,8 +177,8 @@ impl Context { /// Creates a new [Context] with the specified buffer. /// The buffer will be used by GGML. - pub fn new_with_buffer(buffer: Buffer) -> Self { - Self::new(ContextStorage::Buffer(buffer)) + pub fn new_with_buffer(buffer: Buffer, allocate: bool) -> Self { + Self::new(ContextStorage::Buffer { buffer, allocate }) } /// Creates a new [Context] with the specified memory mapped file. @@ -197,28 +223,6 @@ impl Context { unsafe { sys::ggml_used_mem(self.as_ptr()) } } - /// Sets the scratch buffer to be used by this [Context]. - /// - /// If `scratch_buffer` is `None`, the scratch buffer will be disabled. - pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a Buffer>) { - let (size, data) = if let Some(buffer) = scratch_buffer { - (buffer.size(), buffer.data) - } else { - (0, std::ptr::null_mut()) - }; - // SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API - unsafe { - sys::ggml_set_scratch( - self.as_ptr(), - sys::ggml_scratch { - offs: 0, - size, - data, - }, - ); - } - } - /// Creates a new 1D tensor. pub fn new_tensor_1d(&self, typ: Type, ne0: usize) -> Tensor { let raw = unsafe { sys::ggml_new_tensor_1d(self.as_ptr(), typ.into(), usize_to_i64(ne0)) }; @@ -262,6 +266,12 @@ impl Context { pub fn storage(&self) -> &ContextStorage { self.storage.as_ref().unwrap() } + + /// Set all values of the tensor with the specified value. + pub fn set_f32(&self, a: &Tensor, x: f32) -> Tensor { + let raw = unsafe { sys::ggml_set_f32(a.ptr.as_ptr(), x) }; + self.new_tensor_raw(raw) + } } // Operations impl Context { @@ -279,7 +289,7 @@ impl Context { /// Creates a new tensor with the values of `a`, but normalized. pub fn op_norm(&self, a: &Tensor) -> Tensor { - let tensor = unsafe { sys::ggml_norm(self.as_ptr(), a.ptr.as_ptr()) }; + let tensor = unsafe { sys::ggml_norm(self.as_ptr(), a.ptr.as_ptr(), crate::DEFAULT_EPS) }; self.new_tensor_raw(tensor) } @@ -296,7 +306,7 @@ impl Context { self.new_tensor_raw(tensor) } - /// Repeats the `a` tensor along the first dimension of the `b` tensor. + /// Repeats the `a` tensor along the first dimension of the `b` tensor. pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { sys::ggml_repeat(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) @@ -614,6 +624,30 @@ impl Context { }; self.new_tensor_raw(tensor) } + + /// Creates a new tensor with the square of `a` + pub fn op_sqr(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqr(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Creates a new tensor with the square-root of `a` + pub fn op_sqrt(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqrt(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_sum(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sum(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_div(&self, a: &Tensor, b: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_div(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } } // Public to this crate methods impl Context { diff --git a/crates/ggml/src/format/loader.rs b/crates/ggml/src/format/loader.rs index 8a1a42ae..8b94e6a3 100644 --- a/crates/ggml/src/format/loader.rs +++ b/crates/ggml/src/format/loader.rs @@ -167,7 +167,7 @@ pub fn load( match container_type { ContainerType::Ggml | ContainerType::Ggmf(1) - | ContainerType::Ggjt(1 | 2 | 3) + | ContainerType::Ggjt(1..=3) | ContainerType::Ggla(1) => {} _ => return Err(LoadError::InvalidFormatVersion(container_type)), } diff --git a/crates/ggml/src/lib.rs b/crates/ggml/src/lib.rs index 8b6910eb..26bcc548 100644 --- a/crates/ggml/src/lib.rs +++ b/crates/ggml/src/lib.rs @@ -10,6 +10,8 @@ use std::{ alloc::Layout, os::raw::{c_int, c_void}, + ptr::NonNull, + sync::Arc, }; mod context; @@ -129,7 +131,13 @@ pub const OBJECT_SIZE: usize = sys::GGML_OBJECT_SIZE; pub const MAX_NAME_LENGTH: usize = sys::GGML_MAX_NAME as usize; /// Default epsilon to use for RMS computation. -pub const DEFAULT_EPS: f32 = sys::llama::LLAMA_DEFAULT_RMS_EPS as f32; +pub const DEFAULT_EPS: f32 = 0.000005; + +/// Maximum number of nodes in a `ggml` graph. +pub const MAX_NODES: usize = sys::GGML_MAX_NODES as usize; + +/// Alignment used for the Tensors in a `ggml` graph. +pub const TENSOR_ALIGNMENT: usize = 32; /// Value overrides to use for RoPE. /// @@ -280,10 +288,8 @@ impl Type { } } -/// A buffer of memory that can be used as a scratch buffer for a [Context]. -/// -/// See [Context::use_scratch]. -#[derive(PartialEq, Eq)] +/// A buffer of memory that can be used as a buffer for a [Context] or [GraphAllocator]. +#[derive(PartialEq, Eq, Debug)] pub struct Buffer { data: *mut c_void, layout: Layout, @@ -304,10 +310,27 @@ impl Buffer { } } + /// Creates a new buffer of the specified size, without aligning it. + pub fn new_unaligned(size: usize) -> Self { + let layout = Layout::from_size_align(size, 1).unwrap(); + + unsafe { + Buffer { + data: std::alloc::alloc(layout).cast(), + layout, + } + } + } + /// Returns the size of the buffer in bytes pub fn size(&self) -> usize { self.layout.size() } + + /// Returns a pointer to the data in this buffer. + pub fn data(&mut self) -> *mut c_void { + self.data + } } impl Drop for Buffer { @@ -333,6 +356,37 @@ impl ComputationGraph { pub fn build_forward_expand(&mut self, tensor: &Tensor) { unsafe { sys::ggml_build_forward_expand(self.inner, tensor.ptr.as_ptr()) } } + + /// Returns the leafs in this graph. + pub fn leafs(&self, context: &Context) -> Vec { + let mut wrapped_leafs: Vec = vec![]; + unsafe { + for leaf in self.inner.as_ref().unwrap().leafs { + if !leaf.is_null() { + wrapped_leafs.push(Tensor { + ptr: NonNull::new(leaf).expect("Should not be null"), + inner: Arc::downgrade(&context.inner), + }) + } + } + wrapped_leafs + } + } + /// Returns the nodes in this graph. + pub fn nodes(&self, context: &Context) -> Vec { + let mut wrapped_nodes: Vec = vec![]; + unsafe { + for leaf in self.inner.as_ref().unwrap().leafs { + if !leaf.is_null() { + wrapped_nodes.push(Tensor { + ptr: NonNull::new(leaf).expect("Should not be null"), + inner: Arc::downgrade(&context.inner), + }) + } + } + wrapped_nodes + } + } } /// A `ggml` execution plan. Contains the information needed to execute a computation graph. @@ -350,30 +404,79 @@ impl GraphExecutionPlan { } } - /// Creates a [Type::I8] work buffer with size `plan.work_size` for this [GraphExecutionPlan] in the given [Context]. - fn create_work_buffer(&mut self, context: &Context) -> Tensor { - context.new_tensor_1d(Type::I8, self.inner.work_size) - } + /// Execute this [GraphExecutionPlan] in the given [Context]. + pub fn execute(&mut self, buffer: &mut Vec) { + if self.inner.work_size > 0 { + buffer.resize(self.inner.work_size, 0); + self.inner.work_data = buffer.as_mut_ptr().cast(); + } - /// Assign a work buffer to this [GraphExecutionPlan]. - fn assign_work_buffer(&mut self, buffer: &mut Tensor) { - assert!( - buffer.get_type() == Type::I8, - "Work buffer must be of type i8" - ); unsafe { - self.inner.work_data = buffer.data().cast(); + sys::ggml_graph_compute(self.inner_graph, &mut self.inner); } } +} - /// Execute this [GraphExecutionPlan] in the given [Context]. - pub fn execute(&mut self, context: &Context) { - let mut work_buffer = self.create_work_buffer(context); - self.assign_work_buffer(&mut work_buffer); +#[derive(PartialEq, Eq, Debug)] +/// Acts as a RAII-guard over a `sys::ggml_allocr`, allocating via +/// `ggml_allocr_new` and dropping via `ggml_allocr_free`. +/// Used to allocate the memory used by a computational graph. +pub struct GraphAllocator { + /// The underlying `sys::ggml_allocr` pointer. + pub ptr: *mut sys::ggml_allocr, + /// The buffer used by this allocator. + pub buffer: Buffer, +} - unsafe { - sys::ggml_graph_compute(self.inner_graph, &mut self.inner); - } +impl GraphAllocator { + /// Create a new allocator with the specified buffer. + pub fn new(buffer: Buffer, tensor_alignment: usize) -> Self { + let ptr = unsafe { sys::ggml_allocr_new(buffer.data, buffer.size(), tensor_alignment) }; + Self { ptr, buffer } + } + + /// Create a new allocator to measure a computational graph. + pub fn new_measurement(tensor_alignment: usize) -> Self { + let ptr = unsafe { sys::ggml_allocr_new_measure(tensor_alignment) }; + let buffer = Buffer::new(tensor_alignment); + Self { ptr, buffer } + } + + /// Allocates a computational graph in the allocator and returns the size in bytes. + pub fn allocate_graph(&self, graph: &ComputationGraph) -> usize { + unsafe { sys::ggml_allocr_alloc_graph(self.ptr, graph.inner) } + } + + /// Resets the allocator for a new forward pass. + pub fn reset(&self) { + unsafe { sys::ggml_allocr_reset(self.ptr) } + } + + /// Returns true if the allocator is in measuring mode. + pub fn in_measuring_mode(&self) -> bool { + unsafe { sys::ggml_allocr_is_measure(self.ptr) } + } + + /// Allocates memory for a given tensor in the allocator. + pub fn allocate(&self, tensor: &Tensor) { + unsafe { sys::ggml_allocr_alloc(self.ptr, tensor.ptr.as_ptr()) } + } + + /// Switches the buffer used by the allocator. + pub fn resize_buffer(&mut self, graph_size: usize, tensor_alignment: usize) { + // Free the old allocator + unsafe { sys::ggml_allocr_free(self.ptr) } + //Resize the buffer + self.buffer = Buffer::new_unaligned(graph_size); + // Create a new allocator with the new buffer + self.ptr = + unsafe { sys::ggml_allocr_new(self.buffer.data, self.buffer.size(), tensor_alignment) }; + } +} + +impl Drop for GraphAllocator { + fn drop(&mut self) { + unsafe { sys::ggml_allocr_free(self.ptr) } } } @@ -496,3 +599,8 @@ pub fn cpu_has_gpublas() -> bool { pub fn graph_overhead() -> usize { unsafe { sys::ggml_graph_overhead() } } + +/// Returns the tensor overhead in bytes. +pub fn tensor_overhead() -> usize { + unsafe { sys::ggml_tensor_overhead() } +} diff --git a/crates/ggml/src/tensor.rs b/crates/ggml/src/tensor.rs index 33d7114c..ee5354c2 100644 --- a/crates/ggml/src/tensor.rs +++ b/crates/ggml/src/tensor.rs @@ -52,6 +52,11 @@ impl Tensor { }) } + /// Returns true if the 'extra' field of this tensor is set. e.g. by ggml-cuda + pub fn has_extras(&self) -> bool { + self.with_alive_ctx(|| unsafe { !self.ptr.as_ref().extra.is_null() }) + } + /// Sets the tensor's acceleration backend and moves the tensor's data to the new backend. pub fn transfer_to(mut self, backend: Backend) -> Tensor { self.with_alive_ctx_mut(|t| { @@ -88,7 +93,7 @@ impl Tensor { self.with_alive_ctx(|| { #[cfg(feature = "cublas")] unsafe { - sys::cuda::ggml_cuda_assign_buffers(self.ptr.as_ptr()); + sys::cuda::ggml_cuda_assign_buffers_no_alloc(self.ptr.as_ptr()); } }) } @@ -111,6 +116,18 @@ impl Tensor { }) } + /// If ggml-sys is compiled with CUDA support, this function will set the tensor's scratch offset. + /// If not, this is a no-op. + #[allow(unused_variables)] + pub fn assign_scratch_offset(&self, offset: usize) { + self.with_alive_ctx(|| { + #[cfg(feature = "cublas")] + unsafe { + sys::cuda::ggml_cuda_assign_scratch_offset(self.ptr.as_ptr(), offset); + } + }) + } + /// Creates a shared copy of this tensor pointer. pub fn share(&self) -> Self { Tensor { diff --git a/crates/ggml/sys/build.rs b/crates/ggml/sys/build.rs index f69cee3b..ba7e876b 100644 --- a/crates/ggml/sys/build.rs +++ b/crates/ggml/sys/build.rs @@ -12,8 +12,13 @@ fn main() { let mut builder = cc::Build::new(); let build = builder - .files(["llama-cpp/ggml.c", "llama-cpp/k_quants.c"]) + .files([ + "llama-cpp/ggml.c", + "llama-cpp/k_quants.c", + "llama-cpp/ggml-alloc.c", + ]) .define("GGML_USE_K_QUANTS", None) + .define("QK_K", Some("256")) .includes(["llama-cpp"]); // This is a very basic heuristic for applying compile flags. @@ -75,17 +80,11 @@ fn main() { } "aarch64" => { if compiler.is_like_clang() || compiler.is_like_gnu() { - if std::env::var("HOST") == std::env::var("TARGET") { + if target_os == "macos" { + build.flag("-mcpu=apple-m1"); + } else if std::env::var("HOST") == std::env::var("TARGET") { build.flag("-mcpu=native"); - } else { - #[allow(clippy::single_match)] - match target_os.as_str() { - "macos" => { - build.flag("-mcpu=apple-m1"); - build.flag("-mfpu=neon"); - } - _ => {} - } + build.flag("-mfpu=neon"); } build.flag("-pthread"); } @@ -93,6 +92,10 @@ fn main() { _ => {} } + if compiler.is_like_gnu() && target_os == "linux" { + build.define("_GNU_SOURCE", None); + } + if is_release { build.define("NDEBUG", None); } diff --git a/crates/ggml/sys/llama-cpp b/crates/ggml/sys/llama-cpp index 8183159c..da040034 160000 --- a/crates/ggml/sys/llama-cpp +++ b/crates/ggml/sys/llama-cpp @@ -1 +1 @@ -Subproject commit 8183159cf3def112f6d1fe94815fce70e1bffa12 +Subproject commit da0400344be12074e67dcabc565140289cf7efaa diff --git a/crates/ggml/sys/src/cuda.rs b/crates/ggml/sys/src/cuda.rs index a9ae1a8d..5208b66e 100644 --- a/crates/ggml/sys/src/cuda.rs +++ b/crates/ggml/sys/src/cuda.rs @@ -3,15 +3,17 @@ use super::ggml_compute_params; use super::ggml_tensor; +pub const GGML_CUDA_NAME: &[u8; 5usize] = b"CUDA\0"; +pub const GGML_CUBLAS_NAME: &[u8; 7usize] = b"cuBLAS\0"; pub const GGML_CUDA_MAX_DEVICES: u32 = 16; extern "C" { pub fn ggml_init_cublas(); } extern "C" { - pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32); + pub fn ggml_cuda_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; } extern "C" { - pub fn ggml_cuda_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor); + pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void); } extern "C" { pub fn ggml_cuda_can_mul_mat( @@ -21,26 +23,7 @@ extern "C" { ) -> bool; } extern "C" { - pub fn ggml_cuda_mul_mat_get_wsize( - src0: *const ggml_tensor, - src1: *const ggml_tensor, - dst: *mut ggml_tensor, - ) -> usize; -} -extern "C" { - pub fn ggml_cuda_mul_mat( - src0: *const ggml_tensor, - src1: *const ggml_tensor, - dst: *mut ggml_tensor, - wdata: *mut ::std::os::raw::c_void, - wsize: usize, - ); -} -extern "C" { - pub fn ggml_cuda_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; -} -extern "C" { - pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void); + pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32); } extern "C" { pub fn ggml_cuda_transform_tensor(data: *mut ::std::os::raw::c_void, tensor: *mut ggml_tensor); @@ -57,6 +40,12 @@ extern "C" { extern "C" { pub fn ggml_cuda_assign_buffers_force_inplace(tensor: *mut ggml_tensor); } +extern "C" { + pub fn ggml_cuda_assign_buffers_no_alloc(tensor: *mut ggml_tensor); +} +extern "C" { + pub fn ggml_cuda_assign_scratch_offset(tensor: *mut ggml_tensor, offset: usize); +} extern "C" { pub fn ggml_cuda_set_main_device(main_device: ::std::os::raw::c_int); } @@ -75,3 +64,13 @@ extern "C" { tensor: *mut ggml_tensor, ) -> bool; } +extern "C" { + pub fn ggml_cuda_get_device_count() -> ::std::os::raw::c_int; +} +extern "C" { + pub fn ggml_cuda_get_device_description( + device: ::std::os::raw::c_int, + description: *mut ::std::os::raw::c_char, + description_size: usize, + ); +} diff --git a/crates/ggml/sys/src/lib.rs b/crates/ggml/sys/src/lib.rs index 77b47802..71b34251 100644 --- a/crates/ggml/sys/src/lib.rs +++ b/crates/ggml/sys/src/lib.rs @@ -22,12 +22,17 @@ pub const GGML_MAX_NODES: u32 = 4096; pub const GGML_MAX_PARAMS: u32 = 256; pub const GGML_MAX_CONTEXTS: u32 = 64; pub const GGML_MAX_SRC: u32 = 6; -pub const GGML_MAX_NAME: u32 = 48; +pub const GGML_MAX_NAME: u32 = 64; pub const GGML_MAX_OP_PARAMS: u32 = 32; pub const GGML_DEFAULT_N_THREADS: u32 = 4; +pub const GGML_MEM_ALIGN: u32 = 16; pub const GGML_EXIT_SUCCESS: u32 = 0; pub const GGML_EXIT_ABORTED: u32 = 1; +pub const GGUF_MAGIC: u32 = 1179993927; +pub const GGUF_VERSION: u32 = 2; +pub const GGUF_DEFAULT_ALIGNMENT: u32 = 32; pub const GGML_GRAPH_HASHTABLE_SIZE: u32 = 8273; +pub const GGML_N_TASKS_MAX: i32 = -1; pub const QK_K: u32 = 256; pub const K_SCALE_SIZE: u32 = 12; pub type ggml_fp16_t = u16; @@ -103,49 +108,58 @@ pub const ggml_op_GGML_OP_MEAN: ggml_op = 13; pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 14; pub const ggml_op_GGML_OP_REPEAT: ggml_op = 15; pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 16; -pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 17; -pub const ggml_op_GGML_OP_NORM: ggml_op = 18; -pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 19; -pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 20; -pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 21; -pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 22; -pub const ggml_op_GGML_OP_SCALE: ggml_op = 23; -pub const ggml_op_GGML_OP_SET: ggml_op = 24; -pub const ggml_op_GGML_OP_CPY: ggml_op = 25; -pub const ggml_op_GGML_OP_CONT: ggml_op = 26; -pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 27; -pub const ggml_op_GGML_OP_VIEW: ggml_op = 28; -pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 29; -pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 30; -pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 31; -pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 32; -pub const ggml_op_GGML_OP_DIAG: ggml_op = 33; -pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 34; -pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 35; -pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 36; -pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 37; -pub const ggml_op_GGML_OP_ROPE: ggml_op = 38; -pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 39; -pub const ggml_op_GGML_OP_ALIBI: ggml_op = 40; -pub const ggml_op_GGML_OP_CLAMP: ggml_op = 41; -pub const ggml_op_GGML_OP_CONV_1D: ggml_op = 42; -pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 43; -pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 44; -pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 45; -pub const ggml_op_GGML_OP_FLASH_ATTN: ggml_op = 46; -pub const ggml_op_GGML_OP_FLASH_FF: ggml_op = 47; -pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 48; -pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 49; -pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 50; -pub const ggml_op_GGML_OP_UNARY: ggml_op = 51; -pub const ggml_op_GGML_OP_MAP_UNARY: ggml_op = 52; -pub const ggml_op_GGML_OP_MAP_BINARY: ggml_op = 53; -pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 54; -pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 55; -pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 56; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 57; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 58; -pub const ggml_op_GGML_OP_COUNT: ggml_op = 59; +pub const ggml_op_GGML_OP_CONCAT: ggml_op = 17; +pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 18; +pub const ggml_op_GGML_OP_NORM: ggml_op = 19; +pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 20; +pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 21; +pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 22; +pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 23; +pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 24; +pub const ggml_op_GGML_OP_SCALE: ggml_op = 25; +pub const ggml_op_GGML_OP_SET: ggml_op = 26; +pub const ggml_op_GGML_OP_CPY: ggml_op = 27; +pub const ggml_op_GGML_OP_CONT: ggml_op = 28; +pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 29; +pub const ggml_op_GGML_OP_VIEW: ggml_op = 30; +pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 31; +pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 32; +pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 33; +pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 34; +pub const ggml_op_GGML_OP_DIAG: ggml_op = 35; +pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 36; +pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 37; +pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 38; +pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 39; +pub const ggml_op_GGML_OP_ROPE: ggml_op = 40; +pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 41; +pub const ggml_op_GGML_OP_ALIBI: ggml_op = 42; +pub const ggml_op_GGML_OP_CLAMP: ggml_op = 43; +pub const ggml_op_GGML_OP_CONV_1D: ggml_op = 44; +pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 45; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 46; +pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 47; +pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 48; +pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 49; +pub const ggml_op_GGML_OP_FLASH_ATTN: ggml_op = 50; +pub const ggml_op_GGML_OP_FLASH_FF: ggml_op = 51; +pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 52; +pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 53; +pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 54; +pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 55; +pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 56; +pub const ggml_op_GGML_OP_UNARY: ggml_op = 57; +pub const ggml_op_GGML_OP_MAP_UNARY: ggml_op = 58; +pub const ggml_op_GGML_OP_MAP_BINARY: ggml_op = 59; +pub const ggml_op_GGML_OP_MAP_CUSTOM1_F32: ggml_op = 60; +pub const ggml_op_GGML_OP_MAP_CUSTOM2_F32: ggml_op = 61; +pub const ggml_op_GGML_OP_MAP_CUSTOM3_F32: ggml_op = 62; +pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 63; +pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 64; +pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 65; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 66; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 67; +pub const ggml_op_GGML_OP_COUNT: ggml_op = 68; pub type ggml_op = ::std::os::raw::c_uint; pub const ggml_unary_op_GGML_UNARY_OP_ABS: ggml_unary_op = 0; pub const ggml_unary_op_GGML_UNARY_OP_SGN: ggml_unary_op = 1; @@ -157,11 +171,15 @@ pub const ggml_unary_op_GGML_UNARY_OP_RELU: ggml_unary_op = 6; pub const ggml_unary_op_GGML_UNARY_OP_GELU: ggml_unary_op = 7; pub const ggml_unary_op_GGML_UNARY_OP_GELU_QUICK: ggml_unary_op = 8; pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 9; -pub type ggml_unary_op = ::std::os::raw::c_int; +pub type ggml_unary_op = ::std::os::raw::c_uint; pub const ggml_object_type_GGML_OBJECT_TENSOR: ggml_object_type = 0; pub const ggml_object_type_GGML_OBJECT_GRAPH: ggml_object_type = 1; pub const ggml_object_type_GGML_OBJECT_WORK_BUFFER: ggml_object_type = 2; -pub type ggml_object_type = ::std::os::raw::c_int; +pub type ggml_object_type = ::std::os::raw::c_uint; +pub const ggml_log_level_GGML_LOG_LEVEL_ERROR: ggml_log_level = 2; +pub const ggml_log_level_GGML_LOG_LEVEL_WARN: ggml_log_level = 3; +pub const ggml_log_level_GGML_LOG_LEVEL_INFO: ggml_log_level = 4; +pub type ggml_log_level = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_object { @@ -253,8 +271,10 @@ pub struct ggml_tensor { pub perf_runs: ::std::os::raw::c_int, pub perf_cycles: i64, pub perf_time_us: i64, + pub view_src: *mut ggml_tensor, + pub view_offs: usize, pub data: *mut ::std::os::raw::c_void, - pub name: [::std::os::raw::c_char; 48usize], + pub name: [::std::os::raw::c_char; 64usize], pub extra: *mut ::std::os::raw::c_void, pub padding: [::std::os::raw::c_char; 4usize], } @@ -264,7 +284,7 @@ fn bindgen_test_layout_ggml_tensor() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 272usize, + 304usize, concat!("Size of: ", stringify!(ggml_tensor)) ); assert_eq!( @@ -403,8 +423,28 @@ fn bindgen_test_layout_ggml_tensor() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).data) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).view_src) as usize - ptr as usize }, 200usize, + concat!( + "Offset of field: ", + stringify!(ggml_tensor), + "::", + stringify!(view_src) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).view_offs) as usize - ptr as usize }, + 208usize, + concat!( + "Offset of field: ", + stringify!(ggml_tensor), + "::", + stringify!(view_offs) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).data) as usize - ptr as usize }, + 216usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -414,7 +454,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).name) as usize - ptr as usize }, - 208usize, + 224usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -424,7 +464,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).extra) as usize - ptr as usize }, - 256usize, + 288usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -434,7 +474,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).padding) as usize - ptr as usize }, - 264usize, + 296usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -443,7 +483,7 @@ fn bindgen_test_layout_ggml_tensor() { ) ); } -pub const GGML_TENSOR_SIZE: usize = 272; +pub const GGML_TENSOR_SIZE: usize = 304; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_cplan { @@ -867,6 +907,9 @@ extern "C" { extern "C" { pub fn ggml_nbytes(tensor: *const ggml_tensor) -> usize; } +extern "C" { + pub fn ggml_nbytes_pad(tensor: *const ggml_tensor) -> usize; +} extern "C" { pub fn ggml_nbytes_split( tensor: *const ggml_tensor, @@ -909,6 +952,9 @@ extern "C" { extern "C" { pub fn ggml_is_permuted(tensor: *const ggml_tensor) -> bool; } +extern "C" { + pub fn ggml_are_same_shape(t0: *const ggml_tensor, t1: *const ggml_tensor) -> bool; +} extern "C" { pub fn ggml_tensor_overhead() -> usize; } @@ -991,7 +1037,7 @@ extern "C" { pub fn ggml_dup_tensor(ctx: *mut ggml_context, src: *const ggml_tensor) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_view_tensor(ctx: *mut ggml_context, src: *const ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_view_tensor(ctx: *mut ggml_context, src: *mut ggml_tensor) -> *mut ggml_tensor; } extern "C" { pub fn ggml_get_tensor( @@ -1187,6 +1233,13 @@ extern "C" { b: *mut ggml_tensor, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_concat( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_abs(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -1256,10 +1309,14 @@ extern "C" { ) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_norm(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_norm_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; } extern "C" { pub fn ggml_rms_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) @@ -1272,11 +1329,26 @@ extern "C" { eps: f32, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_group_norm( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_group_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_rms_norm_back( ctx: *mut ggml_context, a: *mut ggml_tensor, b: *mut ggml_tensor, + eps: f32, ) -> *mut ggml_tensor; } extern "C" { @@ -1591,6 +1663,16 @@ extern "C" { freq_scale: f32, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_rope_xpos_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + n_dims: ::std::os::raw::c_int, + base: f32, + down: bool, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_rope_back( ctx: *mut ggml_context, @@ -1599,6 +1681,10 @@ extern "C" { n_dims: ::std::os::raw::c_int, mode: ::std::os::raw::c_int, n_ctx: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + xpos_base: f32, + xpos_down: bool, ) -> *mut ggml_tensor; } extern "C" { @@ -1628,6 +1714,15 @@ extern "C" { d0: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_conv_1d_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s: ::std::os::raw::c_int, + d: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_conv_2d( ctx: *mut ggml_context, @@ -1642,18 +1737,31 @@ extern "C" { ) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_conv_1d_ph( + pub fn ggml_conv_2d_sk_p0( ctx: *mut ggml_context, a: *mut ggml_tensor, b: *mut ggml_tensor, - s: ::std::os::raw::c_int, - d: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_conv_2d_s1_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_conv_transpose_2d_p0( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + stride: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0; pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1; pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2; -pub type ggml_op_pool = ::std::os::raw::c_int; +pub type ggml_op_pool = ::std::os::raw::c_uint; extern "C" { pub fn ggml_pool_1d( ctx: *mut ggml_context, @@ -1677,6 +1785,13 @@ extern "C" { p1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_upscale( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + scale_factor: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_flash_attn( ctx: *mut ggml_context, @@ -1722,6 +1837,44 @@ extern "C" { w: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_unary( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_unary_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_get_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + qh: ::std::os::raw::c_int, + kh: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_add_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_add_rel_pos_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} pub type ggml_unary_op_f32_t = ::std::option::Option< unsafe extern "C" fn(arg1: ::std::os::raw::c_int, arg2: *mut f32, arg3: *const f32), >; @@ -1750,20 +1903,6 @@ pub type ggml_custom3_op_f32_t = ::std::option::Option< arg4: *const ggml_tensor, ), >; -extern "C" { - pub fn ggml_unary( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - op: ggml_unary_op, - ) -> *mut ggml_tensor; -} -extern "C" { - pub fn ggml_unary_inplace( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - op: ggml_unary_op, - ) -> *mut ggml_tensor; -} extern "C" { pub fn ggml_map_unary_f32( ctx: *mut ggml_context, @@ -1842,6 +1981,96 @@ extern "C" { fun: ggml_custom3_op_f32_t, ) -> *mut ggml_tensor; } +pub type ggml_custom1_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom2_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom3_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + c: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +extern "C" { + pub fn ggml_map_custom1( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom1_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom2( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom2_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom3( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom3_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_cross_entropy_loss( ctx: *mut ggml_context, @@ -1863,6 +2092,14 @@ extern "C" { extern "C" { pub fn ggml_build_forward_expand(cgraph: *mut ggml_cgraph, tensor: *mut ggml_tensor); } +extern "C" { + pub fn ggml_build_backward_expand( + ctx: *mut ggml_context, + gf: *mut ggml_cgraph, + gb: *mut ggml_cgraph, + keep: bool, + ); +} extern "C" { pub fn ggml_build_forward(tensor: *mut ggml_tensor) -> ggml_cgraph; } @@ -1952,6 +2189,15 @@ pub const ggml_opt_result_GGML_LINESEARCH_MAXIMUM_STEP: ggml_opt_result = -126; pub const ggml_opt_result_GGML_LINESEARCH_MAXIMUM_ITERATIONS: ggml_opt_result = -125; pub const ggml_opt_result_GGML_LINESEARCH_INVALID_PARAMETERS: ggml_opt_result = -124; pub type ggml_opt_result = ::std::os::raw::c_int; +pub type ggml_opt_callback = + ::std::option::Option; +pub type ggml_log_callback = ::std::option::Option< + unsafe extern "C" fn( + level: ggml_log_level, + text: *const ::std::os::raw::c_char, + user_data: *mut ::std::os::raw::c_void, + ), +>; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_opt_params { @@ -1971,12 +2217,14 @@ pub struct ggml_opt_params__bindgen_ty_1 { pub n_iter: ::std::os::raw::c_int, pub sched: f32, pub decay: f32, + pub decay_min_ndim: ::std::os::raw::c_int, pub alpha: f32, pub beta1: f32, pub beta2: f32, pub eps: f32, pub eps_f: f32, pub eps_g: f32, + pub gclip: f32, } #[test] fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { @@ -1985,7 +2233,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 36usize, + 44usize, concat!("Size of: ", stringify!(ggml_opt_params__bindgen_ty_1)) ); assert_eq!( @@ -2024,8 +2272,18 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).alpha) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).decay_min_ndim) as usize - ptr as usize }, 12usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_params__bindgen_ty_1), + "::", + stringify!(decay_min_ndim) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).alpha) as usize - ptr as usize }, + 16usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2035,7 +2293,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).beta1) as usize - ptr as usize }, - 16usize, + 20usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2045,7 +2303,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).beta2) as usize - ptr as usize }, - 20usize, + 24usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2055,7 +2313,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps) as usize - ptr as usize }, - 24usize, + 28usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2065,7 +2323,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps_f) as usize - ptr as usize }, - 28usize, + 32usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2075,7 +2333,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps_g) as usize - ptr as usize }, - 32usize, + 36usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2083,6 +2341,16 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { stringify!(eps_g) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).gclip) as usize - ptr as usize }, + 40usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_params__bindgen_ty_1), + "::", + stringify!(gclip) + ) + ); } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -2209,7 +2477,7 @@ fn bindgen_test_layout_ggml_opt_params() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 96usize, + 104usize, concat!("Size of: ", stringify!(ggml_opt_params)) ); assert_eq!( @@ -2299,7 +2567,7 @@ fn bindgen_test_layout_ggml_opt_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).lbfgs) as usize - ptr as usize }, - 60usize, + 68usize, concat!( "Offset of field: ", stringify!(ggml_opt_params), @@ -2316,19 +2584,16 @@ pub struct ggml_opt_context { pub iter: ::std::os::raw::c_int, pub nx: i64, pub just_initialized: bool, + pub loss_before: f32, + pub loss_after: f32, pub adam: ggml_opt_context__bindgen_ty_1, pub lbfgs: ggml_opt_context__bindgen_ty_2, } #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_opt_context__bindgen_ty_1 { - pub x: *mut ggml_tensor, - pub g1: *mut ggml_tensor, - pub g2: *mut ggml_tensor, pub m: *mut ggml_tensor, pub v: *mut ggml_tensor, - pub mh: *mut ggml_tensor, - pub vh: *mut ggml_tensor, pub pf: *mut ggml_tensor, pub fx_best: f32, pub fx_prev: f32, @@ -2341,7 +2606,7 @@ fn bindgen_test_layout_ggml_opt_context__bindgen_ty_1() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 80usize, + 40usize, concat!("Size of: ", stringify!(ggml_opt_context__bindgen_ty_1)) ); assert_eq!( @@ -2350,113 +2615,63 @@ fn bindgen_test_layout_ggml_opt_context__bindgen_ty_1() { concat!("Alignment of ", stringify!(ggml_opt_context__bindgen_ty_1)) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).x) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).m) as usize - ptr as usize }, 0usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(x) + stringify!(m) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).g1) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).v) as usize - ptr as usize }, 8usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(g1) + stringify!(v) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).g2) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).pf) as usize - ptr as usize }, 16usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(g2) + stringify!(pf) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).m) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).fx_best) as usize - ptr as usize }, 24usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(m) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).v) as usize - ptr as usize }, - 32usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(v) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).mh) as usize - ptr as usize }, - 40usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(mh) + stringify!(fx_best) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).vh) as usize - ptr as usize }, - 48usize, + unsafe { ::std::ptr::addr_of!((*ptr).fx_prev) as usize - ptr as usize }, + 28usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(vh) + stringify!(fx_prev) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).pf) as usize - ptr as usize }, - 56usize, + unsafe { ::std::ptr::addr_of!((*ptr).n_no_improvement) as usize - ptr as usize }, + 32usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(pf) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).fx_best) as usize - ptr as usize }, - 64usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(fx_best) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).fx_prev) as usize - ptr as usize }, - 68usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(fx_prev) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).n_no_improvement) as usize - ptr as usize }, - 72usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(n_no_improvement) + stringify!(n_no_improvement) ) ); } @@ -2662,7 +2877,7 @@ fn bindgen_test_layout_ggml_opt_context() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 312usize, + 288usize, concat!("Size of: ", stringify!(ggml_opt_context)) ); assert_eq!( @@ -2692,7 +2907,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).iter) as usize - ptr as usize }, - 104usize, + 112usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2702,7 +2917,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).nx) as usize - ptr as usize }, - 112usize, + 120usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2712,7 +2927,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).just_initialized) as usize - ptr as usize }, - 120usize, + 128usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2720,9 +2935,29 @@ fn bindgen_test_layout_ggml_opt_context() { stringify!(just_initialized) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).loss_before) as usize - ptr as usize }, + 132usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_context), + "::", + stringify!(loss_before) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).loss_after) as usize - ptr as usize }, + 136usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_context), + "::", + stringify!(loss_after) + ) + ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).adam) as usize - ptr as usize }, - 128usize, + 144usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2732,7 +2967,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).lbfgs) as usize - ptr as usize }, - 208usize, + 184usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2773,6 +3008,8 @@ extern "C" { f: *mut ggml_tensor, gf: *mut ggml_cgraph, gb: *mut ggml_cgraph, + callback: ggml_opt_callback, + callback_data: *mut ::std::os::raw::c_void, ) -> ggml_opt_result; } extern "C" { @@ -2830,6 +3067,282 @@ extern "C" { hist: *mut i64, ) -> usize; } +pub const gguf_type_GGUF_TYPE_UINT8: gguf_type = 0; +pub const gguf_type_GGUF_TYPE_INT8: gguf_type = 1; +pub const gguf_type_GGUF_TYPE_UINT16: gguf_type = 2; +pub const gguf_type_GGUF_TYPE_INT16: gguf_type = 3; +pub const gguf_type_GGUF_TYPE_UINT32: gguf_type = 4; +pub const gguf_type_GGUF_TYPE_INT32: gguf_type = 5; +pub const gguf_type_GGUF_TYPE_FLOAT32: gguf_type = 6; +pub const gguf_type_GGUF_TYPE_BOOL: gguf_type = 7; +pub const gguf_type_GGUF_TYPE_STRING: gguf_type = 8; +pub const gguf_type_GGUF_TYPE_ARRAY: gguf_type = 9; +pub const gguf_type_GGUF_TYPE_UINT64: gguf_type = 10; +pub const gguf_type_GGUF_TYPE_INT64: gguf_type = 11; +pub const gguf_type_GGUF_TYPE_FLOAT64: gguf_type = 12; +pub const gguf_type_GGUF_TYPE_COUNT: gguf_type = 13; +pub type gguf_type = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct gguf_context { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct gguf_init_params { + pub no_alloc: bool, + pub ctx: *mut *mut ggml_context, +} +#[test] +fn bindgen_test_layout_gguf_init_params() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(gguf_init_params)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(gguf_init_params)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).no_alloc) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(gguf_init_params), + "::", + stringify!(no_alloc) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ctx) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(gguf_init_params), + "::", + stringify!(ctx) + ) + ); +} +extern "C" { + pub fn gguf_init_empty() -> *mut gguf_context; +} +extern "C" { + pub fn gguf_init_from_file( + fname: *const ::std::os::raw::c_char, + params: gguf_init_params, + ) -> *mut gguf_context; +} +extern "C" { + pub fn gguf_free(ctx: *mut gguf_context); +} +extern "C" { + pub fn gguf_type_name(type_: gguf_type) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_version(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_alignment(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_data_offset(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_data(ctx: *const gguf_context) -> *mut ::std::os::raw::c_void; +} +extern "C" { + pub fn gguf_get_n_kv(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_find_key( + ctx: *const gguf_context, + key: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_key( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_kv_type(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> gguf_type; +} +extern "C" { + pub fn gguf_get_arr_type(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> gguf_type; +} +extern "C" { + pub fn gguf_get_val_u8(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u8; +} +extern "C" { + pub fn gguf_get_val_i8(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i8; +} +extern "C" { + pub fn gguf_get_val_u16(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u16; +} +extern "C" { + pub fn gguf_get_val_i16(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i16; +} +extern "C" { + pub fn gguf_get_val_u32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u32; +} +extern "C" { + pub fn gguf_get_val_i32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i32; +} +extern "C" { + pub fn gguf_get_val_f32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> f32; +} +extern "C" { + pub fn gguf_get_val_u64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u64; +} +extern "C" { + pub fn gguf_get_val_i64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i64; +} +extern "C" { + pub fn gguf_get_val_f64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> f64; +} +extern "C" { + pub fn gguf_get_val_bool(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> bool; +} +extern "C" { + pub fn gguf_get_val_str( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_arr_n( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_arr_data( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_void; +} +extern "C" { + pub fn gguf_get_arr_str( + ctx: *const gguf_context, + key_id: ::std::os::raw::c_int, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_n_tensors(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_find_tensor( + ctx: *const gguf_context, + name: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_tensor_offset(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> usize; +} +extern "C" { + pub fn gguf_get_tensor_name( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_set_val_u8(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u8); +} +extern "C" { + pub fn gguf_set_val_i8(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i8); +} +extern "C" { + pub fn gguf_set_val_u16(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u16); +} +extern "C" { + pub fn gguf_set_val_i16(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i16); +} +extern "C" { + pub fn gguf_set_val_u32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u32); +} +extern "C" { + pub fn gguf_set_val_i32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i32); +} +extern "C" { + pub fn gguf_set_val_f32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: f32); +} +extern "C" { + pub fn gguf_set_val_u64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u64); +} +extern "C" { + pub fn gguf_set_val_i64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i64); +} +extern "C" { + pub fn gguf_set_val_f64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: f64); +} +extern "C" { + pub fn gguf_set_val_bool(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: bool); +} +extern "C" { + pub fn gguf_set_val_str( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + val: *const ::std::os::raw::c_char, + ); +} +extern "C" { + pub fn gguf_set_arr_data( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + type_: gguf_type, + data: *const ::std::os::raw::c_void, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn gguf_set_arr_str( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + data: *mut *const ::std::os::raw::c_char, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn gguf_set_kv(ctx: *mut gguf_context, src: *mut gguf_context); +} +extern "C" { + pub fn gguf_add_tensor(ctx: *mut gguf_context, tensor: *const ggml_tensor); +} +extern "C" { + pub fn gguf_set_tensor_type( + ctx: *mut gguf_context, + name: *const ::std::os::raw::c_char, + type_: ggml_type, + ); +} +extern "C" { + pub fn gguf_set_tensor_data( + ctx: *mut gguf_context, + name: *const ::std::os::raw::c_char, + data: *const ::std::os::raw::c_void, + size: usize, + ); +} +extern "C" { + pub fn gguf_write_to_file( + ctx: *const gguf_context, + fname: *const ::std::os::raw::c_char, + only_meta: bool, + ); +} +extern "C" { + pub fn gguf_get_meta_size(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_meta_data(ctx: *const gguf_context, data: *mut ::std::os::raw::c_void); +} extern "C" { pub fn ggml_cpu_has_avx() -> ::std::os::raw::c_int; } @@ -2854,6 +3367,9 @@ extern "C" { extern "C" { pub fn ggml_cpu_has_arm_fma() -> ::std::os::raw::c_int; } +extern "C" { + pub fn ggml_cpu_has_metal() -> ::std::os::raw::c_int; +} extern "C" { pub fn ggml_cpu_has_f16c() -> ::std::os::raw::c_int; } @@ -2878,6 +3394,9 @@ extern "C" { extern "C" { pub fn ggml_cpu_has_sse3() -> ::std::os::raw::c_int; } +extern "C" { + pub fn ggml_cpu_has_ssse3() -> ::std::os::raw::c_int; +} extern "C" { pub fn ggml_cpu_has_vsx() -> ::std::os::raw::c_int; } @@ -2898,6 +3417,10 @@ pub type ggml_vec_dot_t = ::std::option::Option< #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_type_traits_t { + pub type_name: *const ::std::os::raw::c_char, + pub blck_size: ::std::os::raw::c_int, + pub type_size: usize, + pub is_quantized: bool, pub to_float: ggml_to_float_t, pub from_float: ggml_from_float_t, pub from_float_reference: ggml_from_float_t, @@ -2910,7 +3433,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 40usize, + 72usize, concat!("Size of: ", stringify!(ggml_type_traits_t)) ); assert_eq!( @@ -2919,8 +3442,48 @@ fn bindgen_test_layout_ggml_type_traits_t() { concat!("Alignment of ", stringify!(ggml_type_traits_t)) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).to_float) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).type_name) as usize - ptr as usize }, 0usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(type_name) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).blck_size) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(blck_size) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).type_size) as usize - ptr as usize }, + 16usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(type_size) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).is_quantized) as usize - ptr as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(is_quantized) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).to_float) as usize - ptr as usize }, + 32usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2930,7 +3493,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).from_float) as usize - ptr as usize }, - 8usize, + 40usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2940,7 +3503,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).from_float_reference) as usize - ptr as usize }, - 16usize, + 48usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2950,7 +3513,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).vec_dot) as usize - ptr as usize }, - 24usize, + 56usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2960,7 +3523,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).vec_dot_type) as usize - ptr as usize }, - 32usize, + 64usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2970,7 +3533,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); } extern "C" { - pub fn ggml_internal_get_type_traits(i: ggml_type) -> ggml_type_traits_t; + pub fn ggml_internal_get_type_traits(type_: ggml_type) -> ggml_type_traits_t; } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -3513,3 +4076,40 @@ extern "C" { hist: *mut i64, ) -> usize; } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_allocr { + _unused: [u8; 0], +} +extern "C" { + pub fn ggml_allocr_new( + data: *mut ::std::os::raw::c_void, + size: usize, + alignment: usize, + ) -> *mut ggml_allocr; +} +extern "C" { + pub fn ggml_allocr_new_measure(alignment: usize) -> *mut ggml_allocr; +} +extern "C" { + pub fn ggml_allocr_set_parse_seq( + alloc: *mut ggml_allocr, + list: *const ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn ggml_allocr_free(alloc: *mut ggml_allocr); +} +extern "C" { + pub fn ggml_allocr_is_measure(alloc: *mut ggml_allocr) -> bool; +} +extern "C" { + pub fn ggml_allocr_reset(alloc: *mut ggml_allocr); +} +extern "C" { + pub fn ggml_allocr_alloc(alloc: *mut ggml_allocr, tensor: *mut ggml_tensor); +} +extern "C" { + pub fn ggml_allocr_alloc_graph(alloc: *mut ggml_allocr, graph: *mut ggml_cgraph) -> usize; +} diff --git a/crates/ggml/sys/src/llama.rs b/crates/ggml/sys/src/llama.rs index a8aa42ef..d3552cd9 100644 --- a/crates/ggml/sys/src/llama.rs +++ b/crates/ggml/sys/src/llama.rs @@ -1,18 +1,10 @@ /* automatically generated by rust-bindgen 0.65.1 */ pub const LLAMA_MAX_DEVICES: u32 = 1; -pub const LLAMA_FILE_MAGIC_GGJT: u32 = 1734830708; -pub const LLAMA_FILE_MAGIC_GGLA: u32 = 1734831201; -pub const LLAMA_FILE_MAGIC_GGMF: u32 = 1734831462; -pub const LLAMA_FILE_MAGIC_GGML: u32 = 1734831468; +pub const LLAMA_DEFAULT_SEED: u32 = 4294967295; pub const LLAMA_FILE_MAGIC_GGSN: u32 = 1734833006; -pub const LLAMA_FILE_VERSION: u32 = 3; -pub const LLAMA_FILE_MAGIC: u32 = 1734830708; -pub const LLAMA_FILE_MAGIC_UNVERSIONED: u32 = 1734831468; pub const LLAMA_SESSION_MAGIC: u32 = 1734833006; pub const LLAMA_SESSION_VERSION: u32 = 1; -pub const LLAMA_DEFAULT_SEED: u32 = 4294967295; -pub const LLAMA_DEFAULT_RMS_EPS: f64 = 0.000005; pub const LLAMA_FTYPE_ALL_F32: llama_ftype = 0; pub const LLAMA_FTYPE_MOSTLY_F16: llama_ftype = 1; pub const LLAMA_FTYPE_MOSTLY_Q4_0: llama_ftype = 2; @@ -30,4 +22,5 @@ pub const LLAMA_FTYPE_MOSTLY_Q4_K_M: llama_ftype = 15; pub const LLAMA_FTYPE_MOSTLY_Q5_K_S: llama_ftype = 16; pub const LLAMA_FTYPE_MOSTLY_Q5_K_M: llama_ftype = 17; pub const LLAMA_FTYPE_MOSTLY_Q6_K: llama_ftype = 18; -pub type llama_ftype = ::std::os::raw::c_int; +pub const LLAMA_FTYPE_GUESSED: llama_ftype = 1024; +pub type llama_ftype = ::std::os::raw::c_uint; diff --git a/crates/ggml/sys/src/metal.rs b/crates/ggml/sys/src/metal.rs index bbd16034..e2d7c621 100644 --- a/crates/ggml/sys/src/metal.rs +++ b/crates/ggml/sys/src/metal.rs @@ -1,15 +1,16 @@ /* automatically generated by rust-bindgen 0.65.1 */ +use super::ggml_tensor; +use super::ggml_log_callback; +use super::ggml_cgraph; + pub const GGML_METAL_MAX_BUFFERS: u32 = 16; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct ggml_tensor { - _unused: [u8; 0], -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct ggml_cgraph { - _unused: [u8; 0], +pub const GGML_METAL_MAX_COMMAND_BUFFERS: u32 = 32; +extern "C" { + pub fn ggml_metal_log_set_callback( + log_callback: ggml_log_callback, + user_data: *mut ::std::os::raw::c_void, + ); } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -22,6 +23,12 @@ extern "C" { extern "C" { pub fn ggml_metal_free(ctx: *mut ggml_metal_context); } +extern "C" { + pub fn ggml_metal_host_malloc(n: usize) -> *mut ::std::os::raw::c_void; +} +extern "C" { + pub fn ggml_metal_host_free(data: *mut ::std::os::raw::c_void); +} extern "C" { pub fn ggml_metal_set_n_cb(ctx: *mut ggml_metal_context, n_cb: ::std::os::raw::c_int); } @@ -41,10 +48,17 @@ extern "C" { pub fn ggml_metal_get_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor); } extern "C" { - pub fn ggml_metal_graph_find_concurrency(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); + pub fn ggml_metal_graph_find_concurrency( + ctx: *mut ggml_metal_context, + gf: *mut ggml_cgraph, + check_mem: bool, + ); +} +extern "C" { + pub fn ggml_metal_if_optimized(ctx: *mut ggml_metal_context) -> ::std::os::raw::c_int; } extern "C" { - pub fn ggml_metal_if_optimized(ctx: *mut ggml_metal_context) -> bool; + pub fn ggml_metal_get_concur_list(ctx: *mut ggml_metal_context) -> *mut ::std::os::raw::c_int; } extern "C" { pub fn ggml_metal_graph_compute(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 12f66f5e..0c9f3a5b 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -1,4 +1,4 @@ -use ggml::{Buffer, ComputationGraph, Context, GraphExecutionPlan, Tensor}; +use ggml::{Buffer, ComputationGraph, Context, GraphAllocator, GraphExecutionPlan, Tensor}; use serde::Serialize; use std::{cell::RefCell, fmt::Display, sync::Arc}; use thiserror::Error; @@ -8,25 +8,10 @@ use tracing::{instrument, log}; use ggml::accelerator::metal::MetalContext; use crate::{ - mulf, util, InferenceParameters, Model, ModelParameters, OutputRequest, Prompt, TokenId, - TokenUtf8Buffer, TokenizationError, + mulf, util, InferenceParameters, Model, ModelContext, ModelParameters, OutputRequest, Prompt, + TokenId, TokenUtf8Buffer, TokenizationError, }; -// The size of a scratch buffer used for inference. This is used for temporary -// storage of intermediate results during inference. -// -// The specific value was copied from `llama.cpp`. -const SCRATCH_SIZE: usize = 512 * 1024 * 1024; - -type ScratchBuffers = [ggml::Buffer; 2]; - -fn scratch_buffers() -> ScratchBuffers { - [ - ggml::Buffer::new(SCRATCH_SIZE), - ggml::Buffer::new(SCRATCH_SIZE), - ] -} - /// Result of graph building pub struct GraphOutputs { /// The output containing the model's result @@ -34,6 +19,9 @@ pub struct GraphOutputs { /// The output containing embeddings pub embedding_result: Tensor, + + /// The length of the output + pub output_length: usize, } /// An inference session represents the state of the text generation. This holds @@ -66,7 +54,7 @@ pub struct InferenceSession { /// How many tokens have been fed into the model's working memory so far. #[doc(hidden)] - pub n_past: usize, + n_past: usize, /// How much memory is required per token for the temporary context used /// during inference. @@ -90,21 +78,32 @@ pub struct InferenceSession { n_embd: usize, - scratch: ScratchBuffers, + /// Allocator used by this session + allocator: GraphAllocator, + + ///Context size of this session + context_size: usize, + + /// Work buffer for graph planing + work_buffer: Vec, + + /// If the session can use the gpu + use_gpu: bool, } pub struct BuildContext<'session> { //FIXME: Borrowing issue, dont know how to fix it pub ctx0: RefCell<&'session mut Context>, + pub allocator: RefCell<&'session GraphAllocator>, pub embd: &'session Tensor, pub memory_k: &'session Tensor, pub memory_v: &'session Tensor, - pub scratch: &'session ScratchBuffers, + pub n_past: usize, } impl<'session> BuildContext<'session> { - pub fn get_scratch(&self, idx: usize) -> Option<&Buffer> { - Some(&self.scratch[idx]) + pub fn input_length(&self) -> usize { + self.embd.nelements() } } @@ -124,7 +123,7 @@ impl InferenceSession { .. } = *params; - let context_byte_size = { + let cache_byte_size = { let mut size = 0; size += mulf!( context_size, @@ -138,49 +137,48 @@ impl InferenceSession { n_embd, ggml::type_sizef(config.memory_v_type.into()) ); // memory_v - size += (5 + 10 * n_layer) * 256; // object overhead + size += 2 * 1024 * 1024; // overhead size }; + log::info!( + "Allocating {:.2} MB for KV-memory", + cache_byte_size / (1024 * 1024) + ); + if use_gpu { ggml::accelerator::initialize(0); - ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024); + ggml::accelerator::set_scratch_size(0); } - let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size)); + // TODO: revisit this with `Rc`, maybe? We should be able to prove that the session + // context is only accessed from one thread at a time, but I've already spent enough + // time on this as-is. + #[allow(clippy::arc_with_non_send_sync)] + let session_ctx = Arc::new(ggml::Context::new_with_allocate(cache_byte_size)); // Initialize key + value memory tensors let n_mem = n_layer * context_size; let n_elements = n_embd * n_mem; let (memory_k, memory_v) = kv_memory(&session_ctx, &config, use_gpu, n_elements); - let scratch = scratch_buffers(); - - // Allocate buffer for storing intermediate values during evaluation (ctx0 backing) - // For the first run, we need to guess a maximum buffer size so we can measure - // the actual memory consumption of the temporary ggml context. - // - // These numbers are from `llama.cpp`, and could potentially be more efficient. - let buf_size = { - let buf_size_mb = if n_layer >= 80 { - 1536 - } else if n_layer >= 60 { - 1280 - } else { - 1024 - }; - buf_size_mb * 1024 * 1024 + ggml::graph_overhead() - }; - + // Allocate buffer for storing tensor and graph structs + let buf_size = ggml::graph_overhead() + (ggml::tensor_overhead() * ggml::MAX_NODES); let eval = Buffer::new(buf_size); - let ctx0 = ggml::Context::new_with_buffer(eval); + log::info!( + "Allocating {:.2} MB for eval-context", + buf_size / (1024 * 1024) + ); + + let ctx0 = ggml::Context::new_with_buffer(eval, false); + let allocator = GraphAllocator::new_measurement(ggml::TENSOR_ALIGNMENT); // Set up Metal support #[cfg(feature = "metal")] let metal_context = { if use_gpu { - let mut metal_context = MetalContext::new(config.n_threads); + let mut metal_context = MetalContext::new(); metal_context.add_scratch_buffer(ctx0.storage().as_buffer().unwrap()); for buf in scratch.iter() { @@ -195,7 +193,7 @@ impl InferenceSession { InferenceSession { _session_ctx: session_ctx, - _memory_size: context_byte_size, + _memory_size: cache_byte_size, config, memory_k, memory_v, @@ -208,50 +206,124 @@ impl InferenceSession { metal_context, ctx0, n_embd, - scratch, + allocator, + context_size, + work_buffer: vec![0], + use_gpu, } } /// Compute a model (possibly building a graph in the provided closure when called for the first time and/or when parameters have) pub fn compute( &mut self, - #[allow(unused_variables)] model_context: Arc, + #[allow(unused_variables)] model_context: ModelContext, input_tokens: &[TokenId], builder: F, ) -> GraphOutputs where - F: FnOnce(BuildContext) -> (ComputationGraph, GraphOutputs), + F: Fn(BuildContext) -> (ComputationGraph, GraphOutputs), { - // Build a graph + // Check if we need to allocate the graph + if self.allocator.in_measuring_mode() { + // Build a graph + self.ctx0.recreate(); + let ctx0 = &mut self.ctx0; + + // If we are in measuring mode, we need to build a "worst case" graph, meaning the input has either `batch_size` or `context_size` tokens. + let max_n_tokens = self.config.n_batch.min(self.context_size); + // We assume the history is full + let max_n_past = self.context_size - max_n_tokens; + let embd = ctx0 + .new_tensor_1d(ggml::Type::I32, max_n_tokens) + .set_name("embd"); + + self.allocator.allocate(&embd); + + let bc = BuildContext { + ctx0: RefCell::new(ctx0), + allocator: RefCell::new(&self.allocator), + embd: &embd, + memory_k: &self.memory_k, + memory_v: &self.memory_v, + n_past: max_n_past, + }; + + let (mut worst_case_graph, built_result) = builder(bc); + // Expand the graph + worst_case_graph.build_forward_expand(&built_result.result); + + // Allocate the graph + let graph_size = + self.allocator.allocate_graph(&worst_case_graph) + ggml::TENSOR_ALIGNMENT; + log::info!("Allocating {:.2} MB for graph", graph_size / (1024 * 1024)); + // Pre-allocate the buffer for future use + self.allocator + .resize_buffer(graph_size, ggml::TENSOR_ALIGNMENT); + + if self.use_gpu { + ggml::accelerator::set_scratch_size(graph_size); + } + } + + // Reset the context and allocator self.ctx0.recreate(); + self.allocator.reset(); let ctx0 = &mut self.ctx0; + let mut embd = ctx0 .new_tensor_1d(ggml::Type::I32, input_tokens.len()) .set_name("embd"); + self.allocator.allocate(&embd); + let bc = BuildContext { ctx0: RefCell::new(ctx0), + allocator: RefCell::new(&self.allocator), embd: &embd, memory_k: &self.memory_k, memory_v: &self.memory_v, - scratch: &mut self.scratch, + n_past: self.n_past, }; + let (mut built_gf, built_result) = builder(bc); + // Build the graph + built_gf.build_forward_expand(&built_result.result); + + // Allocate the graph + self.allocator.allocate_graph(&built_gf); + + #[cfg(feature = "cublas")] + { + for mut leaf in built_gf.leafs(&ctx0) { + if leaf.backend() == ggml::accelerator::Backend::Gpu && !leaf.has_extras() { + unsafe { + let offset = leaf.data().offset_from(self.allocator.buffer.data()) as usize; + leaf.assign_scratch_offset(offset); + } + } + } + + for mut node in built_gf.nodes(&ctx0) { + if node.backend() == ggml::accelerator::Backend::Gpu && !node.has_extras() { + unsafe { + let offset = node.data().offset_from(self.allocator.buffer.data()) as usize; + node.assign_scratch_offset(offset); + } + } + } + } // Do Metal'y stuff #[cfg(feature = "metal")] { if let Some(ref mut metal_context) = self.metal_context { - metal_context.add_context(model_context); + metal_context.add_context(model_context.0); } } // Write input tokens unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; - // Compute the graph - built_gf.build_forward_expand(&built_result.result); - #[cfg(feature = "metal")] { // FIXME can only process one token at a time currently @@ -272,7 +344,7 @@ impl InferenceSession { #[cfg(not(feature = "metal"))] { let mut plan = GraphExecutionPlan::new(&mut built_gf, self.config.n_threads); - plan.execute(ctx0); + plan.execute(&mut self.work_buffer); } // Adjust the required memory per token if we didn't know that already @@ -287,6 +359,7 @@ impl InferenceSession { GraphOutputs { result: built_result.result.share(), embedding_result: built_result.embedding_result.share(), + output_length: input_tokens.len(), } } @@ -399,7 +472,7 @@ impl InferenceSession { } // Remove the tokens from self.tokens. - let token_start = self.n_past - num; + let token_start = self.tokens.len() - num; let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); // Remove the corresponding chars from decoded @@ -764,7 +837,7 @@ impl InferenceSession { config: self.config, tokens: self.tokens.clone(), decoded_tokens: self.decoded_tokens.clone(), - logits: self.last_logits.clone(), + last_logits: self.last_logits.clone(), memory_k, memory_v, } @@ -905,7 +978,7 @@ pub struct InferenceSnapshotRef<'a> { /// All decoded tokens generated by this inference session. pub decoded_tokens: Vec, /// The vector of logits that was produced after the last inference. - pub logits: Vec, + pub last_logits: Vec, /// The contents of the 'key' memory tensor. #[serde(with = "serde_bytes")] pub memory_k: &'a [u8], @@ -923,7 +996,7 @@ impl InferenceSnapshotRef<'_> { config: self.config, tokens: self.tokens.clone(), decoded_tokens: self.decoded_tokens.clone(), - last_logits: self.logits.clone(), + last_logits: self.last_logits.clone(), memory_k: self.memory_k.to_vec(), memory_v: self.memory_v.to_vec(), } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index dd13d7bc..f0a88a8a 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -35,7 +35,7 @@ pub use loader::{ }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; -pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest}; +pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest}; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use regex::Regex; pub use tokenizer::{ @@ -60,7 +60,7 @@ pub struct InferenceParameters { /// This can be anything that implements [Sampler]. Refer to /// the `llm-samplers` documentation for possible samplers and suggested /// combinations: - pub sampler: Arc>>, + pub sampler: Arc>, } //Since Sampler implements Send and Sync, InferenceParameters should too. diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index c98ab048..2e80495c 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -5,11 +5,12 @@ use std::{ fs::File, io::{BufRead, BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, + sync::Arc, }; use crate::{ - util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId, - Tokenizer, TokenizerLoadError, TokenizerSource, + util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters, + TokenId, Tokenizer, TokenizerLoadError, TokenizerSource, }; pub use ggml::{format::FormatMagic, ContainerType}; use ggml::{ @@ -31,13 +32,13 @@ pub struct FileType { impl From for i32 { fn from(value: FileType) -> Self { (value.quantization_version * ggml::QNT_VERSION_FACTOR) as i32 - + ggml::sys::llama::llama_ftype::from(value.format) + + ggml::sys::llama::llama_ftype::from(value.format) as i32 } } -impl TryFrom for FileType { +impl TryFrom for FileType { type Error = (); - fn try_from(value: i32) -> Result { + fn try_from(value: u32) -> Result { let format = FileTypeFormat::try_from( ((value as u32) % ggml::QNT_VERSION_FACTOR) as ggml::sys::llama::llama_ftype, )?; @@ -251,7 +252,7 @@ pub enum LoadError { #[error("unsupported ftype: {0}")] /// The `ftype` hyperparameter had an invalid value. This usually means that the format used /// by this file is unrecognized by this version of `llm`. - UnsupportedFileType(i32), + UnsupportedFileType(u32), #[error("invalid magic number {magic} for {path:?}")] /// An invalid magic number was encountered during the loading process. InvalidMagic { @@ -398,7 +399,7 @@ pub trait TensorLoader { /// Gets a tensor from the loader. fn load(&mut self, name: &str) -> Result; /// Finish loading the model, returning the context. - fn finish(self) -> Context; + fn finish(self) -> ModelContext; } /// Load a GGML model from the `path` and configure it per the `params`. The status @@ -653,12 +654,7 @@ impl TensorLoader for MmapCompatibleLoader<'_> { path: Default::default(), })?; - let mut main_context = FileContext::new( - &self.context, - &mut self.file, - &self.path, - self.context.storage().as_mmap(), - ); + let mut main_context = FileContext::new(&self.context, &mut self.file, &self.path); let mut tensor = main_context.get_tensor(info)?; @@ -681,8 +677,11 @@ impl TensorLoader for MmapCompatibleLoader<'_> { Ok(tensor) } - fn finish(self) -> Context { - self.context + fn finish(self) -> ModelContext { + // We can ignore this warning as it's OK to share this particular + // context around, being that it is immutable. + #[allow(clippy::arc_with_non_send_sync)] + ModelContext(Arc::new(self.context)) } } @@ -690,20 +689,13 @@ pub(crate) struct FileContext<'a> { context: &'a Context, file: &'a mut File, path: &'a Path, - mmap: Option<&'a Mmap>, } impl<'a> FileContext<'a> { - pub(crate) fn new( - context: &'a Context, - file: &'a mut File, - path: &'a Path, - mmap: Option<&'a Mmap>, - ) -> Self { + pub(crate) fn new(context: &'a Context, file: &'a mut File, path: &'a Path) -> Self { Self { context, file, path, - mmap, } } @@ -738,7 +730,7 @@ impl<'a> FileContext<'a> { } }; - match self.mmap { + match self.context.storage().as_mmap() { Some(mmap) => unsafe { let ptr = mmap.as_ptr().offset(info.start_offset as isize); tensor.set_data(ptr as *mut std::ffi::c_void); diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index b6ed4a0f..f433931e 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -106,7 +106,7 @@ impl LoraAdapter { // Create a temporary context for the patching operations // TODO: test if GPU can be enabled (make it configurable) let patch_context = ggml::Context::new_with_allocate(patch_context_size); - let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None); + let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path); // Load the A and B tensors let a = patch_file.get_tensor(&a_info)?; @@ -128,8 +128,9 @@ impl LoraAdapter { gf.build_forward_expand(&output); //TODO: maybe pass the model's thread count to this context + let mut work_buffer = vec![0u8]; let mut plan = GraphExecutionPlan::new(&mut gf, 8); - plan.execute(&patch_context); + plan.execute(&mut work_buffer); // Overwrite the original tensor. // The `output` and the `target_tensor` are not from the same context, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index b31faf56..ab30e4f2 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -5,6 +5,7 @@ use std::{ fmt::Debug, io::{BufRead, Write}, path::{Path, PathBuf}, + sync::Arc, }; use ggml::accelerator::Backend; @@ -263,3 +264,13 @@ pub struct OutputRequest { /// `n_batch * n_embd`. pub embeddings: Option>, } + +/// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync` +/// to allow for the free transfer of models; this is made possible by this +/// context being effectively inert after creation, so that it cannot be +/// modified across threads. +#[derive(Clone)] +#[allow(clippy::arc_with_non_send_sync)] +pub struct ModelContext(pub(crate) Arc); +unsafe impl Send for ModelContext {} +unsafe impl Sync for ModelContext {} diff --git a/crates/llm-base/src/samplers.rs b/crates/llm-base/src/samplers.rs index 7a179f0b..f0b07b9e 100644 --- a/crates/llm-base/src/samplers.rs +++ b/crates/llm-base/src/samplers.rs @@ -59,7 +59,7 @@ pub enum SamplingError { /// to ensure a valid configuration. pub struct ConfiguredSamplers { /// A builder from the `llm-samplers` crate. - pub builder: SamplerChainBuilder, + pub builder: SamplerChainBuilder, /// Mirostat 1 is present. pub mirostat1: bool, /// Mirostat 2 is present. @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers { /// We call a configuration of samplers that run in a certain order a "chain". /// Here is a description of the default chain `llm` uses: /// -/// 1. Repetition (present by default, multiple allowed) -/// 2. Frequency/Presence (optional, multiple allowed) -/// 3. Sequence Repetition (optional, multiple allowed) -/// 4. Top-K (present by default - incompatible with Mirostat) -/// 5. Tail Free (optional - incompatible with Mirostat) -/// 6. Locally Typical (optional - incompatible with Mirostat) -/// 7. Top-P (present by default - incompatible with Mirostat) -/// 8. Temperature (present by default) -/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. +/// 1. Repetition (present by default, multiple allowed) +/// 2. Frequency/Presence (optional, multiple allowed) +/// 3. Sequence Repetition (optional, multiple allowed) +/// 4. Top-K (present by default - incompatible with Mirostat) +/// 5. Tail Free (optional - incompatible with Mirostat) +/// 6. Locally Typical (optional - incompatible with Mirostat) +/// 7. Top-P (present by default - incompatible with Mirostat) +/// 8. Top-A (optional - incompatible with Mirostat) +/// 9. Min-P (optional - incompatible with Mirostat) +/// 10. Temperature (present by default) +/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. /// /// Samplers listed as "present by default" but incompatible with Mirostat will /// only be enabled by default if there is no Mirostat sampler enabled. @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers { Option::::None, ), ), + ( + "topa", + SamplerSlot::new_single( + || Box::new(SampleTopA::default().a1(0.0).a2(0.0)), + Option::::None, + ), + ), + ( + "minp", + SamplerSlot::new_single( + || Box::new(SampleMinP::default().p(0.0)), + Option::::None, + ), + ), ( "temperature", SamplerSlot::new_single( @@ -203,7 +219,7 @@ impl ConfiguredSamplers { ))? } else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat { Err(SamplerConfigurationError::SamplerCombinationError( - "Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), + "Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), ))? } Ok(()) @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers { .inspect(|(name, _slot)| match name.as_str() { "mirostat1" => result.mirostat1 = true, "mirostat2" => result.mirostat2 = true, - "topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true, + "topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => { + result.incompat_mirostat = true + } _ => (), }) .collect::>(); @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers { /// Sample a token. This convenience function handles building /// the sampler resources and logits objects the sampler needs. pub fn sample_token( - mut sampler: impl Sampler, + mut sampler: impl Sampler, rng: &mut impl rand::Rng, previous_tokens: &[TokenId], last_logits: impl IntoIterator, @@ -297,7 +315,7 @@ pub fn build_sampler( n_vocab: usize, bias: &[(TokenId, f32)], args: &[impl AsRef], -) -> Result>>, SamplerConfigurationError> { +) -> Result>, SamplerConfigurationError> { let mut samplers = SamplerChain::new(); if !bias.is_empty() { @@ -326,7 +344,7 @@ pub fn build_sampler( } /// Get the default sampler chain. -pub fn default_samplers() -> Arc>> { +pub fn default_samplers() -> Arc> { let mut result = ConfiguredSamplers::default(); result.ensure_default_slots(); Arc::new(Mutex::new(result.builder.into_chain())) @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> { } impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { - type TokenId = TokenId; - fn with_rng_mut( &mut self, fun: &mut dyn FnMut(&mut dyn rand::RngCore), @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { Ok(()) } - fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> { + fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> { fun(self.previous_tokens); Ok(()) } diff --git a/crates/llm-base/src/util.rs b/crates/llm-base/src/util.rs index 55cda41c..70fe2994 100644 --- a/crates/llm-base/src/util.rs +++ b/crates/llm-base/src/util.rs @@ -28,7 +28,7 @@ use crate::{FileType, LoadError}; /// Read the filetype from a reader. pub fn read_filetype(reader: &mut dyn BufRead) -> Result { - let ftype = read_i32(reader)?; + let ftype = read_u32(reader)?; FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype)) } @@ -117,19 +117,17 @@ fn collect_related_paths( p.file_name() .and_then(|p| p.to_str()) .zip(main_filename) - .map(|(part_filename, main_filename)| { - match part_filename.strip_prefix(main_filename) { - Some(suffix) => { - suffix.is_empty() - || (suffix - .strip_prefix('.') - .map(|s| s.parse::().is_ok()) - .unwrap_or(false)) - } - None => false, + .map_or(false, |(part_filename, main_filename)| match part_filename + .strip_prefix(main_filename) + { + Some(suffix) => { + suffix.is_empty() + || (suffix + .strip_prefix('.') + .map_or(false, |s| s.parse::().is_ok())) } + None => false, }) - .unwrap_or(false) }) .collect(); paths.sort(); diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 0f395f5a..efff39e5 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -16,6 +16,7 @@ llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" } llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" } llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } +llm-bert = { path = "../models/bert", optional = true, version = "0.2.0-dev" } serde = { workspace = true } tracing = { workspace = true } @@ -34,13 +35,14 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] +models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "bert"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] bloom = ["dep:llm-bloom"] gptneox = ["dep:llm-gptneox"] mpt = ["dep:llm-mpt"] +bert = ["dep:llm-bert"] # Falcon is off by default. See `llm_falcon`'s module documentation for more information. falcon = ["dep:llm-falcon"] diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index febe2441..14800686 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -172,6 +172,7 @@ macro_rules! define_models { } define_models!( + (bert, "bert", Bert, llm_bert, "Bert"), (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), diff --git a/crates/models/bert/Cargo.toml b/crates/models/bert/Cargo.toml new file mode 100644 index 00000000..0be81b40 --- /dev/null +++ b/crates/models/bert/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "llm-bert" +version = "0.2.0-dev" +license = { workspace = true } +repository = { workspace = true } +description = "An implementation of BERT for the `llm` ecosystem." +edition = "2021" +readme = "../../../README.md" + +[dependencies] +bytemuck.workspace = true +llm-base = { path = "../../llm-base", version = "0.2.0-dev" } +tracing = { version = "0.1", features = ["log"] } + diff --git a/crates/models/bert/src/lib.rs b/crates/models/bert/src/lib.rs new file mode 100644 index 00000000..9a8daf6e --- /dev/null +++ b/crates/models/bert/src/lib.rs @@ -0,0 +1,464 @@ +//! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. +#![deny(missing_docs)] + +use std::error::Error; + +use llm_base::{ + ggml, + model::{common, HyperparametersWriteError}, + util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, +}; + +/// The BERT model. +/// +/// # Safety +/// This implements [Send] and [Sync] as it is immutable after construction. +pub struct Bert { + params: ModelParameters, + hyperparameters: Hyperparameters, + tokenizer: Tokenizer, + + word_embeddings: ggml::Tensor, + token_type_embeddings: ggml::Tensor, + position_embeddings: ggml::Tensor, + ln_e_w: ggml::Tensor, + ln_e_b: ggml::Tensor, + + // weights for the model + layers: Vec, + + // must be kept alive for the model + context: ModelContext, +} + +unsafe impl Send for Bert {} +unsafe impl Sync for Bert {} + +/// BERT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct Hyperparameters { + /// Size of the model's vocabulary + pub n_vocab: usize, + + /// Maximum number of tokens + pub n_max_tokens: usize, + + /// Size of the model's embedding layer + pub n_embd: usize, + + /// n_head + pub n_intermediate: usize, + + /// Number of attention heads + pub n_head: usize, + + /// Number of layers in the model + pub n_layer: usize, + + /// file_type + pub file_type: FileType, +} + +impl KnownModel for Bert { + type Hyperparameters = Hyperparameters; + + fn new( + hyperparameters: Self::Hyperparameters, + params: ModelParameters, + tokenizer: Tokenizer, + tensor_loader: impl TensorLoader, + ) -> Result { + let mut tl = tensor_loader; + + let word_embeddings = tl.load("embeddings.word_embeddings.weight")?; + let token_type_embeddings = tl.load("embeddings.token_type_embeddings.weight")?; + let position_embeddings = tl.load("embeddings.position_embeddings.weight")?; + + let ln_e_w = tl.load("embeddings.LayerNorm.weight")?; + let ln_e_b = tl.load("embeddings.LayerNorm.bias")?; + + let mut layers = Vec::new(); + + for i in 0..hyperparameters.n_layer { + let backend = params.backend(i); + + let layer = Layer { + ln_att_w: tl + .load(&format!( + "encoder.layer.{i}.attention.output.LayerNorm.weight" + ))? + .transfer_to(backend), + ln_att_b: tl + .load(&format!( + "encoder.layer.{i}.attention.output.LayerNorm.bias" + ))? + .transfer_to(backend), + + // attention + q_w: tl + .load(&format!("encoder.layer.{i}.attention.self.query.weight"))? + .transfer_to(backend), + q_b: tl + .load(&format!("encoder.layer.{i}.attention.self.query.bias"))? + .transfer_to(backend), + k_w: tl + .load(&format!("encoder.layer.{i}.attention.self.key.weight"))? + .transfer_to(backend), + k_b: tl + .load(&format!("encoder.layer.{i}.attention.self.key.bias"))? + .transfer_to(backend), + v_w: tl + .load(&format!("encoder.layer.{i}.attention.self.value.weight"))? + .transfer_to(backend), + v_b: tl + .load(&format!("encoder.layer.{i}.attention.self.value.bias"))? + .transfer_to(backend), + + o_w: tl + .load(&format!("encoder.layer.{i}.attention.output.dense.weight"))? + .transfer_to(backend), + o_b: tl + .load(&format!("encoder.layer.{i}.attention.output.dense.bias"))? + .transfer_to(backend), + + // ff + ff_i_w: tl + .load(&format!("encoder.layer.{i}.intermediate.dense.weight"))? + .transfer_to(backend), + ff_i_b: tl + .load(&format!("encoder.layer.{i}.intermediate.dense.bias"))? + .transfer_to(backend), + + ln_out_w: tl + .load(&format!("encoder.layer.{i}.output.LayerNorm.weight"))? + .transfer_to(backend), + ln_out_b: tl + .load(&format!("encoder.layer.{i}.output.LayerNorm.bias"))? + .transfer_to(backend), + ff_o_w: tl + .load(&format!("encoder.layer.{i}.output.dense.weight"))? + .transfer_to(backend), + ff_o_b: tl + .load(&format!("encoder.layer.{i}.output.dense.bias"))? + .transfer_to(backend), + }; + + layers.push(layer); + } + let context = tl.finish(); + + Ok(Self { + ln_e_b, + ln_e_w, + position_embeddings, + token_type_embeddings, + word_embeddings, + hyperparameters, + params, + tokenizer, + layers, + context, + }) + } + + /// Starts a new `InferenceSession` for this model. + fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { + InferenceSession::new( + config, + &self.params, + self.hyperparameters.n_layer, + self.hyperparameters.n_embd, + self.hyperparameters.n_vocab, + ) + } + + #[tracing::instrument(level = "trace", skip_all)] + fn evaluate( + &self, + session: &mut InferenceSession, + input_tokens: &[TokenId], + output_request: &mut OutputRequest, + ) { + let input_len = input_tokens.len(); + let _ctx_size = self.params.context_size; + + let Hyperparameters { + n_vocab, + n_max_tokens: _, + n_embd, + n_intermediate: _, + n_head, + n_layer, + file_type: _, + } = self.hyperparameters; + + let d_head = n_embd / n_head; + + let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let mut ctx0 = builder.ctx0.borrow_mut(); + let gf = ctx0.create_compute_graph(); + + let embd = builder.embd; + + let mut input_layer = ctx0.op_get_rows(&self.word_embeddings, embd); + + // IL = word_embeddings + token_types + position_embeddingso + { + // token-types: a zero tensor + let mut token_types = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); + token_types.zero_data(); + + // position embeddings: another tensor + let position_buf: Vec = (0..input_len as i32).collect(); + let mut positions = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); + unsafe { positions.write_data(bytemuck::cast_slice(&position_buf)) }; + + // IL += token_types + input_layer = ctx0.op_add( + &input_layer, + &ctx0.op_get_rows(&self.token_type_embeddings, &token_types), + ); + + // IL += position_embeddings + input_layer = ctx0.op_add( + &input_layer, + &ctx0.op_get_rows(&self.position_embeddings, &positions), + ); + } + + // embd norm + { + input_layer = ctx0.op_norm(&input_layer); + input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_e_w), &self.ln_e_b); + } + + for il in 0..n_layer { + ctx0.set_offloading(self.params.should_offload(il)); + + let mut current = input_layer.share(); + + // self-attention + { + let q_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].q_w, ¤t), + &self.layers[il].q_b, + ), + d_head, + n_head, + input_len, + ); + let q = ctx0.op_permute(&q_current, (0, 2, 1, 3)); + + let k_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].k_w, ¤t), + &self.layers[il].k_b, + ), + d_head, + n_head, + input_len, + ); + let k = ctx0.op_permute(&k_current, (0, 2, 1, 3)); + + let v_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].v_w, ¤t), + &self.layers[il].v_b, + ), + d_head, + n_head, + input_len, + ); + let mut v = ctx0.op_permute(&v_current, (0, 2, 1, 3)); + + let mut kq = ctx0.op_mul_mat(&k, &q); + + // TODO: look into op_scale_inplace and op_soft_max_inplace + kq = ctx0.op_scale( + &kq, + &ctx0.new_f32(1.0 / ((n_embd as f32 / n_head as f32).sqrt())), + ); + kq = ctx0.op_soft_max(&kq); + + v = ctx0.op_cont(&ctx0.op_transpose(&v)); + + let mut kqv = ctx0.op_mul_mat(&v, &kq); + kqv = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + + current = ctx0.op_cpy( + &kqv, + &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), + ); + } + + // attention output + current = ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].o_w, ¤t), + &self.layers[il].o_b, + ); + + // re-add the layer input + current = ctx0.op_add(¤t, &input_layer); + + // attention norm + { + current = ctx0.op_norm(¤t); + current = ctx0.op_add( + &ctx0.op_mul(¤t, &self.layers[il].ln_att_w), + &self.layers[il].ln_att_b, + ); + } + + let att_output = current.share(); + + // intermediate output + current = ctx0.op_mul_mat(&self.layers[il].ff_i_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].ff_i_b); + current = ctx0.op_gelu(¤t); + + // layer output + current = ctx0.op_mul_mat(&self.layers[il].ff_o_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].ff_o_b); + + // attentions bypass the intermediate layer + current = ctx0.op_add(&att_output, ¤t); + + // output norm + { + current = ctx0.op_norm(¤t); + current = ctx0.op_add( + &ctx0.op_mul(¤t, &self.layers[il].ln_out_w), + &self.layers[il].ln_out_b, + ); + } + + // input for next layer + input_layer = current; + } + input_layer = ctx0.op_cont(&ctx0.op_transpose(&input_layer)); + + ctx0.set_offloading(false); + // pooler + let mut sum = ctx0.new_tensor_2d(llm_base::ElementType::F32, input_len, 1); + sum = ctx0.set_f32(&sum, 1.0 / (input_len as f32)); + input_layer = ctx0.op_mul_mat(&input_layer, &sum); + + // normalizer + let length = ctx0.op_sqrt(&ctx0.op_sum(&ctx0.op_sqr(&input_layer))); + + input_layer = ctx0.op_scale(&input_layer, &ctx0.op_div(&ctx0.new_f32(1.0), &length)); + + ( + gf, + GraphOutputs { + result: input_layer.share(), + embedding_result: input_layer.share(), + output_length: input_len, + }, + ) + }); + + // finish evaluation + common::read_last_token(session, &outputs.result, n_vocab, input_len); + common::extract_logits(output_request, &outputs.result, n_vocab, input_len); + common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, 1); + } + + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + + fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn context_size(&self) -> usize { + self.params.context_size + } + + fn bot_token_id(&self) -> Option { + self.tokenizer.id("[PAD]".as_bytes()) + } + + fn eot_token_id(&self) -> TokenId { + self.tokenizer.id("".as_bytes()).unwrap_or(2) + } + + fn quantize_tensors() -> Vec { + vec![Regex::new(".*weight").unwrap()] + } + + fn skip_quantize_tensors() -> Vec { + vec![] + } + + fn supports_rewind(&self) -> bool { + true + } +} + +impl llm_base::Hyperparameters for Hyperparameters { + fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { + Ok(Hyperparameters { + n_vocab: util::read_i32(reader)?.try_into()?, + n_max_tokens: util::read_i32(reader)?.try_into()?, + n_embd: util::read_i32(reader)?.try_into()?, + n_intermediate: util::read_i32(reader)?.try_into()?, + n_head: util::read_i32(reader)?.try_into()?, + n_layer: util::read_i32(reader)?.try_into()?, + file_type: util::read_filetype(reader)?, + }) + } + + fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { + util::write_i32(writer, self.n_vocab.try_into()?)?; + util::write_i32(writer, self.n_max_tokens.try_into()?)?; + util::write_i32(writer, self.n_embd.try_into()?)?; + util::write_i32(writer, self.n_intermediate.try_into()?)?; + util::write_i32(writer, self.n_head.try_into()?)?; + util::write_i32(writer, self.n_layer.try_into()?)?; + util::write_i32(writer, self.file_type.into())?; + Ok(()) + } + + fn n_vocabulary(&self) -> usize { + self.n_vocab + } + + fn file_type(&self) -> Option { + Some(self.file_type) + } + + fn file_type_mut(&mut self) -> Option<&mut FileType> { + Some(&mut self.file_type) + } +} + +struct Layer { + // normalization + ln_att_w: ggml::Tensor, + ln_att_b: ggml::Tensor, + + ln_out_w: ggml::Tensor, + ln_out_b: ggml::Tensor, + + // attention + q_w: ggml::Tensor, + q_b: ggml::Tensor, + k_w: ggml::Tensor, + k_b: ggml::Tensor, + v_w: ggml::Tensor, + v_b: ggml::Tensor, + + o_w: ggml::Tensor, + o_b: ggml::Tensor, + + // ff + ff_i_w: ggml::Tensor, + ff_i_b: ggml::Tensor, + + ff_o_w: ggml::Tensor, + ff_o_b: ggml::Tensor, +} diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 349ebab6..fb26ff3d 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -2,13 +2,11 @@ //! for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom) @@ -37,7 +35,7 @@ pub struct Bloom { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Bloom {} @@ -101,7 +99,7 @@ impl KnownModel for Bloom { output_norm_bias, output, layers, - context: Arc::new(context), + context, }) } @@ -121,8 +119,6 @@ impl KnownModel for Bloom { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -135,6 +131,8 @@ impl KnownModel for Bloom { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let session_len = builder.n_past; + let input_len = builder.input_length(); let ctx0 = builder.ctx0.borrow(); let (memory_k_size, memory_v_size) = ( builder.memory_k.element_size(), @@ -333,14 +331,25 @@ impl KnownModel for Bloom { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 914c22bb..f9f6c5d7 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -7,14 +7,12 @@ //! supported. It is currently only available as a preview. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae) @@ -39,7 +37,7 @@ pub struct Falcon { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Falcon {} @@ -138,7 +136,7 @@ impl KnownModel for Falcon { output_norm_b, lm_head, layers, - context: Arc::new(context), + context, }) } @@ -158,8 +156,6 @@ impl KnownModel for Falcon { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -172,9 +168,12 @@ impl KnownModel for Falcon { } = self.hyperparameters; let head_dim = n_embd / n_head; - let n = input_len; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let input_len = builder.input_length(); + let n = input_len; + let session_len = builder.n_past; + let mut ctx0 = builder.ctx0.borrow_mut(); let embd = builder.embd; let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); @@ -194,7 +193,6 @@ impl KnownModel for Falcon { for il in 0..n_layer { // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); ctx0.set_offloading(self.params.should_offload(il)); // self-attention @@ -321,9 +319,6 @@ impl KnownModel for Falcon { // projection current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - // feed forward uses second scratch buffer - ctx0.use_scratch(builder.get_scratch(1)); - let inp_ff = layernorm_output.share(); let attn_out = ctx0.op_cpy(¤t, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); @@ -338,8 +333,6 @@ impl KnownModel for Falcon { input_layer = current.share(); } - ctx0.use_scratch(builder.get_scratch(0)); - // norm input_layer = ctx0.op_norm(&input_layer); @@ -351,7 +344,6 @@ impl KnownModel for Falcon { let embeddings_tensor: ggml::Tensor = input_layer.share(); ctx0.set_offloading(false); - ctx0.use_scratch(None); // lm_head input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); @@ -361,14 +353,25 @@ impl KnownModel for Falcon { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: n, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index ccd1d012..d06eb1ec 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -1,14 +1,12 @@ //! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) @@ -38,7 +36,7 @@ pub struct Gpt2 { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Gpt2 {} @@ -123,7 +121,7 @@ impl KnownModel for Gpt2 { wte, wpe, lm_head, - context: Arc::new(context), + context, }) } @@ -143,8 +141,6 @@ impl KnownModel for Gpt2 { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -156,6 +152,8 @@ impl KnownModel for Gpt2 { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let input_len = builder.input_length(); + let session_len = builder.n_past; let mut ctx0 = builder.ctx0.borrow_mut(); let (memory_k_size, memory_v_size) = ( builder.memory_k.element_size(), @@ -176,7 +174,7 @@ impl KnownModel for Gpt2 { let mut gf = ctx0.create_compute_graph(); for il in 0..n_layer { ctx0.set_offloading(self.params.should_offload(il)); - ctx0.use_scratch(builder.get_scratch(0)); + // norm let mut current = ctx0.op_norm(&input_layer); current = ctx0.op_add( @@ -283,8 +281,6 @@ impl KnownModel for Gpt2 { // feed-forward let ff_in = current.share(); - ctx0.use_scratch(builder.get_scratch(1)); - // feed-forward normalization current = ctx0.op_norm(&ff_in); current = ctx0.op_add( @@ -307,13 +303,10 @@ impl KnownModel for Gpt2 { input_layer = ctx0.op_add(¤t, &ff_in); } - ctx0.use_scratch(builder.get_scratch(0)); - // normalization input_layer = ctx0.op_norm(&input_layer); input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); - ctx0.use_scratch(None); ctx0.set_offloading(false); let embeddings_tensor: ggml::Tensor = input_layer.share(); @@ -326,14 +319,25 @@ impl KnownModel for Gpt2 { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index b5fd4fc5..b4ee3d82 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -1,14 +1,14 @@ //! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b) @@ -35,7 +35,7 @@ pub struct GptJ { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for GptJ {} @@ -117,7 +117,7 @@ impl KnownModel for GptJ { lmh_g, lmh_b, layers, - context: Arc::new(context), + context, }) } @@ -137,8 +137,6 @@ impl KnownModel for GptJ { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -151,6 +149,9 @@ impl KnownModel for GptJ { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let input_len = builder.input_length(); + let session_len = builder.n_past; + let mut ctx0 = builder.ctx0.borrow_mut(); let (memory_k_size, memory_v_size) = ( builder.memory_k.element_size(), @@ -300,14 +301,25 @@ impl KnownModel for GptJ { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index b420ec13..e355fe22 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -2,14 +2,14 @@ //! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) @@ -35,7 +35,7 @@ pub struct GptNeoX { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for GptNeoX {} @@ -137,7 +137,7 @@ impl KnownModel for GptNeoX { wte, lmh_g, layers, - context: Arc::new(context), + context, }) } @@ -159,8 +159,6 @@ impl KnownModel for GptNeoX { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let n = input_tokens.len(); - let n_past = session.n_past; let n_ctx = self.params.context_size; let Hyperparameters { @@ -174,6 +172,9 @@ impl KnownModel for GptNeoX { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let n = builder.input_length(); + let n_past = builder.n_past; + let mut ctx0 = builder.ctx0.borrow_mut(); let embd = builder.embd; let mut input_layer = ctx0.op_get_rows(&self.wte, embd); @@ -186,8 +187,6 @@ impl KnownModel for GptNeoX { for il in 0..n_layer { ctx0.set_offloading(self.params.should_offload(il)); - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); // self-attention let mut current = ctx0.op_norm(&input_layer); @@ -301,9 +300,6 @@ impl KnownModel for GptNeoX { current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); - // use the second scratch for the feed forward - ctx0.use_scratch(builder.get_scratch(1)); - let feedforward_input: Tensor; if !use_parallel_residual { feedforward_input = ctx0.op_add(¤t, &input_layer); @@ -326,9 +322,6 @@ impl KnownModel for GptNeoX { } } - // use the first scratch for the norm - ctx0.use_scratch(builder.get_scratch(0)); - // normalize the output input_layer = ctx0.op_norm(&input_layer); // inpL = ln_f_g*inpL + ln_f_b @@ -336,8 +329,6 @@ impl KnownModel for GptNeoX { let embeddings_tensor: ggml::Tensor = input_layer.share(); - // Disable the scratchbuffer - ctx0.use_scratch(None); ctx0.set_offloading(false); // apply language model head input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); @@ -347,14 +338,25 @@ impl KnownModel for GptNeoX { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: n, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, n); - common::extract_logits(output_request, &outputs.result, n_vocab, n); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index db78b3a0..69ab5aa8 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -1,13 +1,13 @@ //! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use llm_base::{ ggml::{self}, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) @@ -31,7 +31,7 @@ pub struct Llama { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Llama {} @@ -125,7 +125,7 @@ impl KnownModel for Llama { norm, output, layers, - context: Arc::new(context), + context, }) } @@ -147,8 +147,6 @@ impl KnownModel for Llama { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -164,7 +162,11 @@ impl KnownModel for Llama { let n_embd_gqa = n_embd / (n_head / n_head_kv); let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let session_len = builder.n_past; + let input_len = builder.input_length(); + let mut ctx0 = builder.ctx0.borrow_mut(); + let embd = builder.embd; let mut input_layer = ctx0.op_get_rows(&self.wte, embd); @@ -177,8 +179,6 @@ impl KnownModel for Llama { let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; - ctx0.use_scratch(builder.get_scratch(0)); - // norm current = ctx0.op_rms_norm(&input_layer); @@ -309,8 +309,6 @@ impl KnownModel for Llama { // projection (no bias) current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - ctx0.use_scratch(builder.get_scratch(1)); - let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); // feed-forward network @@ -337,8 +335,6 @@ impl KnownModel for Llama { input_layer = current; } - ctx0.use_scratch(builder.get_scratch(0)); - // norm input_layer = ctx0.op_rms_norm(&input_layer); @@ -351,20 +347,30 @@ impl KnownModel for Llama { // lm_head input_layer = ctx0.op_mul_mat(&self.output, &input_layer); - ctx0.use_scratch(None); ( gf, GraphOutputs { result: input_layer, embedding_result, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 351ddd6c..1e52d2d0 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -1,14 +1,12 @@ //! An implementation of [MPT](https://huggingface.co/mosaicml) for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml::{self}, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b) @@ -31,7 +29,7 @@ pub struct Mpt { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Mpt {} @@ -78,7 +76,7 @@ impl KnownModel for Mpt { wte, norm, layers, - context: Arc::new(context), + context, }) } @@ -98,8 +96,6 @@ impl KnownModel for Mpt { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let n = input_tokens.len(); - let session_len = session.n_past; let ctx_size = self.params.context_size; let Hyperparameters { @@ -112,6 +108,8 @@ impl KnownModel for Mpt { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let n = builder.input_length(); + let session_len = builder.n_past; let ctx0 = builder.ctx0.borrow(); let (memory_k_size, memory_v_size) = ( builder.memory_k.element_size(), @@ -125,9 +123,6 @@ impl KnownModel for Mpt { let mut gf = ctx0.create_compute_graph(); for il in 0..n_layer { - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); - let mut current = ctx0.op_norm(&input_layer); current = ctx0.op_mul(¤t, &self.layers[il].norm_1_weight); @@ -215,9 +210,6 @@ impl KnownModel for Mpt { input_layer = ctx0.op_add(&input_layer, ¤t); - // feed forward uses second scratch buffer - ctx0.use_scratch(builder.get_scratch(1)); - current = ctx0.op_norm(&input_layer); current = ctx0.op_mul(¤t, &self.layers[il].norm_2_weight); @@ -231,17 +223,12 @@ impl KnownModel for Mpt { input_layer = ctx0.op_add(&input_layer, ¤t); } - //use scratch buffer 0 for the rest - ctx0.use_scratch(builder.get_scratch(0)); - // norm input_layer = ctx0.op_norm(&input_layer); input_layer = ctx0.op_mul(&input_layer, &self.norm); let embeddings_tensor: ggml::Tensor = input_layer.share(); - // disable scratch buffer for last layer - ctx0.use_scratch(None); // output embedding weight tied to input embedding input_layer = ctx0.op_mul_mat(&self.wte, &input_layer); @@ -250,14 +237,25 @@ impl KnownModel for Mpt { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: n, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, n); - common::extract_logits(output_request, &outputs.result, n_vocab, n); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); + common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); + common::extract_logits( + output_request, + &outputs.result, + n_vocab, + outputs.output_length, + ); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + n_embd, + outputs.output_length, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters {