Skip to content

Commit

Permalink
✨ adds mixformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Jan 7, 2024
1 parent c702876 commit f767255
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 13 deletions.
68 changes: 62 additions & 6 deletions src/llm/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
use std::path::PathBuf;

use crate::llm::Model;

use super::models::Models;
use anyhow::{Error as E, Result};
use candle_core::quantized::{ggml_file, gguf_file};
Expand Down Expand Up @@ -42,7 +44,7 @@ fn format_size(size_in_bytes: usize) -> String {
pub fn create_model(
model: Models,
cache_dir: &Option<PathBuf>,
) -> Result<(ModelWeights, Device), Box<dyn std::error::Error>> {
) -> Result<(Model, Device), Box<dyn std::error::Error>> {
info!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
Expand Down Expand Up @@ -76,20 +78,48 @@ pub fn create_model(

let model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?;
let content = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
for (_, tensor) in content.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
debug!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
content.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
match model {
Models::L7b
| Models::L13b
| Models::L7bChat
| Models::L13bChat
| Models::L7bCode
| Models::L13bCode
| Models::L34bCode
| Models::Leo7b
| Models::Leo13b => Model::Llama(ModelWeights::from_gguf(content, &mut file)?),
Models::Mixtral
| Models::MixtralInstruct
| Models::Mistral7b
| Models::Mistral7bInstruct
| Models::Zephyr7bAlpha
| Models::Zephyr7bBeta
| Models::L70b
| Models::L70bChat
| Models::OpenChat35
| Models::Starling7bAlpha => {
Model::Llama(ModelWeights::from_gguf(content, &mut file)?)
}
Models::PhiV1 | Models::PhiV1_5 | Models::PhiV2 | Models::PhiHermes => {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_path,
)?;
Model::MixFormer(candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM::new_v2(&candle_transformers::models::mixformer::Config::v2(), vb)?)
}
}
}
Some("ggml" | "bin") | Some(_) | None => {
let content = ggml_file::Content::read(&mut file)?;
Expand Down Expand Up @@ -128,7 +158,33 @@ pub fn create_model(
| Models::Starling7bAlpha => 8,
Models::PhiHermes | Models::PhiV1 | Models::PhiV1_5 | Models::PhiV2 => 4,
};
ModelWeights::from_ggml(content, default_gqa)?

match model {
Models::L7b
| Models::L13b
| Models::L7bChat
| Models::L13bChat
| Models::L7bCode
| Models::L13bCode
| Models::L34bCode
| Models::Leo7b
| Models::Leo13b => Model::Llama(ModelWeights::from_ggml(content, default_gqa)?),
Models::Mixtral
| Models::MixtralInstruct
| Models::Mistral7b
| Models::Mistral7bInstruct
| Models::Zephyr7bAlpha
| Models::Zephyr7bBeta
| Models::L70b
| Models::L70bChat
| Models::OpenChat35
| Models::Starling7bAlpha => {
Model::Llama(ModelWeights::from_ggml(content, default_gqa)?)
}
Models::PhiV1 | Models::PhiV1_5 | Models::PhiV2 | Models::PhiHermes => {
Model::Llama(ModelWeights::from_ggml(content, default_gqa)?)
}
}
}
};
Ok((model, Device::Cpu))
Expand Down
6 changes: 6 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ pub enum FinishReason {
/// Generation stopped because a specified stop sequence was encountered.
StopSequence,
}

#[derive(Clone)]
pub enum Model {
Llama(candle_transformers::models::quantized_llama::ModelWeights),
MixFormer(candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM),
}
10 changes: 6 additions & 4 deletions src/llm/model_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
//! which are used for processing input tensors and generating output tensors
//! representing logits from a language model.
use super::Model;
use candle_core::{Result, Tensor};
use candle_transformers::models::quantized_llama::ModelWeights;

/// A trait for processing model inputs and generating outputs.
///
Expand All @@ -25,10 +25,12 @@ pub trait ModelProcessor {
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor>;
}

/// Implementation of `ModelProcessor` for the `ModelWeights` from `candle_transformers`.
impl ModelProcessor for ModelWeights {
impl ModelProcessor for Model {
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
Self::forward(self, x, index_pos)
match self {
Model::Llama(model) => model.forward(x, index_pos),
Model::MixFormer(model) => model.forward(x),
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::llm::generate_parameter::GenerateParameter;
use anyhow::Result;
use candle_core::Device;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::{generation::LogitsProcessor, models::quantized_llama::ModelWeights};
use candle_transformers::generation::LogitsProcessor;
use futures::Stream;
use log::{info, trace};
use std::{collections::HashSet, path::PathBuf, sync::Arc};
Expand All @@ -23,16 +23,17 @@ use super::{
models::Models,
text_generator::{self, TextGenerator},
token_generator::{TokenGenerator, TokenGeneratorTrait},
Model,
};

pub struct TextGeneration {
model: Arc<Mutex<ModelWeights>>,
model: Arc<Mutex<Model>>,
tokenizer: Arc<Mutex<TokenOutputStream>>,
}

impl TextGeneration {
#[allow(clippy::too_many_arguments)]
pub fn new(model: ModelWeights, tokenizer: Tokenizer, _device: &Device) -> Self {
pub fn new(model: Model, tokenizer: Tokenizer, _device: &Device) -> Self {
Self {
model: Arc::new(Mutex::new(model)),
tokenizer: Arc::new(Mutex::new(TokenOutputStream::new(tokenizer))),
Expand Down

0 comments on commit f767255

Please sign in to comment.