Skip to content

Commit

Permalink
deal with small data sets
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrsd committed Apr 2, 2024
1 parent aff585e commit 559d06d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 51 deletions.
66 changes: 47 additions & 19 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState {
// --- String merging ---
(
SchemaState::String(StringType::Unknown {
char_distribution: mut charset,
mut chars_seen,
mut strings_seen,
min_length,
max_length,
}),
SchemaState::String(StringType::Unknown {
char_distribution: second_charset,
chars_seen: second_chars_seen,
strings_seen: second_strings_seen,
min_length: second_min_length,
max_length: second_max_length,
}),
Expand All @@ -57,18 +59,21 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState {
second_max_length
};

charset.extend(second_charset);
chars_seen.extend(second_chars_seen);
strings_seen.extend(second_strings_seen);

SchemaState::String(StringType::Unknown {
char_distribution: charset,
chars_seen,
strings_seen,
min_length,
max_length,
})
}

(
SchemaState::String(StringType::Unknown {
char_distribution: charset,
chars_seen,
strings_seen,
min_length,
max_length,
}),
Expand All @@ -77,12 +82,14 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState {
| (
SchemaState::String(_),
SchemaState::String(StringType::Unknown {
char_distribution: charset,
chars_seen,
strings_seen,
min_length,
max_length,
}),
) => SchemaState::String(StringType::Unknown {
char_distribution: charset,
chars_seen,
strings_seen,
min_length,
max_length,
}),
Expand All @@ -92,7 +99,8 @@ fn merge(initial: SchemaState, new: SchemaState) -> SchemaState {
SchemaState::String(first_type)
} else {
SchemaState::String(StringType::Unknown {
char_distribution: vec![],
chars_seen: vec![],
strings_seen: std::collections::HashSet::new(),
min_length: None,
max_length: None,
})
Expand Down Expand Up @@ -285,7 +293,7 @@ fn infer_array_schema(values: &[serde_json::Value]) -> SchemaState {
///
/// ```
/// use serde_json::json;
/// use std::collections::HashMap;
/// use std::collections::{HashMap, HashSet};
/// use drivel::{infer_schema, SchemaState, StringType, NumberType};
///
/// // Define a JSON value
Expand All @@ -301,7 +309,8 @@ fn infer_array_schema(values: &[serde_json::Value]) -> SchemaState {
/// SchemaState::Object {
/// required: HashMap::from_iter([
/// ("name".to_string(), SchemaState::String(StringType::Unknown {
/// char_distribution: vec!['J', 'o', 'h', 'n'],
/// chars_seen: vec!['J', 'o', 'h', 'n'],
/// strings_seen: HashSet::from_iter(["John".to_string()]),
/// min_length: Some(4),
/// max_length: Some(4)
/// })),
Expand Down Expand Up @@ -331,7 +340,8 @@ pub fn infer_schema(json: &serde_json::Value) -> SchemaState {
StringType::UUID
} else {
StringType::Unknown {
char_distribution: value.chars().collect(),
chars_seen: value.chars().collect(),
strings_seen: std::collections::HashSet::from_iter([value.clone()]),
min_length: Some(value.len()),
max_length: Some(value.len()),
}
Expand Down Expand Up @@ -387,7 +397,8 @@ mod tests {
assert_eq!(
schema,
SchemaState::String(StringType::Unknown {
char_distribution: vec!['f', 'o', 'o'],
chars_seen: vec!['f', 'o', 'o'],
strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]),
min_length: Some(3),
max_length: Some(3)
})
Expand Down Expand Up @@ -497,7 +508,8 @@ mod tests {
(
"string".to_string(),
SchemaState::String(StringType::Unknown {
char_distribution: vec!['f', 'o', 'o'],
chars_seen: vec!['f', 'o', 'o'],
strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]),
min_length: Some(3),
max_length: Some(3)
})
Expand All @@ -520,7 +532,10 @@ mod tests {
min_length: 1,
max_length: 1,
schema: Box::new(SchemaState::String(StringType::Unknown {
char_distribution: vec!['b', 'a', 'z'],
chars_seen: vec!['b', 'a', 'z'],
strings_seen: std::collections::HashSet::from_iter([
"baz".to_owned()
]),
min_length: Some(3),
max_length: Some(3)
}))
Expand All @@ -533,7 +548,10 @@ mod tests {
required: std::collections::HashMap::from_iter([(
"string".to_owned(),
SchemaState::String(StringType::Unknown {
char_distribution: vec!['f', 'o', 'o'],
chars_seen: vec!['f', 'o', 'o'],
strings_seen: std::collections::HashSet::from_iter([
"foo".to_owned()
]),
min_length: Some(3),
max_length: Some(3)
})
Expand Down Expand Up @@ -573,7 +591,11 @@ mod tests {
min_length: 2,
max_length: 2,
schema: Box::new(SchemaState::String(StringType::Unknown {
char_distribution: vec!['f', 'o', 'o', 'b', 'a', 'r', 'b', 'a', 'r'],
chars_seen: vec!['f', 'o', 'o', 'b', 'a', 'r', 'b', 'a', 'r'],
strings_seen: std::collections::HashSet::from_iter([
"foo".to_owned(),
"barbar".to_owned(),
]),
min_length: Some(3),
max_length: Some(6)
}))
Expand All @@ -592,7 +614,8 @@ mod tests {
min_length: 2,
max_length: 2,
schema: Box::new(SchemaState::String(StringType::Unknown {
char_distribution: vec!['b', 'a', 'r', 'b', 'a', 'r'],
chars_seen: vec!['b', 'a', 'r', 'b', 'a', 'r'],
strings_seen: std::collections::HashSet::from_iter(["barbar".to_owned(),]),
min_length: Some(6),
max_length: Some(6),
}))
Expand Down Expand Up @@ -689,7 +712,11 @@ mod tests {
optional: std::collections::HashMap::from_iter([(
"foo".to_owned(),
SchemaState::String(StringType::Unknown {
char_distribution: vec!['b', 'a', 'r', 'b', 'a', 'r', 'b', 'a', 'r'],
chars_seen: vec!['b', 'a', 'r', 'b', 'a', 'r', 'b', 'a', 'r'],
strings_seen: std::collections::HashSet::from_iter([
"bar".to_owned(),
"barbar".to_owned(),
]),
min_length: Some(3),
max_length: Some(6)
})
Expand Down Expand Up @@ -733,7 +760,8 @@ mod tests {
max_length: 2,
schema: Box::new(SchemaState::Nullable(Box::new(SchemaState::String(
StringType::Unknown {
char_distribution: vec!['f', 'o', 'o'],
chars_seen: vec!['f', 'o', 'o'],
strings_seen: std::collections::HashSet::from_iter(["foo".to_owned()]),
min_length: Some(3),
max_length: Some(3)
}
Expand Down
49 changes: 33 additions & 16 deletions src/produce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use fake::{Fake, Faker};
use rand::{random, thread_rng, Rng};
use serde_json::Number;

use crate::{NumberType, SchemaState, StringType};
use crate::{NumberType, SchemaState, StringType, STRING_VARIANTS_SMALL_DATA_BOUNDARY};

fn produce_inner(schema: &SchemaState, repeat_n: usize, current_depth: usize) -> serde_json::Value {
match schema {
Expand Down Expand Up @@ -37,27 +37,44 @@ fn produce_inner(schema: &SchemaState, repeat_n: usize, current_depth: usize) ->
uuid.to_string()
}
StringType::Unknown {
char_distribution: charset,
chars_seen,
strings_seen,
min_length,
max_length,
} => {
let min = min_length.unwrap_or(0);
let max = max_length.unwrap_or(32);
let take_n = if min != max {
thread_rng().gen_range(min..=max)
// heuristic: if we have only seen a very small set of strings,
// let's not generate a random string but pick one from the set.
// this might be an enum we're dealing with, or otherwise we won't have
// enough data to have a nice sample anyway.
if strings_seen.len() > 0
&& strings_seen.len() < STRING_VARIANTS_SMALL_DATA_BOUNDARY
{
let as_vec = Vec::from_iter(strings_seen);
let choice = as_vec[thread_rng().gen_range(0..as_vec.len())];
choice.clone()
} else {
min
};
let min = min_length.unwrap_or(0);
let max = max_length.unwrap_or(32);
let take_n = if min != max {
thread_rng().gen_range(min..=max)
} else {
min
};

if charset.is_empty() {
take_n.fake()
} else {
let mut s = String::with_capacity(take_n);
for _ in 0..take_n {
let idx = thread_rng().gen_range(0..charset.len());
s.push(charset[idx]);
if chars_seen.is_empty() {
// we have no data at all to go by; generate a totally random string
take_n.fake()
} else {
// otherwise we use the fact that we have collected all characters seen
// to generate a random string with a similar character distribution to the
// input data.
let mut s = String::with_capacity(take_n);
for _ in 0..take_n {
let idx = thread_rng().gen_range(0..chars_seen.len());
s.push(chars_seen[idx]);
}
s
}
s
}
}
};
Expand Down
44 changes: 28 additions & 16 deletions src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::fmt::Display;

pub(crate) const STRING_VARIANTS_SMALL_DATA_BOUNDARY: usize = 10;

#[derive(PartialEq, Debug)]
pub enum StringType {
Unknown {
char_distribution: Vec<char>,
chars_seen: Vec<char>,
strings_seen: std::collections::HashSet<String>,
min_length: Option<usize>,
max_length: Option<usize>,
},
Expand All @@ -17,23 +20,31 @@ impl Display for StringType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let text = match self {
StringType::Unknown {
char_distribution: _,
chars_seen: _,
strings_seen,
min_length,
max_length,
} => {
let length = match (min_length, max_length) {
(Some(min), Some(max)) => {
if min != max {
format!("({}-{})", min, max)
} else {
format!("({})", min)
if strings_seen.len() > 0
&& strings_seen.len() <= STRING_VARIANTS_SMALL_DATA_BOUNDARY
{
let as_vec = Vec::from_iter(strings_seen.iter().cloned());
format!("string (limited dataset: {})", &as_vec.join(","))
} else {
let length = match (min_length, max_length) {
(Some(min), Some(max)) => {
if min != max {
format!("({}-{})", min, max)
} else {
format!("({})", min)
}
}
}
(Some(min), None) => format!("({}-?)", min),
(None, Some(max)) => format!("(?-{})", max),
(None, None) => "(length unknown)".to_string(),
};
format!("string {}", length)
(Some(min), None) => format!("({}-?)", min),
(None, Some(max)) => format!("(?-{})", max),
(None, None) => "(length unknown)".to_string(),
};
format!("string {}", length)
}
}
StringType::IsoDate => "string (date - ISO 8601)".to_owned(),
StringType::DateTimeRFC2822 => "string (datetime - RFC 2822)".to_owned(),
Expand Down Expand Up @@ -190,12 +201,13 @@ impl SchemaState {
///
/// ```
/// use drivel::{SchemaState, StringType, NumberType};
/// use std::collections::HashMap;
/// use std::collections::{HashMap, HashSet};
/// use std::iter::FromIterator;
///
/// let required = HashMap::from_iter(vec![
/// ("name".to_string(), SchemaState::String(StringType::Unknown {
/// char_distribution: vec!['a', 'b', 'c'],
/// chars_seen: vec!['a', 'b', 'c'],
/// strings_seen: HashSet::from_iter(["bac", "abc", "baa", "bcc"])
/// min_length: Some(1),
/// max_length: Some(10),
/// }))
Expand Down

0 comments on commit 559d06d

Please sign in to comment.