Skip to content

Commit

Permalink
fix bug for k=1
Browse files Browse the repository at this point in the history
  • Loading branch information
jirigav committed May 14, 2024
1 parent 0f489a1 commit b07a6f2
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions src/bottomup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ fn compute_bins(
}
}

fn brute_force(data: &Data, block_size: usize, deg: usize, k: usize) -> Vec<Histogram> {
fn brute_force(data: &Data, block_size: usize, k: usize, top: usize) -> Vec<Histogram> {
let mut hists: Vec<Vec<usize>> = Vec::new();
for i in 0..block_size {
let ones = multi_eval(&[i], data);
hists.push(vec![(data._num_of_blocks as usize) - ones, ones])
}

for d in 2..deg {
let mut new_hists = Vec::with_capacity(2_usize.pow(deg as u32));
for d in 2..k {
let mut new_hists = Vec::with_capacity(2_usize.pow(k as u32));

for bits in (0..block_size).combinations(d) {
let mut bins = vec![0; 2_usize.pow(d as u32)];
Expand All @@ -177,17 +177,28 @@ fn brute_force(data: &Data, block_size: usize, deg: usize, k: usize) -> Vec<Hist
}
hists = new_hists;
}
let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); k];
let mut bins = vec![0; 2_usize.pow(deg as u32)];
for bits in (0..block_size).combinations(deg) {
compute_bins(&bits, data, deg, &hists, &mut bins, block_size);
let hist = Histogram::from_bins(bits, &bins);
best_hists.push(hist);
best_hists.sort_by(|a, b| b.z_score.abs().partial_cmp(&a.z_score.abs()).unwrap());
best_hists.pop();
if k > 1 {
let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); top];
let mut bins = vec![0; 2_usize.pow(k as u32)];
for bits in (0..block_size).combinations(k) {
compute_bins(&bits, data, k, &hists, &mut bins, block_size);
let hist = Histogram::from_bins(bits, &bins);
best_hists.push(hist);
best_hists.sort_by(|a, b| b.z_score.abs().partial_cmp(&a.z_score.abs()).unwrap());
best_hists.pop();
}
best_hists
} else {
let bits = (0..block_size).combinations(k).collect_vec();
let mut best: Vec<_> = hists
.into_iter()
.enumerate()
.map(|(i, bins)| Histogram::from_bins(bits[i].clone(), &bins))
.collect();

best.sort_by(|a, b| b.z_score.partial_cmp(&a.z_score).unwrap());
best.into_iter().take(top).collect()
}

best_hists
}

fn _combine_bins(hists: &[Histogram], n: usize, data: &[Vec<u8>]) -> Histogram {
Expand All @@ -209,18 +220,18 @@ fn _combine_bins(hists: &[Histogram], n: usize, data: &[Vec<u8>]) -> Histogram {
pub(crate) fn bottomup(
data: &[Vec<u8>],
block_size: usize,
base_degree: usize,
k: usize,
top: usize,
max_bits: usize,
threads: usize,
) -> Histogram {
let mut top_k = if threads == 0 {
brute_force(&transform_data(data), block_size, base_degree, k)
brute_force(&transform_data(data), block_size, k, top)
} else {
brute_force_threads(&transform_data(data), block_size, base_degree, k, threads)
brute_force_threads(&transform_data(data), block_size, k, top, threads)
};

if max_bits > base_degree {
if max_bits > k {
top_k = phase_two(data, block_size, top_k, max_bits);
}

Expand Down Expand Up @@ -315,8 +326,8 @@ pub(crate) fn multi_eval_neg(
fn brute_force_threads(
data: &Data,
block_size: usize,
deg: usize,
k: usize,
top: usize,
threads: usize,
) -> Vec<Histogram> {
rayon::ThreadPoolBuilder::new()
Expand All @@ -338,12 +349,12 @@ fn brute_force_threads(
let mut hists: Vec<Histogram> = (0..threads)
.into_par_iter()
.map(|i| {
let combs = (0..block_size).combinations(deg).skip(i);
let combs = (0..block_size).combinations(k).skip(i);

let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); k];
let mut best_hists = vec![Histogram::from_bins(vec![0], &[1, 1]); top];

for bits in combs.step_by(threads) {
let mut bins = vec![0; 2_usize.pow(deg as u32)];
let mut bins = vec![0; 2_usize.pow(k as u32)];
for (i, bin) in bins.iter_mut().enumerate() {
*bin = multi_eval_neg(&bits, data, &neg_data, i);
}
Expand All @@ -358,7 +369,7 @@ fn brute_force_threads(
.collect();
hists.sort_by(|a, b| b.z_score.abs().partial_cmp(&a.z_score.abs()).unwrap());

hists = hists.into_iter().take(k).collect();
hists = hists.into_iter().take(top).collect();

hists
}

0 comments on commit b07a6f2

Please sign in to comment.