Skip to content

Commit

Permalink
reuse blocks in tr when size is multiple of output
Browse files Browse the repository at this point in the history
  • Loading branch information
jirigav committed Nov 6, 2023
1 parent 7a4e25e commit 45dc4d9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/bottomup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ pub(crate) fn bottomup(
args: &Args,
) -> Vec<Pattern> {
let mut start = Instant::now();
let top_k = phase_one(data, args.k, args.block_size, args.base_pattern_size);
let top_k = phase_one(data, args.k, args.block_size * args.block_size_multiple, args.base_pattern_size);
println!("phase one {:.2?}", start.elapsed());
start = Instant::now();
let r = phase_two(
Expand All @@ -266,7 +266,7 @@ pub(crate) fn bottomup(
data,
validation_data_option,
args.min_difference,
args.block_size,
args.block_size * args.block_size_multiple,
args.max_bits,
);
println!("phase two {:.2?}", start.elapsed());
Expand Down
49 changes: 41 additions & 8 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use clap::Parser;
use itertools::Itertools;
use pyo3::prelude::*;
use std::fs;

Expand All @@ -16,6 +17,10 @@ pub(crate) struct Args {
#[arg(short, long, default_value_t = 128)]
pub(crate) block_size: usize,

/// If the value is greater than 1, CoolTest looks for distinguisher on block size that is a multiple of block_size and utilizes all such consecutive tuples.
#[arg(long, default_value_t = 1)]
pub(crate) block_size_multiple: usize,

/// Number of explored pattern branches.
#[arg(short, long, default_value_t = 100)]
pub(crate) k: usize,
Expand Down Expand Up @@ -104,22 +109,28 @@ pub(crate) fn multi_eval_count(
multi_eval(bits_signs, bits, tr_data, mask, is_last).count_ones()
}

fn load_data(path: &str, block_size: usize) -> Vec<Vec<u8>> {
let len_of_block_in_bytes = block_size / 8;
fs::read(path)
fn load_data(path: &str, block_size: usize, block_size_multiple: usize) -> Vec<Vec<u8>> {
let len_of_block_in_bytes = (block_size * block_size_multiple) / 8;
let mut data: Vec<_> = fs::read(path)
.unwrap()
.chunks(len_of_block_in_bytes)
.map(<[u8]>::to_vec)
.collect()
.collect();
if data[data.len() - 1].len() != len_of_block_in_bytes {
println!("Data are not aligned with block size, dropping last block!");
data.pop();
}
data
}

pub(crate) fn prepare_data(
data_source: &str,
block_size: usize,
block_size_multiple: usize,
halving: bool,
validation: bool,
) -> (Data, Option<Data>, Option<Data>) {
let data = load_data(data_source, block_size);
let data = load_data(data_source, block_size, block_size_multiple);
let training_data;
let mut testing_data_option = None;
let mut validation_data_option = None;
Expand All @@ -129,17 +140,39 @@ pub(crate) fn prepare_data(
let (val_data, test_data) = testing_data.split_at(testing_data.len() / 2);
testing_data_option = Some(transform_data(test_data.to_vec()));
validation_data_option = Some(transform_data(val_data.to_vec()));
training_data = transform_data(tr_data.to_vec());
training_data = transform_training_data(tr_data.to_vec(), block_size, block_size_multiple);
} else if halving {
let (tr_data, testing_data) = data.split_at(data.len() / 2);
testing_data_option = Some(transform_data(testing_data.to_vec()));
training_data = transform_data(tr_data.to_vec());
training_data = transform_training_data(tr_data.to_vec(), block_size, block_size_multiple);
} else {
training_data = transform_data(data);
training_data = transform_training_data(data, block_size, block_size_multiple);
}
println!("tr {}, te {}", training_data.data.len(), testing_data_option.as_ref().unwrap().data.len());
(training_data, validation_data_option, testing_data_option)
}

fn transform_training_data(data: Vec<Vec<u8>>, block_size: usize, block_size_multiple: usize) -> Data {
if block_size_multiple == 1{
return transform_data(data);
}

let data_flattened: Vec<Vec<u8>> = data.into_iter().flat_map(|x| x).collect_vec().chunks(block_size/8).map(<[u8]>::to_vec).collect();

let mut data_duplicated = Vec::new();

for i in 0..(data_flattened.len() - block_size_multiple + 1) {
let mut block: Vec<u8> = Vec::new();
for j in 0..block_size_multiple {
block.append(&mut data_flattened[i + j].clone());
}
data_duplicated.push(block);

}

transform_data(data_duplicated)
}

/// Returns data transformed into vectors of u64, where i-th u64 contains values of 64 i-th bits of consecutive blocks.
pub(crate) fn transform_data(data: Vec<Vec<u8>>) -> Data {
let mut result = Vec::new();
Expand Down
18 changes: 10 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ fn run_bottomup(args: Args) -> (f64, f64) {
let (training_data, validation_data_option, testing_data_option) = prepare_data(
&args.data_source,
args.block_size,
args.block_size_multiple,
true,
args.validation_and_testing_split,
);
Expand All @@ -116,14 +117,15 @@ fn parse_args(s: Vec<&str>) -> Args {
Args {
data_source: s[0].to_string(),
block_size: s[1].trim().parse().unwrap(),
k: s[2].trim().parse().unwrap(),
min_difference: s[3].trim().parse().unwrap(),
top_n: s[4].trim().parse().unwrap(),
max_bits: Some(s[5].trim().parse().unwrap()),
patterns_combined: s[6].trim().parse().unwrap(),
base_pattern_size: s[7].trim().parse().unwrap(),
validation_and_testing_split: s[8].trim().parse().unwrap(),
hist: s[9].trim().parse().unwrap(),
block_size_multiple: s[2].trim().parse().unwrap(),
k: s[3].trim().parse().unwrap(),
min_difference: s[4].trim().parse().unwrap(),
top_n: s[5].trim().parse().unwrap(),
max_bits: Some(s[6].trim().parse().unwrap()),
patterns_combined: s[7].trim().parse().unwrap(),
base_pattern_size: s[8].trim().parse().unwrap(),
validation_and_testing_split: s[9].trim().parse().unwrap(),
hist: s[10].trim().parse().unwrap(),
config: false,
}
}
Expand Down

0 comments on commit 45dc4d9

Please sign in to comment.