Skip to content

Commit

Permalink
Switch to rusty-ggml as GGML source (#8)
Browse files Browse the repository at this point in the history
Many other changes/improvements
  • Loading branch information
KerfuffleV2 authored Jun 16, 2023
1 parent fb62c2e commit 27da93f
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 638 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["smolrwkv", "smolrwkv-cli"]
resolver = "2"

[workspace.package]
version = "0.4.1"
version = "0.5.0"

[profile.release]
# Needed for profiling. Shouldn't affect performance.
Expand Down
2 changes: 1 addition & 1 deletion smolrwkv-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rand = "0.8"
clap = { version = "4.2", features=["derive", "cargo"]}
tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] }
tracing = "0.1"
memmap2 = "0.5"
memmap2 = "0.7"
# tracing-flame = "0.2"

smolrwkv = { path = "../smolrwkv" }
43 changes: 42 additions & 1 deletion smolrwkv-cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const DEFAULT_MODEL: &str = "./RWKV-4-Pile-430M-20220808-8066.safetensors";
/// Tokenizer definition file. See README.
const DEFAULT_TOKENIZER: &str = "./20B_tokenizer.json";

#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, ValueEnum)]
pub enum EvalType {
#[value(name = "ndf32")]
Expand All @@ -24,16 +25,56 @@ pub enum EvalType {
/// GGML-backed 32 bit. As above, uses a lot of memory.
GGMLf32,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq8_0")]
/// GGML-backed 8 bit quantized, method 1.
GGMLQ8_0,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq4_0")]
/// GGML-backed 4 bit quantized, method 1. Poor quality.
GGMLQ4_0,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq4_1")]
/// GGML-backed 4 bit quantized, method 2. Decenent quality,
/// GGML-backed 4 bit quantized, method 2. Decent quality,
/// but slower (to load?)
GGMLQ4_1,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq5_0")]
/// GGML-backed 5 bit quantized, method 1.
GGMLQ5_0,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq5_1")]
/// GGML-backed 5 bit quantized, method 2.
GGMLQ5_1,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq2_k")]
/// GGML-backed k_quants 2 bit.
GGMLQ2_K,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq3_k")]
/// GGML-backed k_quants 3 bit.
GGMLQ3_K,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq4_k")]
/// GGML-backed k_quants 4 bit.
GGMLQ4_K,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq5_k")]
/// GGML-backed k_quants 5 bit.
GGMLQ5_K,

#[cfg(feature = "ggml")]
#[value(name = "ggmlq6_k")]
/// GGML-backed k_quants 6 bit.
GGMLQ6_K,
}

#[derive(Clone, Debug, Parser)]
Expand Down
34 changes: 26 additions & 8 deletions smolrwkv-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,43 @@ fn go() -> Result<()> {
})?)
}
#[cfg(feature = "ggml")]
args::EvalType::GGMLf32 | args::EvalType::GGMLQ4_0 | args::EvalType::GGMLQ4_1 => {
args::EvalType::GGMLf32
| args::EvalType::GGMLQ8_0
| args::EvalType::GGMLQ4_0
| args::EvalType::GGMLQ4_1
| args::EvalType::GGMLQ5_0
| args::EvalType::GGMLQ2_K
| args::EvalType::GGMLQ3_K
| args::EvalType::GGMLQ4_K
| args::EvalType::GGMLQ5_K
| args::EvalType::GGMLQ5_1
| args::EvalType::GGMLQ6_K => {
use smolrwkv::ggml::{
context::RWKVContext,
loader::{load_rwkv, RwkvGgmlType},
loader::{load_rwkv, ElementType},
};

let wtype = match args.eval_mode {
args::EvalType::GGMLf32 => RwkvGgmlType::Float32,
args::EvalType::GGMLQ4_0 => RwkvGgmlType::Q4_0,
args::EvalType::GGMLQ4_1 => RwkvGgmlType::Q4_1,
args::EvalType::GGMLf32 => ElementType::F32,
args::EvalType::GGMLQ8_0 => ElementType::Q8_0,
args::EvalType::GGMLQ4_0 => ElementType::Q4_0,
args::EvalType::GGMLQ4_1 => ElementType::Q4_1,
args::EvalType::GGMLQ5_0 => ElementType::Q5_0,
args::EvalType::GGMLQ5_1 => ElementType::Q5_1,
args::EvalType::GGMLQ2_K => ElementType::Q2_K,
args::EvalType::GGMLQ3_K => ElementType::Q3_K,
args::EvalType::GGMLQ4_K => ElementType::Q4_K,
args::EvalType::GGMLQ5_K => ElementType::Q5_K,
args::EvalType::GGMLQ6_K => ElementType::Q6_K,
_ => panic!("Impossible: Bad eval mode!"),
};
info!("Backend type: GGML {wtype:?}");
let ltensors = load_rwkv(args.max_load_threads, RwkvGgmlType::Float32, wtype, tdm)?;
let ltensors = load_rwkv(args.max_load_threads, ElementType::F32, wtype, tdm)?;
Ctx::Ggml(RWKVContext::new(
(wtype, ltensors).try_into()?,
tokenizer,
args.max_eval_threads,
))
)?)
}
};

