From cc6630c2c12028a14ba56245d610e3581a5ac3f4 Mon Sep 17 00:00:00 2001 From: Christian M Date: Mon, 25 Dec 2023 11:42:49 +0100 Subject: [PATCH] :sparkles: adds parameter repeat penalty --- src/api/routes/generate_text.rs | 2 ++ src/llm/generate_parameter.rs | 4 +++ src/llm/text_generation.rs | 5 +++ src/llm/token_generator/mod.rs | 56 ++++++++++++++++++++++----------- src/main.rs | 10 ++++++ 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/api/routes/generate_text.rs b/src/api/routes/generate_text.rs index 1b147f4..51f9d0c 100644 --- a/src/api/routes/generate_text.rs +++ b/src/api/routes/generate_text.rs @@ -59,6 +59,8 @@ pub async fn generate_text_handler( top_p: top_p.unwrap_or_default(), max_new_tokens: sample_len, seed: 42, + repeat_penalty, + repeat_last_n, }; let generated_text = generator.run(&payload.inputs, parameter); diff --git a/src/llm/generate_parameter.rs b/src/llm/generate_parameter.rs index 8b26a57..6825460 100644 --- a/src/llm/generate_parameter.rs +++ b/src/llm/generate_parameter.rs @@ -9,4 +9,8 @@ pub struct GenerateParameter { pub temperature: f64, /// Nucleus sampling probability cutoff pub top_p: f64, + /// Penalty to be applied for repeating tokens, 1. means no penalty + pub repeat_penalty: f32, + /// The context size to consider for the repeat penalty + pub repeat_last_n: usize, } diff --git a/src/llm/text_generation.rs b/src/llm/text_generation.rs index 90b20a8..567595e 100644 --- a/src/llm/text_generation.rs +++ b/src/llm/text_generation.rs @@ -50,6 +50,11 @@ impl TextGeneration { } pub fn run(&mut self, prompt: &str, parameter: GenerateParameter) -> Result> { + info!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + parameter.temperature, parameter.repeat_penalty, parameter.repeat_last_n + ); + let locked_tokenizer = self.tokenizer.try_lock().unwrap(); let locked_model = self.model.try_lock().unwrap(); diff --git a/src/llm/token_generator/mod.rs b/src/llm/token_generator/mod.rs index a9fde0a..ee9df40 100644 --- a/src/llm/token_generator/mod.rs +++ b/src/llm/token_generator/mod.rs @@ -94,6 +94,7 @@ pub struct TokenGenerator2 { sampler: Box, model: Box, next_token: Option, + all_tokens: Vec, } impl TokenGenerator2 { @@ -111,26 +112,51 @@ impl TokenGenerator2 { model, sampler, next_token: None, + all_tokens: Vec::new(), } } -} -impl TokenGenerator2 { - fn next_token(&mut self) -> Result { + fn next_token(&mut self, input: &[u32]) -> Result { let next_token = { - let input = Tensor::new(self.prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let logits = self.model.forward(&input, 0)?; + let input = Tensor::new(input, &Device::Cpu)?.unsqueeze(0)?; + let logits = self + .model + .forward(&input, self.prompt_tokens.len() + self.index)?; let logits = logits.squeeze(0)?; - self.sampler.sample(&logits)? + + let adjusted_logits = if self.parameter.repeat_penalty != 1.0 { + self.apply_repeat_penalty(&logits)? + } else { + logits + }; + self.sampler.sample(&adjusted_logits)? }; Ok(next_token) } + + fn apply_repeat_penalty(&self, logits: &Tensor) -> Result { + let start_at = self + .all_tokens + .len() + .saturating_sub(self.parameter.repeat_last_n); + let logits = candle_transformers::utils::apply_repeat_penalty( + &logits, + self.parameter.repeat_penalty, + &self.all_tokens[start_at..], + )?; + Ok(logits.clone()) + } } impl TokenGeneratorTrait for TokenGenerator2 { fn init(&mut self, prompt_tokens: Vec) -> Result<()> { - self.prompt_tokens = prompt_tokens; - self.next_token = Some(self.next_token().unwrap_or_default()); + self.prompt_tokens = prompt_tokens.clone(); + self.all_tokens = prompt_tokens.clone(); + + self.next_token = Some( + self.next_token(prompt_tokens.as_slice()) + .unwrap_or_default(), + ); Ok(()) } @@ -139,21 +165,13 @@ impl TokenGeneratorTrait for TokenGenerator2 { return Ok(TokenGeneratorResult::Finish(FinishReason::Length)); } - let next_token = { - let input = Tensor::new(&[self.next_token.unwrap()], &Device::Cpu)?.unsqueeze(0)?; - let logits = self - .model - .forward(&input, self.prompt_tokens.len() + self.index)?; - let logits = logits.squeeze(0)?; - self.sampler.sample(&logits)? - }; - - // todo: repeat penalty + let next_token = self.next_token(&[self.next_token.unwrap_or_default()])?; if self.stop_tokens.contains(&next_token) { return Ok(TokenGeneratorResult::Finish(FinishReason::EosToken)); } self.next_token = Some(next_token); + self.all_tokens.push(next_token); self.index += 1; Ok(TokenGeneratorResult::Token((next_token, 1.0))) } @@ -171,6 +189,7 @@ mod tests { HashSet::new(), GenerateParameter { max_new_tokens: 10, + repeat_penalty: 1.0, ..Default::default() }, Box::new(DummyModelProcessor::new()), @@ -197,6 +216,7 @@ mod tests { vec![stop_token].into_iter().collect(), GenerateParameter { max_new_tokens: 10, + repeat_penalty: 1.0, ..Default::default() }, Box::new(DummyModelProcessor::new()), diff --git a/src/main.rs b/src/main.rs index 3f31740..283c4c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,6 +40,14 @@ struct Opt { #[arg(long, default_value_t = 299792458)] seed: u64, + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + /// Optional model to use for text generation. If not provided, defaults to 7b-open-chat-3.5. #[structopt(long)] model: Option, @@ -95,6 +103,8 @@ async fn main() { top_p: opt.top_p.unwrap_or_default(), max_new_tokens: opt.sample_len.unwrap_or(50), seed: opt.seed, + repeat_penalty: opt.repeat_penalty, + repeat_last_n: opt.repeat_last_n, }; generate_text(prompt, parameter, opt.model.unwrap_or_default(), config).await;