Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for BitNet Architecture Inference #2664

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions candle-examples/examples/llama-bitnet/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// An implementation of LLaMA https://github.com/facebookresearch/llama
//
// This is based on nanoGPT in a similar way to:
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};

use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;

use candle_transformers::models::llama_bitnet as model;
use model::{Llama, LlamaConfig};

const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";

#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
BitnetB1_58Large,
Bitnet51_58XL,
Bitnet51_38_3B,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,

/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,

/// The initial prompt.
#[arg(long)]
prompt: Option<String>,

/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

/// The model size to use.
#[arg(long, default_value = "bitnet-b1-58large")]
which: Which,

#[arg(long)]
use_flash_attn: bool,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}

fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};

let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large",
Which::Bitnet51_58XL => "1bitLLM/bitnet_b1_58-xl",
Which::Bitnet51_38_3B => "1bitLLM/bitnet_b1_38-3b",
};
str.to_string()
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));

let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);

let filenames = match args.which {
Which::BitnetB1_58Large => {
vec![api.get("model.safetensors")?]
}
Which::Bitnet51_38_3B | Which::Bitnet51_58XL => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single)
});
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);

println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};

let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
if index == 1 {
start_gen = std::time::Instant::now()
}
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();

let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);

match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}
125 changes: 125 additions & 0 deletions candle-nn/src/bit_linear.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//! BitLinear layer
//!
//! This layer applies a bit_linear transformation to the incoming data, `y = [email protected]() + b`.
//! The bias is optional. The `forward` method can be used to apply the layer, it supports input
//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the
//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively.
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
//! use candle_nn::{BitLinear, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(&[[1f32, -1.], [-1., 1.], [1., 1.]], &Cpu)?;
//! let layer = BitLinear::new(w, None); // Use no bias.
//! let xs = Tensor::new(&[[1f32, -1.]], &Cpu)?;
//! let ys = layer.forward(&xs)?;
//! assert_eq!(ys.to_vec2::<f32>()?, &[[2.0f32, -2.0, 0.0]]);
//! # Ok(()) }
//! ```
use candle::{Result, Tensor, D};

#[derive(Clone, Debug)]
pub struct BitLinear {
weight: Tensor,
bias: Option<Tensor>,
}

fn weight_quant(x: &Tensor) -> Result<Tensor> {
let scale = (1.0
/ x.to_dtype(candle::DType::F32)?
.abs()?
.mean_all()?
.clamp(1e-5, f32::INFINITY)?)?
.to_dtype(x.dtype())?;

let u = (x.broadcast_mul(&scale))?
.round()?
.clamp(-1.0, 1.0)?
.broadcast_div(&scale)?;

Ok(u)
}

fn activation_quant(x: &Tensor) -> Result<Tensor> {
let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?;
let scale = (127.0 / scale)?;

let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?.broadcast_div(&scale)?;

Ok(y)
}

impl BitLinear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let weight = weight_quant(&weight).unwrap();
Self { weight, bias }
}

pub fn weight(&self) -> &Tensor {
&self.weight
}

pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}

impl super::Module for BitLinear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = self.weight();
let w = match *x.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};

let x = activation_quant(x)?;

let x = x.matmul(&w)?;

match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}

/// Create or initialize a new bit_linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
pub fn bit_linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<BitLinear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
Ok(BitLinear::new(ws, Some(bs)))
}

/// Create or initialize a new bit_linear layer without biases.
pub fn bit_linear_no_bias(
in_dim: usize,
out_dim: usize,
vb: crate::VarBuilder,
) -> Result<BitLinear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(BitLinear::new(ws, None))
}

pub fn bit_linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: crate::VarBuilder,
) -> Result<BitLinear> {
if bias {
bit_linear(in_dim, out_dim, vb)
} else {
bit_linear_no_bias(in_dim, out_dim, vb)
}
}
3 changes: 2 additions & 1 deletion candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod activation;
pub mod batch_norm;
pub mod bit_linear;
pub mod conv;
pub mod embedding;
pub mod encoding;
Expand All @@ -34,9 +35,9 @@ pub mod rotary_emb;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use bit_linear::{bit_linear, bit_linear_b, bit_linear_no_bias, BitLinear};
pub use conv::{
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
Expand Down
Loading