Skip to content

Commit

Permalink
Support recursive references in json schema
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 6, 2024
1 parent 3a8332f commit e4a8e55
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 216 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ repository = "https://github.com/dottxt-ai/outlines-core"

[dependencies]
once_cell = "1.20"
anyhow = "1.0.86"
thiserror = "2.0"
pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true }
regex = "1.10.6"
Expand Down
51 changes: 51 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,57 @@ pub enum Error {
ByteFallbackProcessorFailed,
}

// TODO: Integrate JsonSchema errors and simplify
#[derive(Error, Debug)]
pub enum JsonSchemaParserError {
#[error("serde json error")]
SerdeJsonError(#[from] serde_json::Error),
#[error("Unsupported JSON Schema structure {0} \nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\nIf it should be supported, please open an issue.")]
UnsupportedJsonSchema(Box<serde_json::Value>),
#[error("'properties' not found or not an object")]
PropertiesNotFound,
#[error("'allOf' must be an array")]
AllOfMustBeAnArray,
#[error("'anyOf' must be an array")]
AnyOfMustBeAnArray,
#[error("'oneOf' must be an array")]
OneOfMustBeAnArray,
#[error("'prefixItems' must be an array")]
PrefixItemsMustBeAnArray,
#[error("Unsupported data type in enum: {0}")]
UnsupportedEnumDataType(Box<serde_json::Value>),
#[error("'enum' must be an array")]
EnumMustBeAnArray,
#[error("Unsupported data type in const: {0}")]
UnsupportedConstDataType(Box<serde_json::Value>),
#[error("'const' key not found in object")]
ConstKeyNotFound,
#[error("'$ref' must be a string")]
RefMustBeAString,
#[error("External references are not supported: {0}")]
ExternalReferencesNotSupported(Box<str>),
#[error("Invalid reference format: {0}")]
InvalidReferenceFormat(Box<str>),
#[error("'type' must be a string")]
TypeMustBeAString,
#[error("Unsupported type: {0}")]
UnsupportedType(Box<str>),
#[error("maxLength must be greater than or equal to minLength")]
MaxBoundError,
#[error("Format {0} is not supported by Outlines")]
StringTypeUnsupportedFormat(Box<str>),
#[error("Invalid reference path: {0}")]
InvalidRefecencePath(Box<str>),
#[error("Ref recusion limit reached: {0}")]
RefRecursionLimitReached(usize),
}

impl JsonSchemaParserError {
pub fn is_recursion_limit(&self) -> bool {
matches!(self, Self::RefRecursionLimitReached(_))
}
}

#[cfg(feature = "python-bindings")]
impl From<Error> for pyo3::PyErr {
fn from(e: Error) -> Self {
Expand Down
221 changes: 126 additions & 95 deletions src/json_schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,129 +3,45 @@ mod types;

pub use types::*;

use anyhow::Result;
use serde_json::Value;

use crate::JsonSchemaParserError;

type Result<T> = std::result::Result<T, JsonSchemaParserError>;

pub fn build_regex_from_schema(json: &str, whitespace_pattern: Option<&str>) -> Result<String> {
let json_value: Value = serde_json::from_str(json)?;
to_regex(&json_value, whitespace_pattern)
}

#[allow(clippy::wrong_self_convention)]
pub fn to_regex(json: &Value, whitespace_pattern: Option<&str>) -> Result<String> {
let mut parser = parsing::Parser::new(json);
if let Some(pattern) = whitespace_pattern {
parser = parser.with_whitespace_pattern(pattern)
}
Ok(parser.to_regex(json)?)
parser.to_regex(json)
}

#[cfg(test)]
mod tests {
use super::*;
use regex::Regex;

#[test]
fn recursive_ref() {
let json = r##"{
"type": "object",
"properties": {
"node": { "$ref": "#/definitions/node" }
},
"definitions": {
"node": {
"type": "object",
"properties": {
"value": { "type": "integer" },
"next": { "$ref": "#/definitions/node" }
}
}
}
}"##;
let json_value: Value = serde_json::from_str(json).expect("Can't parse json");
let mut parser = parsing::Parser::new(&json_value).with_max_recursion_depth(1);
let result = parser.to_regex(&json_value);
assert!(result.is_ok(), "{:?}", result);
let regex = result.unwrap();
assert_eq!(
r#"\{([ ]?"node"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)([ ]?,[ ]?"next"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)([ ]?,[ ]?"next"[ ]?:[ ]?{})?|([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?,)?[ ]?"next"[ ]?:[ ]?{})?[ ]?\})?|([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?,)?[ ]?"next"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)([ ]?,[ ]?"next"[ ]?:[ ]?{})?|([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?,)?[ ]?"next"[ ]?:[ ]?{})?[ ]?\})?[ ]?\})?[ ]?\}"#,
regex,
);
}

#[test]
fn internal_ref_works() {
let json = r##"
{
"definitions": {
"address": {
"type": "object",
"properties": {
"street": { "type": "string" },
"city": { "type": "string" }
}
}
},
"type": "object",
"properties": {
"home_address": { "$ref": "#/definitions/address" },
"work_address": { "$ref": "#/definitions/address" }
}
}"##;

let json_value: Value = serde_json::from_str(json).expect("Can't parse json");
let result = to_regex(&json_value, None);

match result {
Ok(r) => {
assert!(r.contains("home_address"));
assert!(r.contains("work_address"));
}
_ => unreachable!(),
}
}

// #[test]
// fn error_on_indirect_recursion() {
// let json = r##"{
// "type": "object",
// "properties": {
// "node": { "$ref": "#/definitions/node" }
// },
// "definitions": {
// "node": {
// "type": "object",
// "properties": {
// "value": { "type": "integer" },
// "next": { "$ref": "#/definitions/node" }
// }
// }
// }
// }"##;

// let json_value: Value = serde_json::from_str(json).expect("Can't parse json");
// let result = to_regex(&json_value, None, &json_value);

// match result {
// Err(e) => {
// let message = "Recursive references are not supported for now";
// assert_eq!(message, e.to_string());
// }
// _ => unreachable!(),
// }
// }

fn should_match(re: &Regex, value: &str) {
// Asserts that value is fully matched.
match re.find(value) {
Some(matched) => {
assert_eq!(
matched.as_str(),
value,
"Value should match, but does not for: {value}"
"Value should match, but does not for: {value}, re:\n{re}"
);
assert_eq!(matched.range(), 0..value.len());
}
None => unreachable!("Value should match, but does not, in unreachable for: {value}"),
None => unreachable!(
"Value should match, but does not, in unreachable for: {value}, re:\n{re}"
),
}
}

Expand All @@ -135,7 +51,7 @@ mod tests {
assert_ne!(
matched.as_str(),
value,
"Value should NOT match, but does for: {value}"
"Value should NOT match, but does for: {value}, re:\n{re}"
);
assert_ne!(matched.range(), 0..value.len());
}
Expand Down Expand Up @@ -1030,4 +946,119 @@ mod tests {
}
}
}

#[test]
fn direct_recursion_in_array_and_default_behaviour() {
let json = r##"
{
"type": "object",
"properties": {
"name": { "type": "string" },
"children": {
"type": "array",
"items": { "$ref": "#" }
}
}
}"##;

let json_value: Value = serde_json::from_str(json).expect("Can't parse json");
let regex = to_regex(&json_value, None);
assert!(regex.is_ok(), "{:?}", regex);

// Confirm the depth of 3 recursion levels by default, recursion level starts
// when children start to have children
let re = Regex::new(&regex.unwrap()).expect("Regex failed");
for lvl in [
// level 0
r#"{ "name": "Az"}"#,
r#"{ "name": "Az", "children": [] }"#,
r#"{ "name": "Az", "children": [{"name": "Bo"}] }"#,
// level 1
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [] }] }"#,
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li"}] }] }"#,
// level 2
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [] }] }] }"#,
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [{"name": "Ho"}] }] }] }"#,
// level 3
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [{"name": "Ho", "children": [] }] }] }] }"#,
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [{"name": "Ho", "children": [{"name": "Ro"}] }] }] }] }"#,
] {
should_match(&re, lvl);
}

