Skip to content

Commit

Permalink
Merge pull request #6 from jirigav/evaluate
Browse files Browse the repository at this point in the history
Evaluate
  • Loading branch information
jirigav authored Jun 25, 2024
2 parents b07a6f2 + 1e4424a commit 1f21b26
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/target
/data/*
config
/.vscode

47 changes: 46 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cooltest"
version = "0.1.0"
version = "0.1.1"
edition = "2021"


Expand All @@ -11,3 +11,5 @@ rayon="1.7"
clap = { version = "4.3", features = ["derive"] }
itertools="0.11"
pyo3={ version = "0.19", features = ["auto-initialize"] }
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
3 changes: 2 additions & 1 deletion src/bottomup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use itertools::Itertools;
use rayon::iter::*;

use crate::common::{bits_block_eval, multi_eval, transform_data, z_score, Data};
use serde::{Deserialize, Serialize};

#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct Histogram {
pub(crate) bits: Vec<usize>,
pub(crate) sorted_indices: Vec<usize>,
Expand Down
43 changes: 33 additions & 10 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use clap::Parser;
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
use std::fs;

pub(crate) fn z_score(sample_size: usize, positive: usize, p: f64) -> f64 {
((positive as f64) - p * (sample_size as f64)) / f64::sqrt(p * (1.0 - p) * (sample_size as f64))
}

#[derive(Parser, Debug)]
#[derive(Parser, Debug, Serialize, Deserialize, Clone)]
#[command(version)]
pub(crate) struct Args {
/// Path of file with input data.
Expand Down Expand Up @@ -35,8 +36,24 @@ pub(crate) struct Args {
/// Number of threads for multi-thread run. 0 means that efficient single thread implementation is used.
#[arg(short, long, default_value_t = 0)]
pub(crate) threads: usize,

/// Path where json output should be stored. If no path provided, json output is not stored.
#[arg(short, long)]
pub(crate) json: Option<String>,

#[clap(subcommand)]
pub subcommand: Option<SubCommand>,
}

#[derive(Parser, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub(crate) enum SubCommand {
/// Evaluate a given distinguisher on given data and report p-value.
Evaluate {
/// Path of file with distinguisher which should be evaluated.
#[arg(short, long)]
dis_path: String,
},
}
pub(crate) fn bits_block_eval(bits: &[usize], block: &[u8]) -> usize {
let mut result = 0;

Expand Down Expand Up @@ -92,12 +109,19 @@ fn load_data(path: &str, block_size: usize) -> Vec<Vec<u8>> {
data
}

pub(crate) fn prepare_data(data_source: &str, block_size: usize) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
pub(crate) fn prepare_data(
data_source: &str,
block_size: usize,
training_data: bool,
) -> (Vec<Vec<u8>>, Option<Vec<Vec<u8>>>) {
let data = load_data(data_source, block_size);
if !training_data {
(data, None)
} else {
let (tr_data, testing_data) = data.split_at(data.len() / 2);

let (tr_data, testing_data) = data.split_at(data.len() / 2);

(tr_data.to_vec(), testing_data.to_vec())
(tr_data.to_vec(), Some(testing_data.to_vec()))
}
}

/// Returns data transformed into vectors of u64, where i-th u64 contains values of 64 i-th bits of consecutive blocks.
Expand Down Expand Up @@ -139,12 +163,11 @@ pub(crate) fn transform_data(data: &[Vec<u8>]) -> Data {

pub(crate) fn p_value(positive: usize, sample_size: usize, probability: f64) -> f64 {
Python::with_gil(|py| {
let scipy = PyModule::import(py, "scipy").unwrap();
let result: f64 = scipy
.getattr("stats")
.unwrap()
let scipy_stats = PyModule::import(py, "scipy.stats")
.expect("SciPy not installed! Use `pip install scipy` to install the library.");
let result: f64 = scipy_stats
.getattr("binomtest")
.unwrap()
.expect("Scipy binomtest not found! Make sure that your version os SciPy is >=1.7.0.")
.call1((positive, sample_size, probability, "two-sided"))
.unwrap()
.getattr("pvalue")
Expand Down
76 changes: 54 additions & 22 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ use crate::bottomup::bottomup;
use crate::common::{p_value, z_score, Args};
use bottomup::Histogram;
use clap::Parser;
use common::prepare_data;
use common::{prepare_data, SubCommand};
use serde_json::json;
use std::fs::{self, File};
use std::io::Write;
use std::time::Instant;

fn print_results(p_value: f64, z_score: f64, alpha: f64, hist: Histogram, bins: Vec<usize>) {
fn print_results(p_value: f64, z_score: f64, alpha: f64, hist: &Histogram, bins: Vec<usize>) {
println!("----------------------------------------------------------------------");
println!("RESULTS:\n");

Expand Down Expand Up @@ -51,8 +54,44 @@ fn print_results(p_value: f64, z_score: f64, alpha: f64, hist: Histogram, bins:
}
}

fn results(hist: Histogram, testing_data: &[Vec<u8>], args: Args) {
let (count, bins) = hist.evaluate(testing_data);
let prob = 2.0_f64.powf(-(hist.bits.len() as f64));
let z = z_score(
testing_data.len(),
count,
prob * (hist.best_division as f64),
);
let p_val = p_value(
count,
testing_data.len(),
prob * (hist.best_division as f64),
);
print_results(p_val, z, args.alpha, &hist, bins);

if let Some(path) = args.json.clone() {
let mut file =
File::create(&path).unwrap_or_else(|_| panic!("File {} couldn't be created", path));

let output = json!({
"args": args,
"dis": hist,
"result": if p_val < args.alpha {"random"} else {"non-random"},
"p-value": p_val
});

file.write_all(
serde_json::to_string_pretty(&output)
.expect("Failed to produce json!")
.as_bytes(),
)
.unwrap();
}
}

fn run_bottomup(args: Args) {
let (training_data, testing_data) = prepare_data(&args.data_source, args.block);
let (training_data, testing_data) = prepare_data(&args.data_source, args.block, true);
let testing_data = testing_data.unwrap();

let start = Instant::now();
let hist = bottomup(
Expand All @@ -65,29 +104,22 @@ fn run_bottomup(args: Args) {
);
println!("training finished in {:?}", start.elapsed());

let (count, bins) = hist.evaluate(&testing_data);
let prob = 2.0_f64.powf(-(hist.bits.len() as f64));
let z = z_score(
testing_data.len(),
count,
prob * (hist.best_division as f64),
);
print_results(
p_value(
count,
testing_data.len(),
prob * (hist.best_division as f64),
),
z,
args.alpha,
hist,
bins,
);
results(hist, &testing_data, args)
}

fn main() {
let args = Args::parse();
println!("\n{args:?}\n");

run_bottomup(args)
match args.subcommand.clone() {
Some(SubCommand::Evaluate { dis_path }) => {
let contents = fs::read_to_string(&dis_path)
.unwrap_or_else(|_| panic!("Failed to read contents of {}", &dis_path));
let hist: Histogram =
serde_json::from_str(&contents).expect("Invalid distinguisher json!");
let (testing_data, _) = prepare_data(&args.data_source, args.block, false);
results(hist, &testing_data, args)
}
None => run_bottomup(args),
}
}

0 comments on commit 1f21b26

Please sign in to comment.