diff --git a/tests/data/plpgsql_query.json b/tests/data/plpgsql_query.json new file mode 100644 index 0000000..f0f4c9f --- /dev/null +++ b/tests/data/plpgsql_query.json @@ -0,0 +1,91 @@ +[ + { + "PLpgSQL_function": { + "action": { + "PLpgSQL_stmt_block": { + "body": [ + { + "PLpgSQL_stmt_execsql": { + "into": true, + "lineno": 5, + "sqlstmt": { + "PLpgSQL_expr": { + "query": "SELECT details FROM t WHERE col = input" + } + }, + "target": { + "PLpgSQL_row": { + "fields": [ + { + "name": "result", + "varno": 2 + } + ], + "lineno": 5, + "refname": "(unnamed row)" + } + } + } + }, + { + "PLpgSQL_stmt_return": { + "expr": { + "PLpgSQL_expr": { + "query": "result" + } + }, + "lineno": 6 + } + } + ], + "lineno": 4 + } + }, + "datums": [ + { + "PLpgSQL_var": { + "datatype": { + "PLpgSQL_type": { + "typname": "UNKNOWN" + } + }, + "refname": "input" + } + }, + { + "PLpgSQL_var": { + "datatype": { + "PLpgSQL_type": { + "typname": "UNKNOWN" + } + }, + "refname": "found" + } + }, + { + "PLpgSQL_var": { + "datatype": { + "PLpgSQL_type": { + "typname": "jsonb" + } + }, + "lineno": 3, + "refname": "result" + } + }, + { + "PLpgSQL_row": { + "fields": [ + { + "name": "result", + "varno": 2 + } + ], + "lineno": 5, + "refname": "(unnamed row)" + } + } + ] + } + } +] \ No newline at end of file diff --git a/tests/data/simple_plpgsql.json b/tests/data/plpgsql_simple.json similarity index 93% rename from tests/data/simple_plpgsql.json rename to tests/data/plpgsql_simple.json index a1d2115..6af8524 100644 --- a/tests/data/simple_plpgsql.json +++ b/tests/data/plpgsql_simple.json @@ -11,7 +11,7 @@ "query": "v_version IS NULL" } }, - "lineno": 1, + "lineno": 3, "then_body": [ { "PLpgSQL_stmt_return": { @@ -20,7 +20,7 @@ "query": "v_name" } }, - "lineno": 1 + "lineno": 4 } } ] @@ -33,11 +33,11 @@ "query": "v_name || '/' || v_version" } }, - "lineno": 1 + "lineno": 6 } } ], - "lineno": 1 + "lineno": 2 } }, "datums": [ diff --git a/tests/parse_plpgsql_tests.rs b/tests/parse_plpgsql_tests.rs index 353f017..8edafdd 100644 --- a/tests/parse_plpgsql_tests.rs +++ b/tests/parse_plpgsql_tests.rs @@ -1,20 +1,47 @@ +#[macro_use] +mod support; +use support::*; + #[test] fn it_can_parse_a_simple_function() { let result = pg_query::parse_plpgsql( - " \ - CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar) \ - RETURNS varchar AS $$ \ - BEGIN \ - IF v_version IS NULL THEN \ - RETURN v_name; \ - END IF; \ - RETURN v_name || '/' || v_version; \ + " + CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar) + RETURNS varchar AS $$ + BEGIN + IF v_version IS NULL THEN + RETURN v_name; + END IF; + RETURN v_name || '/' || v_version; END; \ - $$ LANGUAGE plpgsql;", + $$ LANGUAGE plpgsql; + ", + ); + assert!(result.is_ok()); + let result = result.unwrap(); + let expected = include_str!("data/plpgsql_simple.json"); + let actual = serde_json::to_string_pretty(&result).unwrap(); + assert_eq!(expected, &actual); +} + +#[test] +fn it_can_parse_a_query_function() { + let result = pg_query::parse_plpgsql( + " + CREATE OR REPLACE FUNCTION fn(input integer) RETURNS jsonb LANGUAGE plpgsql STABLE AS + ' + DECLARE + result jsonb; + BEGIN + SELECT details FROM t INTO result WHERE col = input; + RETURN result; + END; + '; + ", ); assert!(result.is_ok()); let result = result.unwrap(); - let expected = include_str!("data/simple_plpgsql.json"); + let expected = include_str!("data/plpgsql_query.json"); let actual = serde_json::to_string_pretty(&result).unwrap(); assert_eq!(expected, actual); } diff --git a/tests/support.rs b/tests/support.rs index 569a731..0f03869 100644 --- a/tests/support.rs +++ b/tests/support.rs @@ -23,6 +23,16 @@ macro_rules! assert_debug_eq { }; } +macro_rules! assert_eq { + ($left:expr, $right:expr) => { + if let Ok(_diff) = std::env::var("DIFF") { + pretty_assertions::assert_eq!($left, $right); + } else { + std::assert_eq!($left, $right); + } + }; +} + pub fn assert_vec_matches(a: &Vec, b: &Vec) { let matching = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count(); assert!(matching == a.len() && matching == b.len())