From 45dc4d93db65f239280d63d51f3f24601ca7fb41 Mon Sep 17 00:00:00 2001 From: jirigav Date: Mon, 6 Nov 2023 15:49:30 +0100 Subject: [PATCH] reuse blocks in tr when size is multiple of output --- src/bottomup.rs | 4 ++-- src/common.rs | 49 +++++++++++++++++++++++++++++++++++++++++-------- src/main.rs | 18 ++++++++++-------- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/bottomup.rs b/src/bottomup.rs index 2be80ba..0fa221a 100644 --- a/src/bottomup.rs +++ b/src/bottomup.rs @@ -257,7 +257,7 @@ pub(crate) fn bottomup( args: &Args, ) -> Vec { let mut start = Instant::now(); - let top_k = phase_one(data, args.k, args.block_size, args.base_pattern_size); + let top_k = phase_one(data, args.k, args.block_size * args.block_size_multiple, args.base_pattern_size); println!("phase one {:.2?}", start.elapsed()); start = Instant::now(); let r = phase_two( @@ -266,7 +266,7 @@ pub(crate) fn bottomup( data, validation_data_option, args.min_difference, - args.block_size, + args.block_size * args.block_size_multiple, args.max_bits, ); println!("phase two {:.2?}", start.elapsed()); diff --git a/src/common.rs b/src/common.rs index 5d17ddb..9823bd7 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,4 +1,5 @@ use clap::Parser; +use itertools::Itertools; use pyo3::prelude::*; use std::fs; @@ -16,6 +17,10 @@ pub(crate) struct Args { #[arg(short, long, default_value_t = 128)] pub(crate) block_size: usize, + /// If the value is greater than 1, CoolTest looks for distinguisher on block size that is a multiple of block_size and utilizes all such consecutive tuples. + #[arg(long, default_value_t = 1)] + pub(crate) block_size_multiple: usize, + /// Number of explored pattern branches. #[arg(short, long, default_value_t = 100)] pub(crate) k: usize, @@ -104,22 +109,28 @@ pub(crate) fn multi_eval_count( multi_eval(bits_signs, bits, tr_data, mask, is_last).count_ones() } -fn load_data(path: &str, block_size: usize) -> Vec> { - let len_of_block_in_bytes = block_size / 8; - fs::read(path) +fn load_data(path: &str, block_size: usize, block_size_multiple: usize) -> Vec> { + let len_of_block_in_bytes = (block_size * block_size_multiple) / 8; + let mut data: Vec<_> = fs::read(path) .unwrap() .chunks(len_of_block_in_bytes) .map(<[u8]>::to_vec) - .collect() + .collect(); + if data[data.len() - 1].len() != len_of_block_in_bytes { + println!("Data are not aligned with block size, dropping last block!"); + data.pop(); + } + data } pub(crate) fn prepare_data( data_source: &str, block_size: usize, + block_size_multiple: usize, halving: bool, validation: bool, ) -> (Data, Option, Option) { - let data = load_data(data_source, block_size); + let data = load_data(data_source, block_size, block_size_multiple); let training_data; let mut testing_data_option = None; let mut validation_data_option = None; @@ -129,17 +140,39 @@ pub(crate) fn prepare_data( let (val_data, test_data) = testing_data.split_at(testing_data.len() / 2); testing_data_option = Some(transform_data(test_data.to_vec())); validation_data_option = Some(transform_data(val_data.to_vec())); - training_data = transform_data(tr_data.to_vec()); + training_data = transform_training_data(tr_data.to_vec(), block_size, block_size_multiple); } else if halving { let (tr_data, testing_data) = data.split_at(data.len() / 2); testing_data_option = Some(transform_data(testing_data.to_vec())); - training_data = transform_data(tr_data.to_vec()); + training_data = transform_training_data(tr_data.to_vec(), block_size, block_size_multiple); } else { - training_data = transform_data(data); + training_data = transform_training_data(data, block_size, block_size_multiple); } + println!("tr {}, te {}", training_data.data.len(), testing_data_option.as_ref().unwrap().data.len()); (training_data, validation_data_option, testing_data_option) } +fn transform_training_data(data: Vec>, block_size: usize, block_size_multiple: usize) -> Data { + if block_size_multiple == 1{ + return transform_data(data); + } + + let data_flattened: Vec> = data.into_iter().flat_map(|x| x).collect_vec().chunks(block_size/8).map(<[u8]>::to_vec).collect(); + + let mut data_duplicated = Vec::new(); + + for i in 0..(data_flattened.len() - block_size_multiple + 1) { + let mut block: Vec = Vec::new(); + for j in 0..block_size_multiple { + block.append(&mut data_flattened[i + j].clone()); + } + data_duplicated.push(block); + + } + + transform_data(data_duplicated) +} + /// Returns data transformed into vectors of u64, where i-th u64 contains values of 64 i-th bits of consecutive blocks. pub(crate) fn transform_data(data: Vec>) -> Data { let mut result = Vec::new(); diff --git a/src/main.rs b/src/main.rs index 5948f3c..60db614 100644 --- a/src/main.rs +++ b/src/main.rs @@ -94,6 +94,7 @@ fn run_bottomup(args: Args) -> (f64, f64) { let (training_data, validation_data_option, testing_data_option) = prepare_data( &args.data_source, args.block_size, + args.block_size_multiple, true, args.validation_and_testing_split, ); @@ -116,14 +117,15 @@ fn parse_args(s: Vec<&str>) -> Args { Args { data_source: s[0].to_string(), block_size: s[1].trim().parse().unwrap(), - k: s[2].trim().parse().unwrap(), - min_difference: s[3].trim().parse().unwrap(), - top_n: s[4].trim().parse().unwrap(), - max_bits: Some(s[5].trim().parse().unwrap()), - patterns_combined: s[6].trim().parse().unwrap(), - base_pattern_size: s[7].trim().parse().unwrap(), - validation_and_testing_split: s[8].trim().parse().unwrap(), - hist: s[9].trim().parse().unwrap(), + block_size_multiple: s[2].trim().parse().unwrap(), + k: s[3].trim().parse().unwrap(), + min_difference: s[4].trim().parse().unwrap(), + top_n: s[5].trim().parse().unwrap(), + max_bits: Some(s[6].trim().parse().unwrap()), + patterns_combined: s[7].trim().parse().unwrap(), + base_pattern_size: s[8].trim().parse().unwrap(), + validation_and_testing_split: s[9].trim().parse().unwrap(), + hist: s[10].trim().parse().unwrap(), config: false, } }