Skip to content

Commit

Permalink
♻️ refactored modules
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 25, 2023
1 parent 845f418 commit 1c13a8d
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 156 deletions.
4 changes: 2 additions & 2 deletions src/api/routes/generate_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::api::model::GenerateRequest;
use crate::llm::generate_parameter::GenerateParameter;
use crate::{config::Config, llm::create_text_generation};
use crate::llm::text_generation::create_text_generation;
use crate::{api::model::GenerateRequest, config::Config};
use axum::{
extract::State,
response::{sse::Event, IntoResponse, Sse},
Expand Down
2 changes: 1 addition & 1 deletion src/api/routes/generate_text.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
api::model::{ErrorResponse, GenerateRequest, GenerateResponse},
config::Config,
llm::{create_text_generation, generate_parameter::GenerateParameter},
llm::{generate_parameter::GenerateParameter, text_generation::create_text_generation},
};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};

Expand Down
128 changes: 128 additions & 0 deletions src/llm/loader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::path::PathBuf;

use super::models::Models;
use anyhow::{Error as E, Result};
use candle_core::quantized::{ggml_file, gguf_file};
use candle_core::Device;
use candle_transformers::models::quantized_llama::ModelWeights;
use hf_hub::api::sync::{Api, ApiBuilder};
use hf_hub::{Repo, RepoType};
use log::{debug, info};
use tokenizers::Tokenizer;

fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}

pub fn create_model(
model: Models,
cache_dir: &Option<PathBuf>,
) -> Result<(ModelWeights, Device), Box<dyn std::error::Error>> {
info!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let revision = "main".to_string();

let start = std::time::Instant::now();
let api = match cache_dir {
Some(cache_dir) => ApiBuilder::default()
.with_cache_dir(cache_dir.clone())
.build()?,
None => Api::new()?,
};

let model_path = model.repo_path();

debug!("model paths: {:?}", model_path);

let repo = api.repo(Repo::with_revision(
model_path.0.to_string(),
RepoType::Model,
revision,
));

let model_path = &repo.get(model_path.1)?;
let mut file = std::fs::File::open(model_path)?;
info!("retrieved the model files in {:?}", start.elapsed());

let model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
debug!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let content = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in content.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
}
debug!(
"loaded {:?} tensors ({}) in {:.2}s",
content.tensors.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
debug!("params: {:?}", content.hparams);
let default_gqa = match model {
Models::L7b
| Models::L13b
| Models::L7bChat
| Models::L13bChat
| Models::L7bCode
| Models::L13bCode
| Models::L34bCode
| Models::Leo7b
| Models::Leo13b => 1,
Models::Mixtral
| Models::MixtralInstruct
| Models::Mistral7b
| Models::Mistral7bInstruct
| Models::Zephyr7bAlpha
| Models::Zephyr7bBeta
| Models::L70b
| Models::L70bChat
| Models::OpenChat35
| Models::Starling7bAlpha => 8,
};
ModelWeights::from_ggml(content, default_gqa)?
}
};
Ok((model, Device::Cpu))
}

pub fn create_tokenizer(model: Models) -> Result<Tokenizer, Box<dyn std::error::Error>> {
let tokenizer_path = {
let api = hf_hub::api::sync::Api::new()?;
let repo = model.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
};
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
Ok(tokenizer)
}
148 changes: 2 additions & 146 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,160 +1,16 @@
// source: https://github.com/huggingface/candle/blob/main/candle-examples/examples/mistral/main.rs
pub mod generate_parameter;
pub mod loader;
pub mod model_processor;
pub mod models;
pub mod sampler;
mod text_generation;
pub mod text_generation;
pub mod text_generator;
pub mod token_generator;

pub use text_generator::TextGeneratorTrait;

use std::path::PathBuf;

use anyhow::{Error as E, Result};
use candle_core::quantized::{ggml_file, gguf_file};
use candle_core::Device;
use candle_transformers::models::quantized_llama::ModelWeights;
use hf_hub::api::sync::{Api, ApiBuilder};
use hf_hub::{Repo, RepoType};
use log::{debug, info};
pub use text_generation::TextGeneration;
use tokenizers::Tokenizer;

use self::models::Models;

#[derive(Debug, PartialEq)]
pub enum FinishReason {
Length,
EosToken,
StopSequence,
}

fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}

