Skip to content

Commit

Permalink
Add a top_p filter to mirostat sampler.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Apr 21, 2024
1 parent 8329aff commit 03451bd
Showing 1 changed file with 26 additions and 73 deletions.
99 changes: 26 additions & 73 deletions src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ pub struct MirostatParams {
#[derivative(Default(value = "0.1"))]
#[serde(alias = "learning_rate")]
pub rate: f32,
#[derivative(Default(value = "128"))]
#[serde(default = "default_threshold")]
pub threshold: usize,
#[serde(default = "default_top_p")]
#[derivative(Default(value = "0.95"))]
pub top_p: f32,
}

fn default_threshold() -> usize {
MirostatParams::default().threshold
fn default_top_p() -> f32 {
MirostatParams::default().top_p
}

#[derive(Debug, Clone, Default)]
Expand All @@ -32,38 +32,13 @@ pub struct MirostatSampler {
pub state: MirostatState,
}

#[allow(unused)]
impl MirostatSampler {
pub fn new(params: MirostatParams) -> Self {
let state = MirostatState {
max_surprise: params.tau * 2.0,
};
Self { params, state }
}

fn estimate_s(&self, probs: &[f32]) -> f32 {
assert!(probs.len() >= self.params.threshold);
let mut num = 0.0;
let mut den = 0.0;
for i in 0..self.params.threshold {
if probs[i] < 0.0625 / probs.len() as f32 {
break;
}
let b = probs[i].ln() - probs[i + 1].ln();
let t = ((i + 2) as f32).ln() - ((i + 1) as f32).ln();
num += b * t;
den += t * t;
}
num / den
}

fn compute_k(&self, probs: &[f32], s: f32) -> usize {
let n = probs.len() as f32;
let tau = self.state.max_surprise;
let eps = s - 1.0;
let k = (eps * 2.0_f32.powf(tau) / (1.0 - n.powf(-eps))).powf(1.0 / s);
k.ceil().clamp(0.0, n - 1.0) as usize
}
}

impl Sampler for MirostatSampler {
Expand All @@ -72,63 +47,41 @@ impl Sampler for MirostatSampler {
fn transform(&self, _output: &mut [f32]) {}

fn sample(&mut self, probs: &[f32]) -> u16 {
// let sorted = probs
// .iter()
// .enumerate()
// .sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse())
// .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
// *cum += x;
// Some((id, *cum, *x))
// })
// .collect_vec();
// let sorted_probs = sorted.iter().map(|x| x.2).collect_vec();

// let s = self.estimate_s(&sorted_probs);
// let k = self.compute_k(&sorted_probs, s);

// let sum = sorted.get(k).map(|&(_, cum, _)| cum).unwrap_or_default();
// let rand = fastrand::f32() * sum;
// let (token, _, prob) = sorted
// .into_iter()
// .find_or_first(|&(_, cum, _)| rand <= cum)
// .unwrap_or_default();

// let token_surprise = (1.0 / prob).log2();
// let error_surprise = token_surprise - self.params.tau;
// self.state.max_surprise -= self.params.rate * error_surprise;
let MirostatSampler { params, state } = self;

// sort the surprise values and truncate
let sorted = probs
.iter()
.enumerate()
.sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse());
let k = sorted
.clone()
.find_position(|&(_, x)| -x.log2() > self.state.max_surprise)
.map(|(k, _)| k + 1)
.unwrap_or(probs.len());
let sorted = sorted.take(k).collect_vec();

// normalize the probs
let sum: f32 = sorted.iter().map(|&(_, x)| x).sum();
let sorted = sorted
.into_iter()
.map(|(id, x)| (id, x / sum))
.sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse())
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
*cum += x;
Some((id, *cum, x))
if *cum > params.top_p {
None
} else {
*cum += x;
Some((id, *cum, *x))
}
})
.collect_vec();
let k = sorted
.iter()
.find_position(|&(_, _, x)| -x.log2() > state.max_surprise)
.map(|(k, _)| k + 1)
.unwrap_or(sorted.len());
let sorted = sorted.into_iter().take(k).collect_vec();

let rand = fastrand::f32();
// normalize the probs
let sum = sorted.last().map(|(_, x, _)| *x).unwrap();
let rand = fastrand::f32() * sum;
let (token, _, prob) = sorted
.into_iter()
.find_or_first(|&(_, cum, _)| rand <= cum)
.unwrap();

let token_surprise = -prob.log2();
let error_surprise = token_surprise - self.params.tau;
self.state.max_surprise -= self.params.rate * error_surprise;
let token_surprise = sum.log2() - prob.log2();
let error_surprise = token_surprise - params.tau;
state.max_surprise -= params.rate * error_surprise;
state.max_surprise = state.max_surprise.min(4.0 * params.tau);

token as u16
}
Expand Down

0 comments on commit 03451bd

Please sign in to comment.