Skip to content

Commit

Permalink
♻️ use text generator in text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 21, 2023
1 parent 01118fa commit 391e05e
Showing 1 changed file with 51 additions and 71 deletions.
122 changes: 51 additions & 71 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@ use crate::{
llm::token_generator::TokenGenerator,
};

use crate::llm::generate_parameter::GenerateParameter;
use anyhow::{Error as E, Result};
use candle_core::Device;
use candle_transformers::{generation::LogitsProcessor, models::quantized_llama::ModelWeights};
use futures::Stream;
use log::{debug, info, trace};
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;

use super::token_output_stream::TokenOutputStream;
use super::{
text_generator::{self, TextGenerator},
token_generator::{TokenGenerator2, TokenGeneratorTrait},
token_output_stream::TokenOutputStream,
TextGeneratorTrait,
};

pub struct TextGeneration {
model: Arc<Mutex<ModelWeights>>,
Expand Down Expand Up @@ -43,87 +49,61 @@ impl TextGeneration {
}
}

fn prepare_tokens(tokenizer: &mut TokenOutputStream, prompt: String) -> Result<Vec<u32>> {
tokenizer.clear();
let binding = tokenizer.tokenizer().encode(prompt, true).map_err(E::msg)?;
Ok(binding.get_ids().to_vec())
}
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<Option<String>> {
let locked_tokenizer = self.tokenizer.try_lock().unwrap();
let locked_model = self.model.try_lock().unwrap();

let stop_tokens = vec!["<|endoftext|>", "</s>"];

let eos_tokens: HashSet<u32> = stop_tokens
.into_iter()
.map(|token| {
locked_tokenizer
.tokenizer()
.token_to_id(token)
.unwrap_or_default()
})
.collect::<HashSet<u32>>();

let parameter = GenerateParameter {
max_new_tokens: sample_len,
};

let model = Box::new(locked_model.clone());
let sampler = Box::new(LogitsProcessor::new(42, None, None));

let token_generator: Box<dyn TokenGeneratorTrait> =
Box::new(TokenGenerator2::new(eos_tokens, parameter, model, sampler));

let mut text_generator = TextGenerator::new(
TokenOutputStream::new(locked_tokenizer.tokenizer().clone()),
token_generator,
);

text_generator.init(prompt.to_string())?;

fn process_tokens<F>(
tokenizer: &Mutex<TokenOutputStream>,
model: &Mutex<ModelWeights>,
prompt: String,
sample_len: usize,
stop_tokens: Option<Vec<String>>,
mut handle_token: F,
) -> Result<()>
where
F: FnMut(String, usize, bool) -> Result<()>,
{
let mut tokenizer = tokenizer.try_lock().unwrap();
let mut model = model.try_lock().unwrap();
let mut generated_tokens = Self::prepare_tokens(&mut tokenizer, prompt).unwrap();

let repeat_penalty = 1.1;
let repeat_last_n = 64;
let mut logits_processor = LogitsProcessor::new(42, None, None);
let start_gen = std::time::Instant::now();
let mut token_count = 0;

let mut token_generator = TokenGenerator::new();
token_generator.set_stop_tokens(stop_tokens, &mut tokenizer);

let mut generated_text: Vec<String> = Vec::new();

for _ in 0..sample_len {
let next_token = {
token_generator.next(
&generated_tokens,
&mut logits_processor,
&mut model,
repeat_penalty,
repeat_last_n,
)?
};

if let Some(token) = next_token {
generated_tokens.push(token);
generated_text.push(tokenizer.next_token(token).unwrap_or_default().unwrap());
if token_generator.is_stop_token(&token) {
let mut generated_text = String::new();
while let Ok(result) = text_generator.next() {
token_count += 1;
match result {
text_generator::TextGeneratorResult::Token((text, _)) => {
generated_text.push_str(&text);
}
text_generator::TextGeneratorResult::Finish(_) => {
break;
}
}
}

for (index, text) in generated_text.iter().enumerate() {
handle_token(text.clone(), index, false)?;
}

info!(
"{} tokens generated ({:.2} token/s)",
generated_tokens.len(),
generated_tokens.len() as f64 / start_gen.elapsed().as_secs_f64(),
token_count,
token_count as f64 / start_gen.elapsed().as_secs_f64(),
);

Ok(())
}

pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<Option<String>> {
let mut generated_text = String::new();
Self::process_tokens(
&self.tokenizer,
&self.model,
prompt.to_string(),
sample_len,
None,
|text, _, is_final| {
if is_final {
Ok(())
} else {
generated_text.push_str(&text);
Ok(())
}
},
)?;
Ok(Some(generated_text))
}

Expand Down

0 comments on commit 391e05e

Please sign in to comment.