Skip to content

Commit

Permalink
Fix typical.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 24, 2024
1 parent 720569e commit 116f885
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions crates/ai00-core/src/sampler/typical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ impl Sampler for TypicalSampler {

let probs = probs
.iter()
.filter(|&x| *x > 0.0)
.map(|&x| (x, -x.ln()))
.enumerate()
.filter(|(_, &x)| x > 0.0)
.map(|(id, &x)| (id, x, -x.ln()))
.collect_vec();
let entropy = probs.iter().map(|(x, y)| x * y).sum::<f32>();
let entropy = probs.iter().map(|(_, x, y)| x * y).sum::<f32>();
let sorted = probs
.into_iter()
.map(|(x, y)| (x, (y - entropy).abs()))
.enumerate()
.sorted_unstable_by(|(_, (_, x)), (_, (_, y))| x.total_cmp(y))
.map(|(id, (x, _))| (id, x))
.map(|(id, x, y)| (id, x, (y - entropy).abs()))
.sorted_unstable_by(|(_, _, x), (_, _, y)| x.total_cmp(y))
.map(|(id, x, _)| (id, x))
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.tau {
Expand Down

0 comments on commit 116f885

Please sign in to comment.