diff --git a/Cargo.lock b/Cargo.lock index 8a3f8e2..9991ee6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,7 @@ dependencies = [ "serde_json", "structopt", "uuid", + "zstd", ] [[package]] @@ -121,6 +122,9 @@ name = "cc" version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +dependencies = [ + "jobserver", +] [[package]] name = "cfg-if" @@ -424,6 +428,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" +[[package]] +name = "jobserver" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa" +dependencies = [ + "libc", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1293,3 +1306,32 @@ name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.1+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fd07cbbc53846d9145dbffdf6dd09a7a0aa52be46741825f5c97bdd4f73f12b" +dependencies = [ + "cc", + "libc", +] diff --git a/Cargo.toml b/Cargo.toml index 39c8eab..98a24f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ serde = { version = "1.0", features = ["derive"] } structopt = "0.3" uuid = { version = "0.8", features = [ "v4"] } mimalloc = "0.1.29" +zstd = "0.11.2" [dev-dependencies] pretty_assertions = "1.2.1" diff --git a/src/anonymiser.rs b/src/anonymiser.rs index 6219483..e4e7a12 100644 --- a/src/anonymiser.rs +++ b/src/anonymiser.rs @@ -6,11 +6,12 @@ pub fn anonymise( input_file: String, output_file: String, strategy_file: String, + compress_output: bool, transformer_overrides: TransformerOverrides, ) -> Result<(), std::io::Error> { match strategy_file::read(&strategy_file, transformer_overrides) { Ok(strategies) => { - file_reader::read(input_file, output_file, &strategies)?; + file_reader::read(input_file, output_file, &strategies, compress_output)?; Ok(()) } Err(_) => { @@ -33,6 +34,7 @@ mod tests { "test_files/dump_file.sql".to_string(), "test_files/results.sql".to_string(), "non_existing_strategy_file.json".to_string(), + false, TransformerOverrides::none(), ) .is_ok()); @@ -45,6 +47,7 @@ mod tests { "non_existing_input_file.sql".to_string(), "test_files/results.sql".to_string(), "test_files/strategy.json".to_string(), + false, TransformerOverrides::none(), ) .is_ok()); @@ -56,6 +59,7 @@ mod tests { "test_files/dump_file.sql".to_string(), "test_files/results.sql".to_string(), "test_files/strategy.json".to_string(), + false, TransformerOverrides::none(), ) .is_ok()); diff --git a/src/file_reader.rs b/src/file_reader.rs index ef81c2f..2c28d8b 100644 --- a/src/file_reader.rs +++ b/src/file_reader.rs @@ -11,9 +11,13 @@ pub fn read( input_file_path: String, output_file_path: String, strategies: &Strategies, + compress_output: bool, ) -> Result<(), std::io::Error> { - let output_file = File::create(output_file_path).unwrap(); - let mut file_writer = BufWriter::new(output_file); + let output_file = File::create(output_file_path)?; + let mut file_writer: Box = match compress_output { + true => Box::new(zstd::Encoder::new(output_file, 1)?.auto_finish()), + false => Box::new(BufWriter::new(output_file)), + }; let file_reader = File::open(&input_file_path) .unwrap_or_else(|_| panic!("Input file '{}' does not exist", input_file_path)); @@ -26,21 +30,14 @@ pub fn read( let mut rng = rng::get(); loop { - match reader.read_line(&mut line) { - Ok(bytes_read) => { - if bytes_read == 0 { - break; - } - - let transformed_row = - row_parser::parse(&mut rng, &line, &mut row_parser_state, strategies); - file_writer.write_all(transformed_row.as_bytes())?; - line.clear(); - } - Err(err) => { - return Err(err); - } + let bytes_read = reader.read_line(&mut line)?; + if bytes_read == 0 { + break; } + + let transformed_row = row_parser::parse(&mut rng, &line, &mut row_parser_state, strategies); + file_writer.write_all(transformed_row.as_bytes())?; + line.clear(); } Ok(()) } @@ -49,14 +46,13 @@ pub fn read( mod tests { use super::*; use crate::parsers::strategy_structs::*; + use crate::uncompress::uncompress; use pretty_assertions::assert_eq; use std::collections::HashMap; use std::fs; + use std::path::PathBuf; - #[test] - fn can_read() { - let input_file = "test_files/dump_file.sql".to_string(); - let output_file = "test_files/file_reader_test_results.sql".to_string(); + fn default_strategies() -> Strategies { let mut strategies = Strategies::new(); strategies.insert( "public.orders".to_string(), @@ -90,8 +86,17 @@ mod tests { strategy_tuple("phone_number"), ]), ); + strategies + } + + #[test] + fn can_read() { + let input_file = "test_files/dump_file.sql".to_string(); + let output_file = "test_files/file_reader_test_results.sql".to_string(); + let _ = fs::remove_file(&output_file).ok(); + let strategies = default_strategies(); - assert!(read(input_file.clone(), output_file.clone(), &strategies).is_ok()); + assert!(read(input_file.clone(), output_file.clone(), &strategies, false).is_ok()); let original = fs::read_to_string(&input_file).expect("Something went wrong reading the file"); @@ -102,6 +107,40 @@ mod tests { assert_eq!(original, processed); } + #[test] + fn can_read_and_output_compressed() { + let input_file = "test_files/dump_file.sql".to_string(); + let compressed_file = "test_files/compressed_file_reader_test_results.sql".to_string(); + let uncompressed_file_name = "test_files/uncompressed_file_reader_test_results.sql"; + + let _ = fs::remove_file(&compressed_file); + let _ = fs::remove_file(&uncompressed_file_name); + + let strategies = default_strategies(); + + assert!(read( + input_file.clone(), + compressed_file.clone(), + &strategies, + true + ) + .is_ok()); + + uncompress( + PathBuf::from(&compressed_file), + Some(PathBuf::from(uncompressed_file_name)), + ) + .expect("Should not fail to uncompress!"); + + let original = + fs::read_to_string(&input_file).expect("Something went wrong reading the file"); + + let processed = fs::read_to_string(&uncompressed_file_name) + .expect("Something went wrong reading the file"); + + assert_eq!(original, processed); + } + fn strategy_tuple(column_name: &str) -> (String, ColumnInfo) { ( column_name.to_string(), diff --git a/src/main.rs b/src/main.rs index 4f11da3..58df7d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,17 @@ mod anonymiser; mod file_reader; -use std::fmt::Write; mod fixer; mod opts; mod parsers; +mod uncompress; + use crate::opts::{Anonymiser, Opts}; use crate::parsers::strategies::Strategies; use crate::parsers::strategy_structs::{MissingColumns, SimpleColumn, TransformerOverrides}; use itertools::Itertools; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; +use std::fmt::Write; use parsers::{db_schema, strategy_file}; use structopt::StructOpt; @@ -27,6 +29,7 @@ fn main() -> Result<(), std::io::Error> { input_file, output_file, strategy_file, + compress_output, allow_potential_pii, allow_commercially_sensitive, } => { @@ -39,6 +42,7 @@ fn main() -> Result<(), std::io::Error> { input_file, output_file, strategy_file, + compress_output, transformer_overrides, )? } @@ -84,6 +88,10 @@ fn main() -> Result<(), std::io::Error> { } } } + Anonymiser::Uncompress { + input_file, + output_file, + } => uncompress::uncompress(input_file, output_file).expect("failed to uncompress"), } Ok(()) } diff --git a/src/opts.rs b/src/opts.rs index c987cdb..dba141a 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use structopt::StructOpt; #[derive(Debug, StructOpt)] #[structopt(name = "Anonymiser", about = "Anonymise your database backups!")] @@ -16,6 +18,9 @@ pub enum Anonymiser { output_file: String, #[structopt(short, long, default_value = "./strategy.json")] strategy_file: String, + /// Compress output using zstd + #[structopt(short, long)] + compress_output: bool, /// Does not transform PotentiallPii data types #[structopt(long)] allow_potential_pii: bool, @@ -49,4 +54,13 @@ pub enum Anonymiser { #[structopt(short, long, env = "DATABASE_URL")] db_url: String, }, + /// Uncompress a zstd sql dump to a file, or stdout if no file specified + Uncompress { + /// Input file (*.sql.zst) + #[structopt(short, long)] + input_file: PathBuf, + /// Output file, will write to standard output if not specified + #[structopt(short, long)] + output_file: Option, + }, } diff --git a/src/uncompress.rs b/src/uncompress.rs new file mode 100644 index 0000000..9854184 --- /dev/null +++ b/src/uncompress.rs @@ -0,0 +1,64 @@ +use std::fs::File; +use std::path::PathBuf; + +pub fn uncompress(input_file: PathBuf, output_file: Option) -> Result<(), std::io::Error> { + let input = File::open(input_file)?; + match output_file { + Some(output) => zstd::stream::copy_decode(input, File::create(output)?), + None => zstd::stream::copy_decode(input, std::io::stdout()), + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use crate::{anonymiser::anonymise, parsers::strategy_structs::TransformerOverrides}; + + use super::uncompress; + + #[test] + fn compress_gives_correct_output() { + let test_dir_path = PathBuf::from("test_files/compress"); + std::fs::create_dir_all(&test_dir_path).unwrap(); + + anonymise( + "test_files/dump_file.sql".to_string(), + "test_files/compress/results.sql".to_string(), + "test_files/strategy.json".to_string(), + false, + TransformerOverrides::none(), + ) + .unwrap(); + + anonymise( + "test_files/dump_file.sql".to_string(), + "test_files/compress/results.sql.zst".to_string(), + "test_files/strategy.json".to_string(), + true, + TransformerOverrides::none(), + ) + .unwrap(); + + uncompress( + PathBuf::from("test_files/compress/results.sql.zst"), + Some(test_dir_path.join("uncompressed.sql")), + ) + .unwrap(); + + // Can't compare actual content because of randomization, but # of lines + // should be the same + assert_eq!( + std::fs::read_to_string("test_files/compress/results.sql") + .unwrap() + .lines() + .count(), + std::fs::read_to_string("test_files/compress/uncompressed.sql") + .unwrap() + .lines() + .count() + ); + + std::fs::remove_dir_all(test_dir_path).unwrap(); + } +}