diff --git a/src/sampler/mirostat.rs b/src/sampler/mirostat.rs index 02c3998b..5d0988e6 100644 --- a/src/sampler/mirostat.rs +++ b/src/sampler/mirostat.rs @@ -12,13 +12,13 @@ pub struct MirostatParams { #[derivative(Default(value = "0.1"))] #[serde(alias = "learning_rate")] pub rate: f32, - #[derivative(Default(value = "128"))] - #[serde(default = "default_threshold")] - pub threshold: usize, + #[serde(default = "default_top_p")] + #[derivative(Default(value = "0.95"))] + pub top_p: f32, } -fn default_threshold() -> usize { - MirostatParams::default().threshold +fn default_top_p() -> f32 { + MirostatParams::default().top_p } #[derive(Debug, Clone, Default)] @@ -32,7 +32,6 @@ pub struct MirostatSampler { pub state: MirostatState, } -#[allow(unused)] impl MirostatSampler { pub fn new(params: MirostatParams) -> Self { let state = MirostatState { @@ -40,30 +39,6 @@ impl MirostatSampler { }; Self { params, state } } - - fn estimate_s(&self, probs: &[f32]) -> f32 { - assert!(probs.len() >= self.params.threshold); - let mut num = 0.0; - let mut den = 0.0; - for i in 0..self.params.threshold { - if probs[i] < 0.0625 / probs.len() as f32 { - break; - } - let b = probs[i].ln() - probs[i + 1].ln(); - let t = ((i + 2) as f32).ln() - ((i + 1) as f32).ln(); - num += b * t; - den += t * t; - } - num / den - } - - fn compute_k(&self, probs: &[f32], s: f32) -> usize { - let n = probs.len() as f32; - let tau = self.state.max_surprise; - let eps = s - 1.0; - let k = (eps * 2.0_f32.powf(tau) / (1.0 - n.powf(-eps))).powf(1.0 / s); - k.ceil().clamp(0.0, n - 1.0) as usize - } } impl Sampler for MirostatSampler { @@ -72,63 +47,41 @@ impl Sampler for MirostatSampler { fn transform(&self, _output: &mut [f32]) {} fn sample(&mut self, probs: &[f32]) -> u16 { - // let sorted = probs - // .iter() - // .enumerate() - // .sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse()) - // .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| { - // *cum += x; - // Some((id, *cum, *x)) - // }) - // .collect_vec(); - // let sorted_probs = sorted.iter().map(|x| x.2).collect_vec(); - - // let s = self.estimate_s(&sorted_probs); - // let k = self.compute_k(&sorted_probs, s); - - // let sum = sorted.get(k).map(|&(_, cum, _)| cum).unwrap_or_default(); - // let rand = fastrand::f32() * sum; - // let (token, _, prob) = sorted - // .into_iter() - // .find_or_first(|&(_, cum, _)| rand <= cum) - // .unwrap_or_default(); - - // let token_surprise = (1.0 / prob).log2(); - // let error_surprise = token_surprise - self.params.tau; - // self.state.max_surprise -= self.params.rate * error_surprise; + let MirostatSampler { params, state } = self; // sort the surprise values and truncate let sorted = probs .iter() .enumerate() - .sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse()); - let k = sorted - .clone() - .find_position(|&(_, x)| -x.log2() > self.state.max_surprise) - .map(|(k, _)| k + 1) - .unwrap_or(probs.len()); - let sorted = sorted.take(k).collect_vec(); - - // normalize the probs - let sum: f32 = sorted.iter().map(|&(_, x)| x).sum(); - let sorted = sorted - .into_iter() - .map(|(id, x)| (id, x / sum)) + .sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse()) .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| { - *cum += x; - Some((id, *cum, x)) + if *cum > params.top_p { + None + } else { + *cum += x; + Some((id, *cum, *x)) + } }) .collect_vec(); + let k = sorted + .iter() + .find_position(|&(_, _, x)| -x.log2() > state.max_surprise) + .map(|(k, _)| k + 1) + .unwrap_or(sorted.len()); + let sorted = sorted.into_iter().take(k).collect_vec(); - let rand = fastrand::f32(); + // normalize the probs + let sum = sorted.last().map(|(_, x, _)| *x).unwrap(); + let rand = fastrand::f32() * sum; let (token, _, prob) = sorted .into_iter() .find_or_first(|&(_, cum, _)| rand <= cum) .unwrap(); - let token_surprise = -prob.log2(); - let error_surprise = token_surprise - self.params.tau; - self.state.max_surprise -= self.params.rate * error_surprise; + let token_surprise = sum.log2() - prob.log2(); + let error_surprise = token_surprise - params.tau; + state.max_surprise -= params.rate * error_surprise; + state.max_surprise = state.max_surprise.min(4.0 * params.tau); token as u16 }