Skip to content

Commit

Permalink
feat: support template interceptor (#63)
Browse files Browse the repository at this point in the history
* feat: support template interceptor

* fix header

* fix CI
  • Loading branch information
jiacai2050 authored Mar 22, 2024
1 parent 8aada16 commit 3a661b2
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 19 deletions.
2 changes: 2 additions & 0 deletions sqlness/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ readme = { workspace = true }
[dependencies]
async-trait = "0.1"
derive_builder = "0.11"
minijinja = "1"
mysql = { version = "23.0.1", optional = true }
postgres = { version = "0.19.7", optional = true }
prettydiff = { version = "0.6.2", default_features = false }
regex = "1.7.1"
serde_json = "1"
thiserror = "1.0"
toml = "0.5"
walkdir = "2.3"
Expand Down
26 changes: 26 additions & 0 deletions sqlness/examples/interceptor-replace/simple/replace.result
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,29 @@ SELECT 1;

03/14/2012, 01/01/2013 and 07/05/2014;

-- SQLNESS TEMPLATE {"name": "test"}
SELECT * FROM table where name = "{{name}}";

SELECT * FROM table where name = "test";

-- SQLNESS TEMPLATE {"aggr": ["sum", "avg", "count"]}
{% for item in aggr %}
SELECT {{item}}(c) from t {%if not loop.last %} {{sql_delimiter()}} {% endif %}
{% endfor %}
;

SELECT sum(c) from t ;

SELECT avg(c) from t ;

SELECT count(c) from t ;

-- SQLNESS TEMPLATE
INSERT INTO t (c) VALUES
{% for num in range(1, 5) %}
({{ num }}) {%if not loop.last %} , {% endif %}
{% endfor %}
;

INSERT INTO t (c) VALUES(1) , (2) , (3) , (4) ;

16 changes: 16 additions & 0 deletions sqlness/examples/interceptor-replace/simple/replace.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,19 @@ SELECT 0;
-- example of capture group replacement
-- SQLNESS REPLACE (?P<y>\d{4})-(?P<m>\d{2})-(?P<d>\d{2}) $m/$d/$y
2012-03-14, 2013-01-01 and 2014-07-05;

-- SQLNESS TEMPLATE {"name": "test"}
SELECT * FROM table where name = "{{name}}";

-- SQLNESS TEMPLATE {"aggr": ["sum", "avg", "count"]}
{% for item in aggr %}
SELECT {{item}}(c) from t {%if not loop.last %} {{sql_delimiter()}} {% endif %}
{% endfor %}
;

-- SQLNESS TEMPLATE
INSERT INTO t (c) VALUES
{% for num in range(1, 5) %}
({{ num }}) {%if not loop.last %} , {% endif %}
{% endfor %}
;
42 changes: 24 additions & 18 deletions sqlness/src/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
};

const COMMENT_PREFIX: &str = "--";
const QUERY_DELIMITER: char = ';';

