Skip to content

Commit

Permalink
✨ adds parameter repeat penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 25, 2023
1 parent 1cccfc7 commit cc6630c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src/api/routes/generate_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub async fn generate_text_handler(
top_p: top_p.unwrap_or_default(),
max_new_tokens: sample_len,
seed: 42,
repeat_penalty,
repeat_last_n,
};

let generated_text = generator.run(&payload.inputs, parameter);
Expand Down
4 changes: 4 additions & 0 deletions src/llm/generate_parameter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ pub struct GenerateParameter {
pub temperature: f64,
/// Nucleus sampling probability cutoff
pub top_p: f64,
/// Penalty to be applied for repeating tokens, 1. means no penalty
pub repeat_penalty: f32,
/// The context size to consider for the repeat penalty
pub repeat_last_n: usize,
}
5 changes: 5 additions & 0 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ impl TextGeneration {
}

pub fn run(&mut self, prompt: &str, parameter: GenerateParameter) -> Result<Option<String>> {
info!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
parameter.temperature, parameter.repeat_penalty, parameter.repeat_last_n
);

let locked_tokenizer = self.tokenizer.try_lock().unwrap();
let locked_model = self.model.try_lock().unwrap();

Expand Down
56 changes: 38 additions & 18 deletions src/llm/token_generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub struct TokenGenerator2 {
sampler: Box<dyn Sampler>,
model: Box<dyn ModelProcessor>,
next_token: Option<u32>,
all_tokens: Vec<u32>,
}

impl TokenGenerator2 {
Expand All @@ -111,26 +112,51 @@ impl TokenGenerator2 {
model,
sampler,
next_token: None,
all_tokens: Vec::new(),
}
}
}

impl TokenGenerator2 {
fn next_token(&mut self) -> Result<u32> {
fn next_token(&mut self, input: &[u32]) -> Result<u32> {
let next_token = {
let input = Tensor::new(self.prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = self.model.forward(&input, 0)?;
let input = Tensor::new(input, &Device::Cpu)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.prompt_tokens.len() + self.index)?;
let logits = logits.squeeze(0)?;
self.sampler.sample(&logits)?

let adjusted_logits = if self.parameter.repeat_penalty != 1.0 {
self.apply_repeat_penalty(&logits)?
} else {
logits
};
self.sampler.sample(&adjusted_logits)?
};
Ok(next_token)
}

fn apply_repeat_penalty(&self, logits: &Tensor) -> Result<Tensor> {
let start_at = self
.all_tokens
.len()
.saturating_sub(self.parameter.repeat_last_n);
let logits = candle_transformers::utils::apply_repeat_penalty(
&logits,
self.parameter.repeat_penalty,
&self.all_tokens[start_at..],
)?;
Ok(logits.clone())
}
}

impl TokenGeneratorTrait for TokenGenerator2 {
fn init(&mut self, prompt_tokens: Vec<u32>) -> Result<()> {
self.prompt_tokens = prompt_tokens;
self.next_token = Some(self.next_token().unwrap_or_default());
self.prompt_tokens = prompt_tokens.clone();
self.all_tokens = prompt_tokens.clone();

self.next_token = Some(
self.next_token(prompt_tokens.as_slice())
.unwrap_or_default(),
);
Ok(())
}

Expand All @@ -139,21 +165,13 @@ impl TokenGeneratorTrait for TokenGenerator2 {
return Ok(TokenGeneratorResult::Finish(FinishReason::Length));
}

let next_token = {
let input = Tensor::new(&[self.next_token.unwrap()], &Device::Cpu)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.prompt_tokens.len() + self.index)?;
let logits = logits.squeeze(0)?;
self.sampler.sample(&logits)?
};

// todo: repeat penalty
let next_token = self.next_token(&[self.next_token.unwrap_or_default()])?;

if self.stop_tokens.contains(&next_token) {
return Ok(TokenGeneratorResult::Finish(FinishReason::EosToken));
}
self.next_token = Some(next_token);
self.all_tokens.push(next_token);
self.index += 1;
Ok(TokenGeneratorResult::Token((next_token, 1.0)))
}
Expand All @@ -171,6 +189,7 @@ mod tests {
HashSet::new(),
GenerateParameter {
max_new_tokens: 10,
repeat_penalty: 1.0,
..Default::default()
},
Box::new(DummyModelProcessor::new()),
Expand All @@ -197,6 +216,7 @@ mod tests {
vec![stop_token].into_iter().collect(),
GenerateParameter {
max_new_tokens: 10,
repeat_penalty: 1.0,
..Default::default()
},
Box::new(DummyModelProcessor::new()),
Expand Down
10 changes: 10 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ struct Opt {
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

/// Optional model to use for text generation. If not provided, defaults to 7b-open-chat-3.5.
#[structopt(long)]
model: Option<Models>,
Expand Down Expand Up @@ -95,6 +103,8 @@ async fn main() {
top_p: opt.top_p.unwrap_or_default(),
max_new_tokens: opt.sample_len.unwrap_or(50),
seed: opt.seed,
repeat_penalty: opt.repeat_penalty,
repeat_last_n: opt.repeat_last_n,
};

generate_text(prompt, parameter, opt.model.unwrap_or_default(), config).await;
Expand Down

0 comments on commit cc6630c

Please sign in to comment.