From ed3febd7bea8d427398e0364af34e01208ffb70d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 9 Dec 2024 16:36:19 +0100 Subject: [PATCH 1/8] initial bitnet support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-examples/examples/llama-bitnet/main.rs | 236 ++++++++ candle-nn/src/bit_linear.rs | 129 +++++ candle-nn/src/lib.rs | 3 +- .../src/models/llama_bitnet.rs | 544 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + .../src/models/with_tracing.rs | 39 ++ 6 files changed, 951 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/llama-bitnet/main.rs create mode 100644 candle-nn/src/bit_linear.rs create mode 100644 candle-transformers/src/models/llama_bitnet.rs diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs new file mode 100644 index 0000000000..8dca26dc39 --- /dev/null +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -0,0 +1,236 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{bail, Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; + +use candle_transformers::models::llama_bitnet as model; +use model::{Llama, LlamaConfig}; + +const EOS_TOKEN: &str = ""; +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + BitnetB1_58Large, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 10000)] + sample_len: usize, + + /// Disable the key-value cache. + #[arg(long)] + no_kv_cache: bool, + + /// The initial prompt. + #[arg(long)] + prompt: Option, + + /// Use different dtype than f16 + #[arg(long)] + dtype: Option, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + /// The model size to use. + #[arg(long, default_value = "bitnet-b1-58large")] + which: Which, + + #[arg(long)] + use_flash_attn: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 128)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tokenizers::Tokenizer; + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let device = candle_examples::device(args.cpu)?; + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => bail!("Unsupported dtype {dtype}"), + None => DType::F16, + }; + let (llama, tokenizer_filename, mut cache, config) = { + let api = Api::new()?; + let model_id = args.model_id.unwrap_or_else(|| { + let str = match args.which { + Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large", + }; + str.to_string() + }); + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + + let tokenizer_filename = api.get("tokenizer.json")?; + let config_filename = api.get("config.json")?; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(args.use_flash_attn); + + let filenames = match args.which { + | Which::BitnetB1_58Large => { + vec![api.get("model.safetensors")?] + } + }; + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + (Llama::load(vb, &config)?, tokenizer_filename, cache, config) + }; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let eos_token_id = config.eos_token_id.or_else(|| { + tokenizer + .token_to_id(EOS_TOKEN) + .map(model::LlamaEosToks::Single) + }); + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); + + println!("starting the inference loop"); + print!("{prompt}"); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let mut start_gen = std::time::Instant::now(); + let mut index_pos = 0; + let mut token_generated = 0; + for index in 0..args.sample_len { + let (context_size, context_index) = if cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + if index == 1 { + start_gen = std::time::Instant::now() + } + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + token_generated += 1; + tokens.push(next_token); + + match eos_token_id { + Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => { + break; + } + _ => (), + } + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + (token_generated - 1) as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs new file mode 100644 index 0000000000..ded3c351c8 --- /dev/null +++ b/candle-nn/src/bit_linear.rs @@ -0,0 +1,129 @@ +//! BitLinear layer +//! +//! This layer applies a bit_linear transformation to the incoming data, `y = x@w.t() + b`. +//! The bias is optional. The `forward` method can be used to apply the layer, it supports input +//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the +//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively. +//! +//! ```rust +//! use candle::{Tensor, Device::Cpu}; +//! use candle_nn::{BitLinear, Module}; +//! # fn main() -> candle::Result<()> { +//! +//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?; +//! let layer = BitLinear::new(w, None); // Use no bias. +//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?; +//! let ys = layer.forward(&xs)?; +//! assert_eq!(ys.to_vec2::()?, &[[210.0, 430.0, 650.0]]); +//! # Ok(()) } +//! ``` +use candle::{Result, Tensor, D}; + +#[derive(Clone, Debug)] +pub struct BitLinear { + weight: Tensor, + bias: Option, +} + +fn weight_quant(x: &Tensor) -> Result { + let scale = (1.0 + / x.to_dtype(candle::DType::F32)? + .abs()? + .mean_all()? + .clamp(1e-5, f32::INFINITY)?)? + .to_dtype(x.dtype())?; + + let u = (x.broadcast_mul(&scale))? + .round()? + .clamp(-1.0, 1.0)? + .broadcast_div(&scale)?; + + Ok(u) +} + +fn activation_quant(x: &Tensor) -> Result { + let scale = (127.0 + / x.abs()? + .max(D::Minus1)? + .max(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?)? + .to_dtype(x.dtype())?; + + let y = x + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .clamp(-128.0, 127.0)? + .broadcast_div(&scale)?; + + Ok(y) +} + + +impl BitLinear { + pub fn new(weight: Tensor, bias: Option) -> Self { + let weight = weight_quant(&weight).unwrap(); + Self { weight, bias } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } +} + +impl super::Module for BitLinear { + fn forward(&self, x: &Tensor) -> candle::Result { + let w = self.weight(); + let w = match *x.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + + let x = activation_quant(x)?; + + let x = x.matmul(&w)?; + + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +/// Create or initialize a new bit_linear layer. +/// +/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. +pub fn bit_linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; + Ok(BitLinear::new(ws, Some(bs))) +} + +/// Create or initialize a new bit_linear layer without biases. +pub fn bit_linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + Ok(BitLinear::new(ws, None)) +} + +pub fn bit_linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: crate::VarBuilder, +) -> Result { + if bias { + bit_linear(in_dim, out_dim, vb) + } else { + bit_linear_no_bias(in_dim, out_dim, vb) + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index eb3cde4a75..2917c60ebe 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -34,7 +34,7 @@ pub mod rotary_emb; pub mod sequential; pub mod var_builder; pub mod var_map; - +pub mod bit_linear; pub use activation::{prelu, Activation, PReLU}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{ @@ -48,6 +48,7 @@ pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; +pub use bit_linear::{bit_linear, BitLinear, bit_linear_b, bit_linear_no_bias}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-transformers/src/models/llama_bitnet.rs b/candle-transformers/src/models/llama_bitnet.rs new file mode 100644 index 0000000000..c1b9ffd26a --- /dev/null +++ b/candle-transformers/src/models/llama_bitnet.rs @@ -0,0 +1,544 @@ +//! Llama inference implementation. +//! +//! See ["LLaMA: Open and Efficient Foundation Language Models"](https://arxiv.org/abs/2302.13971) +//! +//! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) + +use super::with_tracing::{bit_linear_no_bias as bit_linear, BitLinear, RmsNorm, Linear, linear}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use std::{collections::HashMap, f32::consts::PI}; + +pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub enum Llama3RopeType { + #[serde(rename = "llama3")] + Llama3, + #[default] + #[serde(rename = "default")] + Default, +} + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub struct Llama3RopeConfig { + pub factor: f32, + pub low_freq_factor: f32, + pub high_freq_factor: f32, + pub original_max_position_embeddings: usize, + pub rope_type: Llama3RopeType, +} +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(untagged)] +pub enum LlamaEosToks { + Single(u32), + Multiple(Vec), +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct LlamaConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub tie_word_embeddings: Option, +} + +impl LlamaConfig { + pub fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } +} + +fn default_rope() -> f32 { + 10_000.0 +} + +impl LlamaConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads(), + rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta, + use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, + rope_scaling: self.rope_scaling, + max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, +} + +impl Config { + pub fn config_7b_v1(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-6, + rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, + } + } + + pub fn config_7b_v2(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-5, + rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, + } + } +} + +#[derive(Debug, Clone)] +pub struct Cache { + masks: HashMap, + pub use_kv_cache: bool, + kvs: Vec>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +fn calculate_default_inv_freq(cfg: &Config) -> Vec { + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32)) + .collect() +} + +impl Cache { + pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let theta = match &config.rope_scaling { + None + | Some(Llama3RopeConfig { + rope_type: Llama3RopeType::Default, + .. + }) => calculate_default_inv_freq(config), + Some(rope_scaling) => { + let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.low_freq_factor; + let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.high_freq_factor; + + calculate_default_inv_freq(config) + .into_iter() + .map(|freq| { + let wavelen = 2. * PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / rope_scaling.factor + } else { + let smooth = (rope_scaling.original_max_position_embeddings as f32 + / wavelen + - rope_scaling.low_freq_factor) + / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor); + (1. - smooth) * freq / rope_scaling.factor + smooth * freq + } + }) + .collect::>() + } + }; + + let theta = Tensor::new(theta, device)?; + + let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((config.max_position_embeddings, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + masks: HashMap::new(), + use_kv_cache, + kvs: vec![None; config.num_hidden_layers], + device: device.clone(), + cos, + sin, + }) + } + + fn mask(&mut self, t: usize) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +#[derive(Debug, Clone)] +struct CausalSelfAttention { + q_proj: BitLinear, + k_proj: BitLinear, + v_proj: BitLinear, + o_proj: BitLinear, + inner_attn_ln: RmsNorm, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + use_flash_attn: bool, + span: tracing::Span, + span_rot: tracing::Span, + max_position_embeddings: usize, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result { + let _enter = self.span_rot.enter(); + let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?; + let cos = cache.cos.narrow(0, index_pos, seq_len)?; + let sin = cache.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(x, &cos, &sin) + } + + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { + let _enter = self.span.enter(); + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos, cache)?; + let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; + + if cache.use_kv_cache { + if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > self.max_position_embeddings { + k = k + .narrow( + D::Minus1, + k_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * self.max_position_embeddings { + v = v + .narrow( + D::Minus1, + v_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? + .contiguous()? + } + } + cache.kvs[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let y = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? + } else { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = if seq_len == 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; + + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + let y = self.inner_attn_ln.forward(&y)?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; + let q_proj = bit_linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = bit_linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = bit_linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = bit_linear(size_q, size_in, vb.pp("o_proj"))?; + let inner_attn_ln = RmsNorm::new(size_q, cfg.rms_norm_eps, vb.pp("inner_attn_ln"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + inner_attn_ln, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, + use_flash_attn: cfg.use_flash_attn, + span, + span_rot, + max_position_embeddings: cfg.max_position_embeddings, + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone)] +struct Mlp { + c_fc1: BitLinear, + c_fc2: BitLinear, + c_proj: BitLinear, + ffn_layernorm: RmsNorm, + span: tracing::Span, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = self.ffn_layernorm.forward(&x)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = bit_linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = bit_linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = bit_linear(i_size, h_size, vb.pp("down_proj"))?; + let ffn_layernorm = RmsNorm::new(i_size, cfg.rms_norm_eps, vb.pp("ffn_layernorm"))?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + ffn_layernorm, + span, + }) + } +} + +#[derive(Debug, Clone)] +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + span: tracing::Span, +} + +impl Block { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { + let _enter = self.span.enter(); + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "block"); + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp, + span, + }) + } +} + +#[derive(Debug, Clone)] +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + // required by LLaVA + pub fn embed(&self, x: &Tensor) -> Result { + self.wte.forward(x) + } + // required by LLaVA + pub fn forward_input_embed( + &self, + input_embed: &Tensor, + index_pos: usize, + cache: &mut Cache, + ) -> Result { + let (_, seq_len, _) = input_embed.dims3()?; + let mut x = input_embed.clone(); + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx, cache)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx, cache)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(wte.embeddings().clone(), None) + } else { + linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; + let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) + .collect(); + + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + }) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index be1f15c413..94eb796792 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -110,3 +110,4 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; +pub mod llama_bitnet; \ No newline at end of file diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index f4706c7e95..542f6997ea 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -72,6 +72,45 @@ impl Module for Linear { } } +#[derive(Debug, Clone)] +pub struct BitLinear { + inner: candle_nn::BitLinear, + span: tracing::Span, +} + +impl BitLinear { + pub fn from_weights(weights: Tensor, bias: Option) -> Self { + let inner = candle_nn::BitLinear::new(weights, bias); + let span = tracing::span!(tracing::Level::TRACE, "bit_linear"); + Self { inner, span } + } +} + +pub fn bit_linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result { + let inner = candle_nn::bit_linear_b(d1, d2, b, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "bit_linear"); + Ok(BitLinear { inner, span }) +} + +pub fn bit_linear(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::bit_linear(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "bit_linear"); + Ok(BitLinear { inner, span }) +} + +pub fn bit_linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::bit_linear_no_bias(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "bit_linear"); + Ok(BitLinear { inner, span }) +} + +impl Module for BitLinear { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + // Wrap the conv2d op to provide some tracing. #[derive(Debug, Clone)] pub struct Conv2d { From 231a83eed6bf297481fe0dc9db88b322fe4be339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 9 Dec 2024 16:56:50 +0100 Subject: [PATCH 2/8] add more models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-examples/examples/llama-bitnet/main.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index 8dca26dc39..fe093fe2ad 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -30,6 +30,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum Which { BitnetB1_58Large, + Bitnet51_58XL, + Bitnet51_38_3B, } #[derive(Parser, Debug)] @@ -124,6 +126,8 @@ fn main() -> Result<()> { let model_id = args.model_id.unwrap_or_else(|| { let str = match args.which { Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large", + Which::Bitnet51_58XL => "1bitLLM/bitnet_b1_58-xl", + Which::Bitnet51_38_3B => "1bitLLM/bitnet_b1_38-3b", }; str.to_string() }); @@ -140,6 +144,9 @@ fn main() -> Result<()> { | Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } + | Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? + } }; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; From f64d885d6eb0560ac69e72f0980f87f227e3348d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 9 Dec 2024 21:25:22 +0100 Subject: [PATCH 3/8] add one test for forwarding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-nn/tests/bit_linear.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 candle-nn/tests/bit_linear.rs diff --git a/candle-nn/tests/bit_linear.rs b/candle-nn/tests/bit_linear.rs new file mode 100644 index 0000000000..1c9322448a --- /dev/null +++ b/candle-nn/tests/bit_linear.rs @@ -0,0 +1,15 @@ +use candle::{Device::Cpu, Result, Tensor}; +use candle_nn::{BitLinear, Module}; + +#[test] +fn test_forward_no_bias() -> Result<()> { + let weight = Tensor::new(&[[1f32, -1.], [-1., 1.], [1., 1.]], &Cpu)?; + let layer = BitLinear::new(weight.clone(), None); + + let input = Tensor::new(&[[1f32, -1.]], &Cpu)?; + let output = layer.forward(&input)?; + let expected_output = Tensor::new(&[[2.0f32, -2.0, 0.0]], &Cpu)?; + + assert_eq!(output.to_vec2::()?, expected_output.to_vec2::()?); + Ok(()) +} From 8367e6a3327bcf38823c9ad19cb7f139a7a3259e Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 9 Dec 2024 22:27:55 +0100 Subject: [PATCH 4/8] Apply cargo fmt. --- candle-examples/examples/llama-bitnet/main.rs | 4 ++-- candle-nn/src/bit_linear.rs | 17 ++++++++++------- candle-nn/src/lib.rs | 4 ++-- candle-transformers/src/models/llama_bitnet.rs | 2 +- candle-transformers/src/models/mod.rs | 2 +- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index fe093fe2ad..2a261d8bd7 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -141,10 +141,10 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - | Which::BitnetB1_58Large => { + Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } - | Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { + Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } }; diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index ded3c351c8..b20e8dd710 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -31,7 +31,7 @@ fn weight_quant(x: &Tensor) -> Result { .abs()? .mean_all()? .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + .to_dtype(x.dtype())?; let u = (x.broadcast_mul(&scale))? .round()? @@ -47,17 +47,16 @@ fn activation_quant(x: &Tensor) -> Result { .max(D::Minus1)? .max(D::Minus1)? .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + .to_dtype(x.dtype())?; let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)? - .broadcast_div(&scale)?; + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .clamp(-128.0, 127.0)? + .broadcast_div(&scale)?; Ok(y) } - impl BitLinear { pub fn new(weight: Tensor, bias: Option) -> Self { let weight = weight_quant(&weight).unwrap(); @@ -109,7 +108,11 @@ pub fn bit_linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Resul } /// Create or initialize a new bit_linear layer without biases. -pub fn bit_linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { +pub fn bit_linear_no_bias( + in_dim: usize, + out_dim: usize, + vb: crate::VarBuilder, +) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(BitLinear::new(ws, None)) diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 2917c60ebe..530e8fb2b4 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -17,6 +17,7 @@ pub mod activation; pub mod batch_norm; +pub mod bit_linear; pub mod conv; pub mod embedding; pub mod encoding; @@ -34,9 +35,9 @@ pub mod rotary_emb; pub mod sequential; pub mod var_builder; pub mod var_map; -pub mod bit_linear; pub use activation::{prelu, Activation, PReLU}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; +pub use bit_linear::{bit_linear, bit_linear_b, bit_linear_no_bias, BitLinear}; pub use conv::{ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, @@ -48,7 +49,6 @@ pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; -pub use bit_linear::{bit_linear, BitLinear, bit_linear_b, bit_linear_no_bias}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-transformers/src/models/llama_bitnet.rs b/candle-transformers/src/models/llama_bitnet.rs index c1b9ffd26a..b76c77d42b 100644 --- a/candle-transformers/src/models/llama_bitnet.rs +++ b/candle-transformers/src/models/llama_bitnet.rs @@ -4,7 +4,7 @@ //! //! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) -use super::with_tracing::{bit_linear_no_bias as bit_linear, BitLinear, RmsNorm, Linear, linear}; +use super::with_tracing::{bit_linear_no_bias as bit_linear, linear, BitLinear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 94eb796792..e47174b378 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -48,6 +48,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod llama_bitnet; pub mod llava; pub mod mamba; pub mod marian; @@ -110,4 +111,3 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; -pub mod llama_bitnet; \ No newline at end of file From 81067cb668debe2c0bdd162523120251c65f6d14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 9 Dec 2024 22:41:27 +0100 Subject: [PATCH 5/8] bugfixing: tests should works fine now MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-nn/src/bit_linear.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index b20e8dd710..2fdd4a9aeb 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -10,11 +10,11 @@ //! use candle_nn::{BitLinear, Module}; //! # fn main() -> candle::Result<()> { //! -//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?; +//! let w = Tensor::new(&[[1f32, -1.], [-1., 1.], [1., 1.]], &Cpu)?; //! let layer = BitLinear::new(w, None); // Use no bias. -//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?; +//! let xs = Tensor::new(&[[1f32, -1.]], &Cpu)?; //! let ys = layer.forward(&xs)?; -//! assert_eq!(ys.to_vec2::()?, &[[210.0, 430.0, 650.0]]); +//! assert_eq!(ys.to_vec2::()?, &[[2.0f32, -2.0, 0.0]]); //! # Ok(()) } //! ``` use candle::{Result, Tensor, D}; From eb2f37c60957e9945625df2437c543e85a0613fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 17 Dec 2024 16:14:33 +0100 Subject: [PATCH 6/8] working on supporting falcon model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-examples/examples/llama-bitnet/main.rs | 4 +++- candle-transformers/src/models/llama_bitnet.rs | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index 2a261d8bd7..e55c2c912d 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -32,6 +32,7 @@ enum Which { BitnetB1_58Large, Bitnet51_58XL, Bitnet51_38_3B, + Falcon3_7bInstruct158, } #[derive(Parser, Debug)] @@ -128,6 +129,7 @@ fn main() -> Result<()> { Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large", Which::Bitnet51_58XL => "1bitLLM/bitnet_b1_58-xl", Which::Bitnet51_38_3B => "1bitLLM/bitnet_b1_38-3b", + Which::Falcon3_7bInstruct158 => "tiiuae/Falcon3-7B-Instruct-1.58bit", }; str.to_string() }); @@ -141,7 +143,7 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::BitnetB1_58Large => { + Which::Falcon3_7bInstruct158 | Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { diff --git a/candle-transformers/src/models/llama_bitnet.rs b/candle-transformers/src/models/llama_bitnet.rs index b76c77d42b..1fc870cb78 100644 --- a/candle-transformers/src/models/llama_bitnet.rs +++ b/candle-transformers/src/models/llama_bitnet.rs @@ -4,7 +4,7 @@ //! //! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) -use super::with_tracing::{bit_linear_no_bias as bit_linear, linear, BitLinear, Linear, RmsNorm}; +use super::with_tracing::{bit_linear_no_bias as bit_linear, linear_no_bias as linear, BitLinear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; From f2a38092b746ee2e3e84f6eba9e79e7cc6af4928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Wed, 18 Dec 2024 06:30:10 +0100 Subject: [PATCH 7/8] remove support for falcon because is quantized MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- candle-examples/examples/llama-bitnet/main.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index e55c2c912d..2a261d8bd7 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -32,7 +32,6 @@ enum Which { BitnetB1_58Large, Bitnet51_58XL, Bitnet51_38_3B, - Falcon3_7bInstruct158, } #[derive(Parser, Debug)] @@ -129,7 +128,6 @@ fn main() -> Result<()> { Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large", Which::Bitnet51_58XL => "1bitLLM/bitnet_b1_58-xl", Which::Bitnet51_38_3B => "1bitLLM/bitnet_b1_38-3b", - Which::Falcon3_7bInstruct158 => "tiiuae/Falcon3-7B-Instruct-1.58bit", }; str.to_string() }); @@ -143,7 +141,7 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::Falcon3_7bInstruct158 | Which::BitnetB1_58Large => { + Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { From c00f3c8bf30bb4a72858c476dcc9247cf7d8f0d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 31 Dec 2024 10:53:53 +0100 Subject: [PATCH 8/8] Improve activation quant function --- candle-nn/src/bit_linear.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index 2fdd4a9aeb..97c4a54bde 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -42,17 +42,10 @@ fn weight_quant(x: &Tensor) -> Result { } fn activation_quant(x: &Tensor) -> Result { - let scale = (127.0 - / x.abs()? - .max(D::Minus1)? - .max(D::Minus1)? - .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; - let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)? - .broadcast_div(&scale)?; + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?.broadcast_div(&scale)?; Ok(y) }