Expand Down Expand Up @@ -179,7 +197,7 @@ fn go() -> Result<()> {
}
}
let etime = Instant::now();
let used_mem = context.rwkv.ctx.used_mem();
let used_mem = context.rwkv.ctx.used_mem()?;
println!();
info!(
"GGML memory used: {:.3}GiB",
Expand Down
19 changes: 8 additions & 11 deletions smolrwkv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,31 @@ resolver = "2"
default = ["torch", "ggml"]
torch = ["dep:repugnant-pickle"]
simd = ["dep:simba"]
ggml = ["dep:ggml", "dep:ggml-sys"]
ggml = ["dep:rusty-ggml"]

[dependencies]
safetensors = { version = "0.3", default-features = false }
tokenizers = { version = "0.13", default-features = false, features = ["onig"] }
anyhow = "1"
memmap2 = "0.5"
memmap2 = "0.7"
ndarray = { version = "0.15", features = ["rayon"] }
rayon = "1.7"
half = "2.2"
rand = "0.8"
num-traits = "0.2"
num-derive="0.3"
tracing = "0.1"
bytemuck = { version = "1", features = ["extern_crate_alloc"] }
simba = { version = "0.8", features = ["wide"], optional = true }

[dependencies.repugnant-pickle]
git = "https://github.com/kerfufflev2/repugnant-pickle"
git = "https://github.com/KerfuffleV2/repugnant-pickle"
tag = "v0.0.1"
features = ["torch"]
optional = true

[dependencies.ggml]
[dependencies.rusty-ggml]
git = "https://github.com/KerfuffleV2/rusty-ggml"
tag = "0.0.6"
optional = true
git = "https://github.com/rustformers/llama-rs"
rev = "84656bedce38a6be8070e13506cb6c6245e7ef75"

[dependencies.ggml-sys]
optional = true
git = "https://github.com/rustformers/llama-rs"
rev = "84656bedce38a6be8070e13506cb6c6245e7ef75"
# path = "../../rusty-ggml"
87 changes: 42 additions & 45 deletions smolrwkv/src/ggml/context.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, ensure, Result};
use ndarray::{Array1, ArrayView1};
use tokenizers::Tokenizer;

use ggml::{ComputationGraph, Tensor, Type as GT};
use rusty_ggml::prelude::*;

use super::model::{RWKVLayerState, RWKV};

Expand All @@ -15,46 +15,47 @@ pub struct RWKVContext {
/// Probabilities from the last step (starts filled with zeros).
pub last_probs: Array1<f32>,
/// It's a 1d tensor with length 1 that contains the token ID.
pub token_tensor: Tensor,
pub token_tensor: GTensor1,
/// This is the base of the graph and also where the probs appear.
pub result_tensor: Tensor,
pub result_tensor: GTensor1,
/// The GGML computation graph.
pub ggml_graph: ComputationGraph,
pub ggml_graph: GGraph,
/// The tokenizer.
pub tokenizer: Tokenizer,
}

