From 559d06dc0713e8ecf3f3e9c4a489db5f50d64232 Mon Sep 17 00:00:00 2001 From: hgrsd Date: Tue, 2 Apr 2024 17:22:27 +0100 Subject: [PATCH] deal with small data sets --- src/infer.rs | 66 +++++++++++++++++++++++++++++++++++--------------- src/produce.rs | 49 +++++++++++++++++++++++++------------ src/schema.rs | 44 +++++++++++++++++++++------------ 3 files changed, 108 insertions(+), 51 deletions(-) diff --git a/src/infer.rs b/src/infer.rs index c74dedd..cadc788 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -31,12 +31,14 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState { // --- String merging --- ( SchemaState::String(StringType::Unknown { - char_distribution: mut charset, + mut chars_seen, + mut strings_seen, min_length, max_length, }), SchemaState::String(StringType::Unknown { - char_distribution: second_charset, + chars_seen: second_chars_seen, + strings_seen: second_strings_seen, min_length: second_min_length, max_length: second_max_length, }), @@ -57,10 +59,12 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState { second_max_length }; - charset.extend(second_charset); + chars_seen.extend(second_chars_seen); + strings_seen.extend(second_strings_seen); SchemaState::String(StringType::Unknown { - char_distribution: charset, + chars_seen, + strings_seen, min_length, max_length, }) @@ -68,7 +72,8 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState { ( SchemaState::String(StringType::Unknown { - char_distribution: charset, + chars_seen, + strings_seen, min_length, max_length, }), @@ -77,12 +82,14 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState { | ( SchemaState::String(_), SchemaState::String(StringType::Unknown { - char_distribution: charset, + chars_seen, + strings_seen, min_length, max_length, }), ) => SchemaState::String(StringType::Unknown { - char_distribution: charset, + chars_seen, + strings_seen, min_length, max_length, }), @@ -92,7 +99,8 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState { SchemaState::String(first_type) } else { SchemaState::String(StringType::Unknown { - char_distribution: vec![], + chars_seen: vec![], + strings_seen: std::collections::HashSet::new(), min_length: None, max_length: None, }) @@ -285,7 +293,7 @@ fn infer_array_schema(values: &[serde_json::Value]) -> SchemaState { /// /// ``` /// use serde_json::json; -/// use std::collections::HashMap; +/// use std::collections::{HashMap, HashSet}; /// use drivel::{infer_schema, SchemaState, StringType, NumberType}; /// /// // Define a JSON value @@ -301,7 +309,8 @@ fn infer_array_schema(values: &[serde_json::Value]) -> SchemaState { /// SchemaState::Object { /// required: HashMap::from_iter([ /// ("name".to_string(), SchemaState::String(StringType::Unknown { -/// char_distribution: vec!['J', 'o', 'h', 'n'], +/// chars_seen: vec!['J', 'o', 'h', 'n'], +/// strings_seen: HashSet::from_iter(["John".to_string()]), /// min_length: Some(4), /// max_length: Some(4) /// })), @@ -331,7 +340,8 @@ pub fn infer_schema(json: &serde_json::Value) -> SchemaState { StringType::UUID } else { StringType::Unknown { - char_distribution: value.chars().collect(), + chars_seen: value.chars().collect(), + strings_seen: std::collections::HashSet::from_iter([value.clone()]), min_length: Some(value.len()), max_length: Some(value.len()), } @@ -387,7 +397,8 @@ mod tests { assert_eq!( schema, SchemaState::String(StringType::Unknown { - char_distribution: vec!['f', 'o', 'o'], + chars_seen: vec!['f', 'o', 'o'], + strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]), min_length: Some(3), max_length: Some(3) }) @@ -497,7 +508,8 @@ mod tests { ( "string".to_string(), SchemaState::String(StringType::Unknown { - char_distribution: vec!['f', 'o', 'o'], + chars_seen: vec!['f', 'o', 'o'], + strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]), min_length: Some(3), max_length: Some(3) }) @@ -520,7 +532,10 @@ mod tests { min_length: 1, max_length: 1, schema: Box::new(SchemaState::String(StringType::Unknown { - char_distribution: vec!['b', 'a', 'z'], + chars_seen: vec!['b', 'a', 'z'], + strings_seen: std::collections::HashSet::from_iter([ + "baz".to_owned() + ]), min_length: Some(3), max_length: Some(3) })) @@ -533,7 +548,10 @@ mod tests { required: std::collections::HashMap::from_iter([( "string".to_owned(), SchemaState::String(StringType::Unknown { - char_distribution: vec!['f', 'o', 'o'], + chars_seen: vec!['f', 'o', 'o'], + strings_seen: std::collections::HashSet::from_iter([ + "foo".to_owned() + ]), min_length: Some(3), max_length: Some(3) }) @@ -573,7 +591,11 @@ mod tests { min_length: 2, max_length: 2, schema: Box::new(SchemaState::String(StringType::Unknown { - char_distribution: vec!['f', 'o', 'o', 'b', 'a', 'r', 'b', 'a', 'r'], + chars_seen: vec!['f', 'o', 'o', 'b', 'a', 'r', 'b', 'a', 'r'], + strings_seen: std::collections::HashSet::from_iter([ + "foo".to_owned(), + "barbar".to_owned(), + ]), min_length: Some(3), max_length: Some(6) })) @@ -592,7 +614,8 @@ mod tests { min_length: 2, max_length: 2, schema: Box::new(SchemaState::String(StringType::Unknown { - char_distribution: vec!['b', 'a', 'r', 'b', 'a', 'r'], + chars_seen: vec!['b', 'a', 'r', 'b', 'a', 'r'], + strings_seen: std::collections::HashSet::from_iter(["barbar".to_owned(),]), min_length: Some(6), max_length: Some(6), })) @@ -689,7 +712,11 @@ mod tests { optional: std::collections::HashMap::from_iter([( "foo".to_owned(), SchemaState::String(StringType::Unknown { - char_distribution: vec!['b', 'a', 'r', 'b', 'a', 'r', 'b', 'a', 'r'], + chars_seen: vec!['b', 'a', 'r', 'b', 'a', 'r', 'b', 'a', 'r'], + strings_seen: std::collections::HashSet::from_iter([ + "bar".to_owned(), + "barbar".to_owned(), + ]), min_length: Some(3), max_length: Some(6) }) @@ -733,7 +760,8 @@ mod tests { max_length: 2, schema: Box::new(SchemaState::Nullable(Box::new(SchemaState::String( StringType::Unknown { - char_distribution: vec!['f', 'o', 'o'], + chars_seen: vec!['f', 'o', 'o'], + strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]), min_length: Some(3), max_length: Some(3) } diff --git a/src/produce.rs b/src/produce.rs index 211dc45..9c9caba 100644 --- a/src/produce.rs +++ b/src/produce.rs @@ -3,7 +3,7 @@ use fake::{Fake, Faker}; use rand::{random, thread_rng, Rng}; use serde_json::Number; -use crate::{NumberType, SchemaState, StringType}; +use crate::{NumberType, SchemaState, StringType, STRING_VARIANTS_SMALL_DATA_BOUNDARY}; fn produce_inner(schema: &SchemaState, repeat_n: usize, current_depth: usize) -> serde_json::Value { match schema { @@ -37,27 +37,44 @@ fn produce_inner(schema: &SchemaState, repeat_n: usize, current_depth: usize) -> uuid.to_string() } StringType::Unknown { - char_distribution: charset, + chars_seen, + strings_seen, min_length, max_length, } => { - let min = min_length.unwrap_or(0); - let max = max_length.unwrap_or(32); - let take_n = if min != max { - thread_rng().gen_range(min..=max) + // heuristic: if we have only seen a very small set of strings, + // let's not generate a random string but pick one from the set. + // this might be an enum we're dealing with, or otherwise we won't have + // enough data to have a nice sample anyway. + if strings_seen.len() > 0 + && strings_seen.len() < STRING_VARIANTS_SMALL_DATA_BOUNDARY + { + let as_vec = Vec::from_iter(strings_seen); + let choice = as_vec[thread_rng().gen_range(0..as_vec.len())]; + choice.clone() } else { - min - }; + let min = min_length.unwrap_or(0); + let max = max_length.unwrap_or(32); + let take_n = if min != max { + thread_rng().gen_range(min..=max) + } else { + min + }; - if charset.is_empty() { - take_n.fake() - } else { - let mut s = String::with_capacity(take_n); - for _ in 0..take_n { - let idx = thread_rng().gen_range(0..charset.len()); - s.push(charset[idx]); + if chars_seen.is_empty() { + // we have no data at all to go by; generate a totally random string + take_n.fake() + } else { + // otherwise we use the fact that we have collected all characters seen + // to generate a random string with a similar character distribution to the + // input data. + let mut s = String::with_capacity(take_n); + for _ in 0..take_n { + let idx = thread_rng().gen_range(0..chars_seen.len()); + s.push(chars_seen[idx]); + } + s } - s } } }; diff --git a/src/schema.rs b/src/schema.rs index d14c5ae..a608818 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,9 +1,12 @@ use std::fmt::Display; +pub(crate) const STRING_VARIANTS_SMALL_DATA_BOUNDARY: usize = 10; + #[derive(PartialEq, Debug)] pub enum StringType { Unknown { - char_distribution: Vec, + chars_seen: Vec, + strings_seen: std::collections::HashSet, min_length: Option, max_length: Option, }, @@ -17,23 +20,31 @@ impl Display for StringType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let text = match self { StringType::Unknown { - char_distribution: _, + chars_seen: _, + strings_seen, min_length, max_length, } => { - let length = match (min_length, max_length) { - (Some(min), Some(max)) => { - if min != max { - format!("({}-{})", min, max) - } else { - format!("({})", min) + if strings_seen.len() > 0 + && strings_seen.len() <= STRING_VARIANTS_SMALL_DATA_BOUNDARY + { + let as_vec = Vec::from_iter(strings_seen.iter().cloned()); + format!("string (limited dataset: {})", &as_vec.join(",")) + } else { + let length = match (min_length, max_length) { + (Some(min), Some(max)) => { + if min != max { + format!("({}-{})", min, max) + } else { + format!("({})", min) + } } - } - (Some(min), None) => format!("({}-?)", min), - (None, Some(max)) => format!("(?-{})", max), - (None, None) => "(length unknown)".to_string(), - }; - format!("string {}", length) + (Some(min), None) => format!("({}-?)", min), + (None, Some(max)) => format!("(?-{})", max), + (None, None) => "(length unknown)".to_string(), + }; + format!("string {}", length) + } } StringType::IsoDate => "string (date - ISO 8601)".to_owned(), StringType::DateTimeRFC2822 => "string (datetime - RFC 2822)".to_owned(), @@ -190,12 +201,13 @@ impl SchemaState { /// /// ``` /// use drivel::{SchemaState, StringType, NumberType}; - /// use std::collections::HashMap; + /// use std::collections::{HashMap, HashSet}; /// use std::iter::FromIterator; /// /// let required = HashMap::from_iter(vec![ /// ("name".to_string(), SchemaState::String(StringType::Unknown { - /// char_distribution: vec!['a', 'b', 'c'], + /// chars_seen: vec!['a', 'b', 'c'], + /// strings_seen: HashSet::from_iter(["bac", "abc", "baa", "bcc"]) /// min_length: Some(1), /// max_length: Some(10), /// }))