From da23d14b8fca0696a091bf158c16f64a75ad66a1 Mon Sep 17 00:00:00 2001 From: MarcoDiFrancesco Date: Fri, 3 May 2024 16:16:25 +0200 Subject: [PATCH] Add synthetic dataset download --- src/datasets/keystroke.rs | 2 +- src/datasets/synthetic.rs | 10 ++++++++-- src/datasets/utils.rs | 11 ++++++----- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/datasets/keystroke.rs b/src/datasets/keystroke.rs index e96c0d3..b1cdac8 100644 --- a/src/datasets/keystroke.rs +++ b/src/datasets/keystroke.rs @@ -23,7 +23,7 @@ impl Keystroke { let file_name = "keystroke.csv"; if !Path::new(file_name).exists() { - utils::download_csv_file(url, file_name)? + utils::download_csv_file(url, file_name); } let file = File::open(file_name)?; let y_cols = Some(Target::Name("subject".to_string())); diff --git a/src/datasets/synthetic.rs b/src/datasets/synthetic.rs index 1e0c80d..8bd8f76 100644 --- a/src/datasets/synthetic.rs +++ b/src/datasets/synthetic.rs @@ -1,6 +1,9 @@ use crate::stream::data_stream::Target; use crate::stream::iter_csv::IterCsv; use std::fs::File; +use std::path::Path; + +use super::utils; /// ChatGPT Generated synthetic dataset. /// @@ -8,8 +11,11 @@ use std::fs::File; pub struct Synthetic; impl Synthetic { pub fn load_data() -> IterCsv { - // let file_name = "syntetic_dataset_paper.csv"; - let file_name = "syntetic_dataset_int.csv"; + let url = "https://marcodifrancesco.com/assets/img/LightRiver/syntetic_dataset.csv"; + let file_name = "syntetic_dataset.csv"; + if !Path::new(file_name).exists() { + utils::download_csv_file(url, file_name); + } let file = File::open(file_name).unwrap(); let y_cols = Some(Target::Name("label".to_string())); IterCsv::::new(file, y_cols).unwrap() diff --git a/src/datasets/utils.rs b/src/datasets/utils.rs index 4e8f938..bda89ed 100644 --- a/src/datasets/utils.rs +++ b/src/datasets/utils.rs @@ -1,5 +1,6 @@ use reqwest::blocking::Client; use std::fs::File; +use std::io; use std::path::Path; use zip::ZipArchive; @@ -31,9 +32,9 @@ pub(crate) fn download_zip_file( Ok(()) } -pub(crate) fn download_csv_file( - url: &str, - file_name: &str, -) -> Result<(), Box> { - unimplemented!("For now download the file in the root directory of the project and rename it to 'keystroke.csv'"); +pub(crate) fn download_csv_file(url: &str, file_name: &str) { + let resp = reqwest::blocking::get(url).expect("request failed"); + let body = resp.text().expect("body invalid"); + let mut out = File::create(file_name).expect("failed to create file"); + io::copy(&mut body.as_bytes(), &mut out).expect("failed to copy content"); }