impl RWKVContext {
pub fn new(rwkv: RWKV, tokenizer: Tokenizer, eval_threads: usize) -> Self {
pub fn new(rwkv: RWKV, tokenizer: Tokenizer, eval_threads: usize) -> Result<Self> {
let ctx = &rwkv.ctx;

let token_tensor = ctx.new_tensor_1d(GT::I32, 1);
let token_tensor = ctx.tensor(GType::I32, [1])?;
let mut initial_state = (0..rwkv.n_layers)
.map(|_| RWKVLayerState::new(ctx, rwkv.n_embed))
.collect::<Vec<_>>();

let initial_probs = Array1::zeros(rwkv.n_vocab);
let rwkv_ops_graph = rwkv.evaluate_ops(ctx, &mut initial_state, token_tensor.share());
let rwkv_ops_graph = rwkv.evaluate_ops(&mut initial_state, token_tensor.clone());

let mut ggml_graph = ComputationGraph::new(eval_threads);
ggml_graph.build_forward_expand(&rwkv_ops_graph);
initial_state.iter().for_each(|s| {
ggml_graph.build_forward_expand(&s.tm_last_x);
ggml_graph.build_forward_expand(&s.cm_last_x);
ggml_graph.build_forward_expand(&s.tm_aa);
ggml_graph.build_forward_expand(&s.tm_bb);
ggml_graph.build_forward_expand(&s.tm_pp);
});
let mut ggml_graph = GGraph::new(eval_threads);
ggml_graph.build_forward_expand(&rwkv_ops_graph)?;
initial_state.iter().try_for_each(|s| {
ggml_graph.build_forward_expand(&s.tm_last_x)?;
ggml_graph.build_forward_expand(&s.cm_last_x)?;
ggml_graph.build_forward_expand(&s.tm_aa)?;
ggml_graph.build_forward_expand(&s.tm_bb)?;
ggml_graph.build_forward_expand(&s.tm_pp)?;
anyhow::Ok(())
})?;

Self {
Ok(Self {
rwkv,
state: initial_state,
last_probs: initial_probs,
token_tensor,
result_tensor: rwkv_ops_graph,
ggml_graph,
tokenizer,
}
})
}

/// Feeds some text to the model. A closure can be specified here to allow
Expand All @@ -69,27 +70,25 @@ impl RWKVContext {
.map_err(|e| anyhow!(e))?;

for tid in toks.get_ids().iter() {
unsafe {
self.token_tensor
.write_data(bytemuck::bytes_of(&(*tid as i32)));
}
ctx.graph_compute(&mut self.ggml_graph);
self.token_tensor.set_i32_1d(0, *tid as i32);
ctx.compute(&mut self.ggml_graph)?;
if let Some(f) = &f {
self.tokenizer
.decode(vec![*tid], false)
.map(f)
.map_err(|e| anyhow!(e))?;
}
}
assert_eq!(
self.result_tensor.get_ne()[0] as usize,
self.last_probs.len()
ensure!(
self.result_tensor.shape()[0] == self.last_probs.len()
&& self.result_tensor.elements() == self.last_probs.len(),
"Unexpected shape for result tensor"
);
// FIXME: Use ggml tensor manipulation methods?
unsafe {
(self.result_tensor.data() as *const f32)
.copy_to_nonoverlapping(self.last_probs.as_mut_ptr(), self.last_probs.len());
}
self.result_tensor.copy_to_slice_f32(
self.last_probs
.as_slice_mut()
.expect("Could get slice from last_probs?"),
)?;
Ok(())
}

Expand All @@ -110,21 +109,19 @@ impl RWKVContext {
.decode(vec![tokid as u32], false)
.map_err(|e| anyhow!(e))?;

unsafe {
self.token_tensor
.write_data(bytemuck::bytes_of(&(tokid as i32)));
}
self.token_tensor.set_i32_1d(0, tokid as i32);

ctx.graph_compute(&mut self.ggml_graph);
assert_eq!(
self.result_tensor.get_ne()[0] as usize,
self.last_probs.len()
ctx.compute(&mut self.ggml_graph)?;
ensure!(
self.result_tensor.shape()[0] == self.last_probs.len()
&& self.result_tensor.elements() == self.last_probs.len(),
"Unexpected shape for result tensor"
);
// FIXME: Use ggml tensor manipulation methods?
unsafe {
(self.result_tensor.data() as *const f32)
.copy_to_nonoverlapping(self.last_probs.as_mut_ptr(), self.last_probs.len());
}
self.result_tensor.copy_to_slice_f32(
self.last_probs
.as_slice_mut()
.expect("Could get slice from last_probs?"),
)?;
Ok(Some(output))
}
}
Loading

0 comments on commit 27da93f

Please sign in to comment.