pub fn create_model(
model: Models,
cache_dir: &Option<PathBuf>,
) -> Result<(ModelWeights, Device), Box<dyn std::error::Error>> {
info!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let revision = "main".to_string();

let start = std::time::Instant::now();
let api = match cache_dir {
Some(cache_dir) => ApiBuilder::default()
.with_cache_dir(cache_dir.clone())
.build()?,
None => Api::new()?,
};

let model_path = model.repo_path();

debug!("model paths: {:?}", model_path);

let repo = api.repo(Repo::with_revision(
model_path.0.to_string(),
RepoType::Model,
revision,
));

let model_path = &repo.get(model_path.1)?;
let mut file = std::fs::File::open(model_path)?;
info!("retrieved the model files in {:?}", start.elapsed());

let model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
debug!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let content = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in content.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
}
debug!(
"loaded {:?} tensors ({}) in {:.2}s",
content.tensors.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
debug!("params: {:?}", content.hparams);
let default_gqa = match model {
Models::L7b
| Models::L13b
| Models::L7bChat
| Models::L13bChat
| Models::L7bCode
| Models::L13bCode
| Models::L34bCode
| Models::Leo7b
| Models::Leo13b => 1,
Models::Mixtral
| Models::MixtralInstruct
| Models::Mistral7b
| Models::Mistral7bInstruct
| Models::Zephyr7bAlpha
| Models::Zephyr7bBeta
| Models::L70b
| Models::L70bChat
| Models::OpenChat35
| Models::Starling7bAlpha => 8,
};
ModelWeights::from_ggml(content, default_gqa)?
}
};
Ok((model, Device::Cpu))
}

pub fn create_tokenizer(model: Models) -> Result<Tokenizer, Box<dyn std::error::Error>> {
let tokenizer_path = {
let api = hf_hub::api::sync::Api::new()?;
let repo = model.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
};
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
Ok(tokenizer)
}

pub fn create_text_generation(
model: Models,
cache_dir: &Option<PathBuf>,
) -> Result<TextGeneration, Box<dyn std::error::Error>> {
let tokenizer = create_tokenizer(model)?;
let model = create_model(model, cache_dir)?;

let device = Device::Cpu;

Ok(TextGeneration::new(model.0, tokenizer, &device))
}
22 changes: 19 additions & 3 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
api::model::{FinishReason, StreamDetails, StreamResponse, Token},
llm::{self, text_generator::TextGeneratorResult},
llm::{
self,
text_generator::{TextGeneratorResult, TextGeneratorTrait},
},
};

use crate::llm::generate_parameter::GenerateParameter;
Expand All @@ -10,15 +13,16 @@ use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::{generation::LogitsProcessor, models::quantized_llama::ModelWeights};
use futures::Stream;
use log::{info, trace};
use std::{collections::HashSet, sync::Arc};
use std::{collections::HashSet, path::PathBuf, sync::Arc};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;

use super::{
loader::{create_model, create_tokenizer},
models::Models,
text_generator::{self, TextGenerator},
token_generator::{TokenGenerator, TokenGeneratorTrait},
TextGeneratorTrait,
};

pub struct TextGeneration {
Expand Down Expand Up @@ -232,3 +236,15 @@ impl TextGeneration {
ReceiverStream::new(rx)
}
}

pub fn create_text_generation(
model: Models,
cache_dir: &Option<PathBuf>,
) -> Result<TextGeneration, Box<dyn std::error::Error>> {
let tokenizer = create_tokenizer(model)?;
let model = create_model(model, cache_dir)?;

let device = Device::Cpu;

Ok(TextGeneration::new(model.0, tokenizer, &device))
}
10 changes: 6 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use chat_flame_backend::{
config::{load_config, Config},
llm::{generate_parameter::GenerateParameter, models::Models},
llm::{
generate_parameter::GenerateParameter, loader::create_model, models::Models,
text_generation::create_text_generation,
},
server::server,
};
use clap::Parser;
Expand Down Expand Up @@ -60,8 +63,7 @@ async fn generate_text(
config: Config,
) {
info!("Generating text for prompt: {}", prompt);
let mut text_generation =
chat_flame_backend::llm::create_text_generation(model, &config.cache_dir).unwrap();
let mut text_generation = create_text_generation(model, &config.cache_dir).unwrap();

let generated_text = text_generation.run(&prompt, parameter).unwrap();
println!("{}", generated_text.unwrap_or_default());
Expand All @@ -70,7 +72,7 @@ async fn generate_text(
async fn start_server(model: Models, config: Config) {
info!("Starting server");
info!("preload model");
let _ = chat_flame_backend::llm::create_model(model, &config.cache_dir);
let _ = create_model(model, &config.cache_dir);

info!("Running on port: {}", config.port);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
Expand Down

0 comments on commit 1c13a8d

Please sign in to comment.