pub(crate) struct TestCase {
name: String,
Expand Down Expand Up @@ -55,7 +56,7 @@ impl TestCase {
query.append_query_line(&line);

// SQL statement ends with ';'
if line.ends_with(';') {
if line.ends_with(QUERY_DELIMITER) {
queries.push(query);
query = Query::with_interceptor_factories(cfg.interceptor_factories.clone());
} else {
Expand Down Expand Up @@ -88,7 +89,7 @@ impl Display for TestCase {
}

/// A String-to-String map used as query context.
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone)]
pub struct QueryContext {
pub context: HashMap<String, String>,
}
Expand Down Expand Up @@ -137,14 +138,27 @@ impl Query {
W: Write,
{
let context = self.before_execute_intercept();

let mut result = db
.query(context, self.concat_query_lines())
.await
.to_string();

self.after_execute_intercept(&mut result);
self.write_result(writer, result)?;
for comment in &self.comment_lines {
writer.write_all(comment.as_bytes())?;
writer.write_all("\n".as_bytes())?;
}
for comment in &self.display_query {
writer.write_all(comment.as_bytes())?;
}
writer.write_all("\n\n".as_bytes())?;

let sql = self.concat_query_lines();
// An intercetor may generate multiple SQLs, so we need to split them.
for sql in sql.split(QUERY_DELIMITER) {
if !sql.trim().is_empty() {
let mut result = db
.query(context.clone(), format!("{sql};"))
.await
.to_string();
self.after_execute_intercept(&mut result);
self.write_result(writer, result)?;
}
}

Ok(())
}
Expand Down Expand Up @@ -183,14 +197,6 @@ impl Query {
where
W: Write,
{
for comment in &self.comment_lines {
writer.write_all(comment.as_bytes())?;
writer.write("\n".as_bytes())?;
}
for line in &self.display_query {
writer.write_all(line.as_bytes())?;
}
writer.write("\n\n".as_bytes())?;
writer.write_all(result.as_bytes())?;
writer.write("\n\n".as_bytes())?;

Expand Down
4 changes: 3 additions & 1 deletion sqlness/src/interceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ use crate::{
case::QueryContext,
interceptor::{
arg::ArgInterceptorFactory, env::EnvInterceptorFactory, replace::ReplaceInterceptorFactory,
sort_result::SortResultInterceptorFactory,
sort_result::SortResultInterceptorFactory, template::TemplateInterceptorFactory,
},
};

pub mod arg;
pub mod env;
pub mod replace;
pub mod sort_result;
pub mod template;

pub type InterceptorRef = Box<dyn Interceptor>;

Expand All @@ -40,5 +41,6 @@ pub fn builtin_interceptors() -> Vec<InterceptorFactoryRef> {
Arc::new(ReplaceInterceptorFactory {}),
Arc::new(EnvInterceptorFactory {}),
Arc::new(SortResultInterceptorFactory {}),
Arc::new(TemplateInterceptorFactory {}),
]
}
159 changes: 159 additions & 0 deletions sqlness/src/interceptor/template.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2024 CeresDB Project Authors. Licensed under Apache-2.0.

use minijinja::Environment;
use serde_json::Value;

use super::{Interceptor, InterceptorFactory, InterceptorRef};

pub struct TemplateInterceptorFactory;

const PREFIX: &str = "TEMPLATE";

/// Templated query, powered by [minijinja](https://github.com/mitsuhiko/minijinja).
/// The template syntax can be found [here](https://docs.rs/minijinja/latest/minijinja/syntax/index.html).
///
/// Grammar:
/// ``` text
/// -- SQLNESS TEMPLATE <json>
/// ```
///
/// `json` define data bindings passed to template, it should be a valid JSON string.
///
/// # Example
/// `.sql` file:
/// ``` sql
/// -- SQLNESS TEMPLATE {"name": "test"}
/// SELECT * FROM table where name = "{{name}}"
/// ```
///
/// `.result` file:
/// ``` sql
/// -- SQLNESS TEMPLATE {"name": "test"}
/// SELECT * FROM table where name = "test";
/// ```
///
/// In order to generate multiple queries, you can use the builtin function
/// `sql_delimiter()` to insert a delimiter.
///
#[derive(Debug)]
pub struct TemplateInterceptor {
json_ctx: String,
}

fn sql_delimiter() -> Result<String, minijinja::Error> {
Ok(";".to_string())
}

impl Interceptor for TemplateInterceptor {
fn before_execute(&self, execute_query: &mut Vec<String>, _context: &mut crate::QueryContext) {
let input = execute_query.join("\n");
let mut env = Environment::new();
env.add_function("sql_delimiter", sql_delimiter);
env.add_template("sql", &input).unwrap();
let tmpl = env.get_template("sql").unwrap();
let bindings: Value = if self.json_ctx.is_empty() {
serde_json::from_str("{}").unwrap()
} else {
serde_json::from_str(&self.json_ctx).unwrap()
};
let rendered = tmpl.render(bindings).unwrap();
*execute_query = rendered
.split('\n')
.map(|v| v.to_string())
.collect::<Vec<_>>();
}

fn after_execute(&self, _result: &mut String) {}
}

impl InterceptorFactory for TemplateInterceptorFactory {
fn try_new(&self, interceptor: &str) -> Option<InterceptorRef> {
Self::try_new_from_str(interceptor).map(|i| Box::new(i) as _)
}
}

impl TemplateInterceptorFactory {
fn try_new_from_str(interceptor: &str) -> Option<TemplateInterceptor> {
if interceptor.starts_with(PREFIX) {
let json_ctx = interceptor.trim_start_matches(PREFIX).to_string();
Some(TemplateInterceptor { json_ctx })
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn basic_template() {
let interceptor = TemplateInterceptorFactory
.try_new(r#"TEMPLATE {"name": "test"}"#)
.unwrap();

let mut input = vec!["SELECT * FROM table where name = '{{name}}'".to_string()];
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());

assert_eq!(input, vec!["SELECT * FROM table where name = 'test'"]);
}

#[test]
fn vector_template() {
let interceptor = TemplateInterceptorFactory
.try_new(r#"TEMPLATE {"aggr": ["sum", "count", "avg"]}"#)
.unwrap();

let mut input = [
"{%- for item in aggr %}",
"SELECT {{item}}(c) from t;",
"{%- endfor %}",
]
.map(|v| v.to_string())
.to_vec();
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());

assert_eq!(
input,
[
"",
"SELECT sum(c) from t;",
"SELECT count(c) from t;",
"SELECT avg(c) from t;"
]
.map(|v| v.to_string())
.to_vec()
);
}

#[test]
fn range_template() {
let interceptor = TemplateInterceptorFactory.try_new(r#"TEMPLATE"#).unwrap();

let mut input = [
"INSERT INTO t (c) VALUES",
"{%- for num in range(1, 5) %}",
"({{ num }}){%if not loop.last %}, {% endif %}",
"{%- endfor %}",
";",
]
.map(|v| v.to_string())
.to_vec();
interceptor.before_execute(&mut input, &mut crate::QueryContext::default());

assert_eq!(
input,
[
"INSERT INTO t (c) VALUES",
"(1), ",
"(2), ",
"(3), ",
"(4)",
";"
]
.map(|v| v.to_string())
.to_vec()
);
}
}

0 comments on commit 3a661b2

Please sign in to comment.