diff --git a/src/bottomup.rs b/src/bottomup.rs index b55ca5b..7541b8f 100644 --- a/src/bottomup.rs +++ b/src/bottomup.rs @@ -159,15 +159,15 @@ fn compute_bins( } } -fn brute_force(data: &Data, block_size: usize, deg: usize, k: usize) -> Vec { +fn brute_force(data: &Data, block_size: usize, k: usize, top: usize) -> Vec { let mut hists: Vec> = Vec::new(); for i in 0..block_size { let ones = multi_eval(&[i], data); hists.push(vec![(data._num_of_blocks as usize) - ones, ones]) } - for d in 2..deg { - let mut new_hists = Vec::with_capacity(2_usize.pow(deg as u32)); + for d in 2..k { + let mut new_hists = Vec::with_capacity(2_usize.pow(k as u32)); for bits in (0..block_size).combinations(d) { let mut bins = vec![0; 2_usize.pow(d as u32)]; @@ -177,17 +177,28 @@ fn brute_force(data: &Data, block_size: usize, deg: usize, k: usize) -> Vec 1 { + let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); top]; + let mut bins = vec![0; 2_usize.pow(k as u32)]; + for bits in (0..block_size).combinations(k) { + compute_bins(&bits, data, k, &hists, &mut bins, block_size); + let hist = Histogram::from_bins(bits, &bins); + best_hists.push(hist); + best_hists.sort_by(|a, b| b.z_score.abs().partial_cmp(&a.z_score.abs()).unwrap()); + best_hists.pop(); + } + best_hists + } else { + let bits = (0..block_size).combinations(k).collect_vec(); + let mut best: Vec<_> = hists + .into_iter() + .enumerate() + .map(|(i, bins)| Histogram::from_bins(bits[i].clone(), &bins)) + .collect(); + + best.sort_by(|a, b| b.z_score.partial_cmp(&a.z_score).unwrap()); + best.into_iter().take(top).collect() } - - best_hists } fn _combine_bins(hists: &[Histogram], n: usize, data: &[Vec]) -> Histogram { @@ -209,18 +220,18 @@ fn _combine_bins(hists: &[Histogram], n: usize, data: &[Vec]) -> Histogram { pub(crate) fn bottomup( data: &[Vec], block_size: usize, - base_degree: usize, k: usize, + top: usize, max_bits: usize, threads: usize, ) -> Histogram { let mut top_k = if threads == 0 { - brute_force(&transform_data(data), block_size, base_degree, k) + brute_force(&transform_data(data), block_size, k, top) } else { - brute_force_threads(&transform_data(data), block_size, base_degree, k, threads) + brute_force_threads(&transform_data(data), block_size, k, top, threads) }; - if max_bits > base_degree { + if max_bits > k { top_k = phase_two(data, block_size, top_k, max_bits); } @@ -315,8 +326,8 @@ pub(crate) fn multi_eval_neg( fn brute_force_threads( data: &Data, block_size: usize, - deg: usize, k: usize, + top: usize, threads: usize, ) -> Vec { rayon::ThreadPoolBuilder::new() @@ -338,12 +349,12 @@ fn brute_force_threads( let mut hists: Vec = (0..threads) .into_par_iter() .map(|i| { - let combs = (0..block_size).combinations(deg).skip(i); + let combs = (0..block_size).combinations(k).skip(i); - let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); k]; + let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); top]; for bits in combs.step_by(threads) { - let mut bins = vec![0; 2_usize.pow(deg as u32)]; + let mut bins = vec![0; 2_usize.pow(k as u32)]; for (i, bin) in bins.iter_mut().enumerate() { *bin = multi_eval_neg(&bits, data, &neg_data, i); } @@ -358,7 +369,7 @@ fn brute_force_threads( .collect(); hists.sort_by(|a, b| b.z_score.abs().partial_cmp(&a.z_score.abs()).unwrap()); - hists = hists.into_iter().take(k).collect(); + hists = hists.into_iter().take(top).collect(); hists }