diff --git a/crates/ai00-core/src/sampler/typical.rs b/crates/ai00-core/src/sampler/typical.rs index 9f6a3ea4..4029a764 100644 --- a/crates/ai00-core/src/sampler/typical.rs +++ b/crates/ai00-core/src/sampler/typical.rs @@ -71,16 +71,16 @@ impl Sampler for TypicalSampler { let probs = probs .iter() - .filter(|&x| *x > 0.0) - .map(|&x| (x, -x.ln())) + .enumerate() + .filter(|(_, &x)| x > 0.0) + .map(|(id, &x)| (id, x, -x.ln())) .collect_vec(); - let entropy = probs.iter().map(|(x, y)| x * y).sum::(); + let entropy = probs.iter().map(|(_, x, y)| x * y).sum::(); let sorted = probs .into_iter() - .map(|(x, y)| (x, (y - entropy).abs())) - .enumerate() - .sorted_unstable_by(|(_, (_, x)), (_, (_, y))| x.total_cmp(y)) - .map(|(id, (x, _))| (id, x)) + .map(|(id, x, y)| (id, x, (y - entropy).abs())) + .sorted_unstable_by(|(_, _, x), (_, _, y)| x.total_cmp(y)) + .map(|(id, x, _)| (id, x)) .take(params.top_k) .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| { if *cum > params.tau {