for lvl in [
// level 4
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [{"name": "Ho", "children": [{"name": "Ro", "children": [] }] }] }] }] }"#,
r#"{ "name": "Az", "children": [{"name": "Bo", "children": [{"name": "Li", "children": [{"name": "Ho", "children": [{"name": "Ro", "children": [{"name": "Ks"}] }] }] }] }] }"#,
] {
should_not_match(&re, lvl);
}
}

#[test]
fn indirect_recursion_with_recursion_level_regex_match() {
let json = r##"{
"type": "object",
"properties": {
"node": { "$ref": "#/definitions/node" }
},
"definitions": {
"node": {
"type": "object",
"properties": {
"value": { "type": "integer" },
"next": { "$ref": "#/definitions/node" }
}
}
}
}"##;
let json_value: Value = serde_json::from_str(json).expect("Can't parse json");
let mut parser = parsing::Parser::new(&json_value).with_max_recursion_depth(0);

let result = parser.to_regex(&json_value);
assert!(result.is_ok(), "{:?}", result);
let regex = result.unwrap();
assert_eq!(
r#"\{([ ]?"node"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*))?[ ]?\})?[ ]?\}"#,
regex,
);

// More readable version to confirm that logic is correct.
// Recursion depth 1:
// {
// ("node":
// {
// ("value":(-)?(0|[1-9][0-9]*)(,"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// |
// ("value":(-)?(0|[1-9][0-9]*),)?"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// }
// )?
// }
// Recursion depth 2:
// {
// ("node":
// {
// ("value":(-)?(0|[1-9][0-9]*)(,"next":{
// ("value":(-)?(0|[1-9][0-9]*)(,"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// |
// ("value":(-)?(0|[1-9][0-9]*),)?"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// })?
// |
// ("value":(-)?(0|[1-9][0-9]*),)?"next":{
// ("value":(-)?(0|[1-9][0-9]*)(,"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// |
// ("value":(-)?(0|[1-9][0-9]*),)?"next":{("value":(-)?(0|[1-9][0-9]*))?})?
// })?
// }
// )?
// }
let mut parser = parser.with_max_recursion_depth(1);
let result = parser.to_regex(&json_value);
assert!(result.is_ok(), "{:?}", result);
let regex = result.unwrap();
assert_eq!(
r#"\{([ ]?"node"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)([ ]?,[ ]?"next"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*))?[ ]?\})?|([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?,)?[ ]?"next"[ ]?:[ ]?\{([ ]?"value"[ ]?:[ ]?(-)?(0|[1-9][0-9]*))?[ ]?\})?[ ]?\})?[ ]?\}"#,
regex,
);
}
}
Loading

0 comments on commit e4a8e55

Please sign in to comment.