From 035f0cafc16ff06682d365bc8b41cb72f96c49a0 Mon Sep 17 00:00:00 2001 From: brianheineman Date: Mon, 6 Jan 2025 13:09:56 -0700 Subject: [PATCH] feat: add xml driver --- Cargo.lock | 1 + Cargo.toml | 2 +- README.md | 25 +- datasets/users.xml | 11 + rsql_cli/docs/src/chapter2/drivers/index.md | 17 +- rsql_core/Cargo.toml | 2 + rsql_core/src/commands/drivers.rs | 2 + rsql_drivers/Cargo.toml | 7 + rsql_drivers/src/driver.rs | 4 + rsql_drivers/src/file/driver.rs | 49 +++- rsql_drivers/src/lib.rs | 2 + rsql_drivers/src/xml/driver.rs | 309 ++++++++++++++++++++ rsql_drivers/src/xml/mod.rs | 3 + 13 files changed, 398 insertions(+), 36 deletions(-) create mode 100644 datasets/users.xml create mode 100644 rsql_drivers/src/xml/driver.rs create mode 100644 rsql_drivers/src/xml/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c91f3b8e..12272fd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6478,6 +6478,7 @@ dependencies = [ "polars", "polars-sql", "postgresql_embedded", + "quick-xml 0.37.2", "regex", "reqwest", "rusqlite", diff --git a/Cargo.toml b/Cargo.toml index deba1a22..3842861f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ os_info = "3.9.0" polars = "0.45.1" polars-sql = "0.45.1" postgresql_embedded = "0.17.3" -quick-xml = "0.37.1" +quick-xml = "0.37.2" regex = "1.11.1" reqwest = "0.12.8" rusqlite = "0.30.0" diff --git a/README.md b/README.md index e29709df..f36a83de 100644 --- a/README.md +++ b/README.md @@ -38,18 +38,18 @@ visit the [rsql](https://theseus-rs.github.io/rsql/rsql_cli/) site. ## Features -| Feature | | -|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Databases | Arrow, Avro, CockroachDB, CSV, Delimited, DuckDB, JSON, JSONL, LibSQL (Turso), MariaDB, MySQL, Parquet, PostgreSQL, Redshift, Snowflake, SQLite3, SQL Server, TSV | -| Syntax Highlighting | ✅ | -| Result Highlighting | ✅ | -| Query Auto-completion | ✅ | -| History | ✅ | -| SQL File Execution | ✅ | -| Embedded PostgreSQL | ✅ | -| Output Formats | ascii, csv, expanded, html, json, jsonl, markdown, plain, psql, sqlite, tsv, unicode, xml, yaml | -| Localized Interface | 40+ languages¹ | -| Key Bindings | emacs, vi | +| Feature | | +|-----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Databases | Arrow, Avro, CockroachDB, CSV, Delimited, DuckDB, JSON, JSONL, LibSQL (Turso), MariaDB, MySQL, Parquet, PostgreSQL, Redshift, Snowflake, SQLite3, SQL Server, TSV, XML | +| Syntax Highlighting | ✅ | +| Result Highlighting | ✅ | +| Query Auto-completion | ✅ | +| History | ✅ | +| SQL File Execution | ✅ | +| Embedded PostgreSQL | ✅ | +| Output Formats | ascii, csv, expanded, html, json, jsonl, markdown, plain, psql, sqlite, tsv, unicode, xml, yaml | +| Localized Interface | 40+ languages¹ | +| Key Bindings | emacs, vi | ¹ Computer translations; human translations welcome @@ -90,6 +90,7 @@ rsql --url "" -- "" | sqlite (sqlx) | `sqlite://[]` | | sqlserver | `sqlserver://[:]@[:]/` | | tsv (polars) | `tsv://[?has_header=]["e=][&skip_rows=]` | +| xml | `xml://` | ¹ the `file` driver will attempt to detect the type of file and automatically use the appropriate driver. ² `libsql` needs to be enabled with the `libsql` feature flag; it is disabled by default as it conflicts diff --git a/datasets/users.xml b/datasets/users.xml new file mode 100644 index 00000000..8db9062f --- /dev/null +++ b/datasets/users.xml @@ -0,0 +1,11 @@ + + + + 1 + John Doe + + + 2 + Jane Smith + + diff --git a/rsql_cli/docs/src/chapter2/drivers/index.md b/rsql_cli/docs/src/chapter2/drivers/index.md index e00c4421..60ebe91f 100644 --- a/rsql_cli/docs/src/chapter2/drivers/index.md +++ b/rsql_cli/docs/src/chapter2/drivers/index.md @@ -12,19 +12,19 @@ The drivers command displays the available database drivers. | Driver | Description | URL | |---------------|--------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------| -| `arrow` | Arrow IPC file driver provided by [Polars](https://github.com/pola-rs/polars) | `arrow://` | -| `avro` | Avro file driver provided by [Polars](https://github.com/pola-rs/polars) | `avro://` | +| `arrow` | Arrow IPC provided by [Polars](https://github.com/pola-rs/polars) | `arrow://` | +| `avro` | Avro provided by [Polars](https://github.com/pola-rs/polars) | `avro://` | | `cockroachdb` | CockroachDB driver provided by [SQLx](https://github.com/launchbadge/sqlx) | `cockroachdb://[:]@[:]/` | -| `csv` | Comma Separated Value (CSV) file driver provided by [Polars](https://github.com/pola-rs/polars) | `csv://[?has_header=]["e=][&skip_rows=]` | -| `delimited` | Delimited file driver provided by [Polars](https://github.com/pola-rs/polars) | `delimited://[?separator=][&has_header=]["e=][&skip_rows=]` | +| `csv` | Comma Separated Value (CSV) provided by [Polars](https://github.com/pola-rs/polars) | `csv://[?has_header=]["e=][&skip_rows=]` | +| `delimited` | Delimited provided by [Polars](https://github.com/pola-rs/polars) | `delimited://[?separator=][&has_header=]["e=][&skip_rows=]` | | `duckdb` | DuckDB provided by [DuckDB](https://duckdb.org/) | `duckdb://[]` | | `file` | File driver | `file://` | -| `json` | JSON file driver provided by [Polars](https://github.com/pola-rs/polars) | `json://` | -| `jsonl` | JSONL file driver provided by [Polars](https://github.com/pola-rs/polars) | `jsonl://` | +| `json` | JSON provided by [Polars](https://github.com/pola-rs/polars) | `json://` | +| `jsonl` | JSONL provided by [Polars](https://github.com/pola-rs/polars) | `jsonl://` | | `libsql` | LibSQL provided by [Turso](https://github.com/tursodatabase/libsql) | `libsql://?[][&file=][&auth_token=]` | | `mariadb` | MariaDB provided by [SQLx](https://github.com/launchbadge/sqlx) | `mariadb://[:]@[:]/` | | `mysql` | MySQL provided by [SQLx](https://github.com/launchbadge/sqlx) | `mysql://[:]@[:]/` | -| `parquet` | Parquet file driver provided by [Polars](https://github.com/pola-rs/polars) | `parquet://` | +| `parquet` | Parquet provided by [Polars](https://github.com/pola-rs/polars) | `parquet://` | | `postgres` | PostgreSQL driver provided by [rust-postgres](https://github.com/sfackler/rust-postgres) | `postgres://[:]@[:]/?` | | `postgresql` | PostgreSQL driver provided by [SQLx](https://github.com/launchbadge/sqlx) | `postgresql://[:]@[:]/?` | | `redshift` | Redshift driver provided by [SQLx](https://github.com/launchbadge/sqlx) | `redshift://[:]@[:]/` | @@ -32,7 +32,8 @@ The drivers command displays the available database drivers. | `snowflake` | Snowflake provided by [Snowflake SQL API](https://docs.snowflake.com/en/developer-guide/sql-api/index) | `snowflake://[:]@.snowflakecomputing.com/[?private_key_file=pkey_file&public_key_file=pubkey_file]` | | `sqlite` | SQLite provided by [SQLx](https://github.com/launchbadge/sqlx) | `sqlite://[]` | | `sqlserver` | SQL Server provided by [Tiberius](https://github.com/prisma/tiberius) | `sqlserver://[:]@[:]/` | -| `tsv` | Tab Separated Value (TSV) file driver provided by [Polars](https://github.com/pola-rs/polars) | `tsv://[?has_header=]["e=][&skip_rows=]` | +| `tsv` | Tab Separated Value (TSV) provided by [Polars](https://github.com/pola-rs/polars) | `tsv://[?has_header=]["e=][&skip_rows=]` | +| `xml` | Extensible Markup Language (XML) provided by [Polars](https://github.com/pola-rs/polars) | `xml://` | ### Examples diff --git a/rsql_core/Cargo.toml b/rsql_core/Cargo.toml index 988ed7d0..f4a77e5a 100644 --- a/rsql_core/Cargo.toml +++ b/rsql_core/Cargo.toml @@ -78,6 +78,7 @@ all-drivers = [ "driver-sqlite", "driver-sqlserver", "driver-tsv", + "driver-xml", ] driver-arrow = ["rsql_drivers/arrow"] driver-avro = ["rsql_drivers/avro"] @@ -100,6 +101,7 @@ driver-snowflake = ["rsql_drivers/snowflake"] driver-sqlite = ["rsql_drivers/sqlite"] driver-sqlserver = ["rsql_drivers/sqlserver"] driver-tsv = ["rsql_drivers/tsv"] +driver-xml = ["rsql_drivers/xml"] all-formats = [ "format-ascii", "format-csv", diff --git a/rsql_core/src/commands/drivers.rs b/rsql_core/src/commands/drivers.rs index 5e55f7a6..0db6154f 100644 --- a/rsql_core/src/commands/drivers.rs +++ b/rsql_core/src/commands/drivers.rs @@ -121,6 +121,8 @@ mod tests { "sqlserver", #[cfg(feature = "driver-tsv")] "tsv", + #[cfg(feature = "driver-xml")] + "xml", ]; let available_drivers = drivers.join(", "); diff --git a/rsql_drivers/Cargo.toml b/rsql_drivers/Cargo.toml index f40947aa..cdca8b87 100644 --- a/rsql_drivers/Cargo.toml +++ b/rsql_drivers/Cargo.toml @@ -29,6 +29,7 @@ num-format = { workspace = true } polars = { workspace = true, optional = true, features = ["avro", "ipc", "lazy", "json", "parquet", "polars-sql"] } polars-sql = { workspace = true, optional = true } postgresql_embedded = { workspace = true, optional = true } +quick-xml = { workspace = true, optional = true, features = ["serde"] } regex = { workspace = true } reqwest = { workspace = true, optional = true, features = ["json", "gzip"] } rusqlite = { workspace = true, features = ["bundled-full"], optional = true } @@ -87,6 +88,7 @@ all = [ "sqlite", "sqlserver", "tsv", + "xml", ] default = [] arrow = [ @@ -171,6 +173,11 @@ tsv = [ "dep:polars", "dep:polars-sql", ] +xml = [ + "dep:polars", + "dep:polars-sql", + "dep:quick-xml", +] [lints.clippy] unwrap_used = "deny" diff --git a/rsql_drivers/src/driver.rs b/rsql_drivers/src/driver.rs index 0d91cb9b..5fb2de3e 100644 --- a/rsql_drivers/src/driver.rs +++ b/rsql_drivers/src/driver.rs @@ -136,6 +136,8 @@ impl Default for DriverManager { drivers.add(Box::new(crate::sqlserver::Driver)); #[cfg(feature = "tsv")] drivers.add(Box::new(crate::tsv::Driver)); + #[cfg(feature = "xml")] + drivers.add(Box::new(crate::xml::Driver)); drivers } @@ -216,6 +218,8 @@ mod tests { let driver_count = driver_count + 1; #[cfg(feature = "tsv")] let driver_count = driver_count + 1; + #[cfg(feature = "xml")] + let driver_count = driver_count + 1; assert_eq!(driver_manager.drivers.len(), driver_count); } diff --git a/rsql_drivers/src/file/driver.rs b/rsql_drivers/src/file/driver.rs index 7da25cec..eb577585 100644 --- a/rsql_drivers/src/file/driver.rs +++ b/rsql_drivers/src/file/driver.rs @@ -53,33 +53,52 @@ impl crate::Driver for Driver { mod test { use crate::test::dataset_url; use crate::{DriverManager, Value}; + use indoc::indoc; #[tokio::test] async fn test_file_drivers() -> anyhow::Result<()> { let database_urls = vec![ - dataset_url("file", "users.arrow"), - dataset_url("file", "users.avro"), - dataset_url("file", "users.csv"), - dataset_url("file", "users.duckdb"), - dataset_url("file", "users.json"), - dataset_url("file", "users.jsonl"), - dataset_url("file", "users.parquet"), - dataset_url("file", "users.sqlite3"), - dataset_url("file", "users.tsv"), + #[cfg(feature = "arrow")] + (dataset_url("file", "users.arrow"), None), + #[cfg(feature = "avro")] + (dataset_url("file", "users.avro"), None), + #[cfg(feature = "csv")] + (dataset_url("file", "users.csv"), None), + #[cfg(feature = "duckdb")] + (dataset_url("file", "users.duckdb"), None), + #[cfg(feature = "json")] + (dataset_url("file", "users.json"), None), + #[cfg(feature = "jsonl")] + (dataset_url("file", "users.jsonl"), None), + #[cfg(feature = "parquet")] + (dataset_url("file", "users.parquet"), None), + #[cfg(feature = "sqlite")] + (dataset_url("file", "users.sqlite3"), None), + #[cfg(feature = "tsv")] + (dataset_url("file", "users.tsv"), None), + #[cfg(feature = "xml")] + ( + dataset_url("file", "users.xml"), + Some(indoc! {r" + WITH cte_user AS ( + SELECT unnest(data.user) FROM users + ) + SELECT user.* FROM cte_user + "}), + ), ]; - for database_url in database_urls { - test_file_driver(database_url.as_str()).await?; + for (database_url, sql) in database_urls { + test_file_driver(database_url.as_str(), sql).await?; } Ok(()) } - async fn test_file_driver(database_url: &str) -> anyhow::Result<()> { + async fn test_file_driver(database_url: &str, sql: Option<&str>) -> anyhow::Result<()> { + let sql = sql.unwrap_or("SELECT id, name FROM users ORDER BY id"); let driver_manager = DriverManager::default(); let mut connection = driver_manager.connect(database_url).await?; - let mut query_result = connection - .query("SELECT id, name FROM users ORDER BY id") - .await?; + let mut query_result = connection.query(sql).await?; assert_eq!(query_result.columns().await, vec!["id", "name"]); assert_eq!( diff --git a/rsql_drivers/src/lib.rs b/rsql_drivers/src/lib.rs index 78850659..16534db7 100644 --- a/rsql_drivers/src/lib.rs +++ b/rsql_drivers/src/lib.rs @@ -67,6 +67,8 @@ mod test; mod tsv; mod url; mod value; +#[cfg(feature = "xml")] +mod xml; pub use connection::{ Connection, LimitQueryResult, MemoryQueryResult, MockConnection, QueryResult, StatementMetadata, diff --git a/rsql_drivers/src/xml/driver.rs b/rsql_drivers/src/xml/driver.rs new file mode 100644 index 00000000..e300447e --- /dev/null +++ b/rsql_drivers/src/xml/driver.rs @@ -0,0 +1,309 @@ +use crate::error::Result; +use crate::polars::Connection; +use crate::url::UrlExtension; +use crate::Error::{ConversionError, IoError}; +use async_trait::async_trait; +use polars::io::SerReader; +use polars::prelude::{IntoLazy, JsonReader}; +use polars_sql::SQLContext; +use quick_xml::events::Event; +use quick_xml::Reader; +use serde_json::{json, Number, Value}; +use std::collections::HashMap; +use std::io::Cursor; +use std::num::NonZeroUsize; +use tokio::fs::read_to_string; +use url::Url; + +#[derive(Debug)] +pub struct Driver; + +#[async_trait] +impl crate::Driver for Driver { + fn identifier(&self) -> &'static str { + "xml" + } + + async fn connect( + &self, + url: String, + _password: Option, + ) -> Result> { + let parsed_url = Url::parse(url.as_str())?; + let query_parameters: HashMap = + parsed_url.query_pairs().into_owned().collect(); + + let file_name = parsed_url.to_file()?.to_string_lossy().to_string(); + let json = { + let xml = read_to_string(&file_name).await?; + let value = xml_to_json(&xml)?; + serde_json::to_string(&value).map_err(|error| IoError(error.into()))? + }; + + let ignore_errors = query_parameters + .get("ignore_errors") + .map_or(false, |v| v == "true"); + let infer_schema_length = match query_parameters.get("infer_schema_length") { + Some(infer_schema_length) => { + let length = infer_schema_length + .parse::() + .map_err(|error| ConversionError(error.to_string()))?; + if length == 0 { + None + } else { + NonZeroUsize::new(length) + } + } + None => NonZeroUsize::new(100), + }; + + let cursor = Cursor::new(json.as_bytes()); + let data_frame = JsonReader::new(cursor) + .infer_schema_len(infer_schema_length) + .set_rechunk(true) + .with_ignore_errors(ignore_errors) + .finish()?; + + let table_name = crate::polars::driver::get_table_name(file_name)?; + let mut context = SQLContext::new(); + context.register(table_name.as_str(), data_frame.lazy()); + + let connection = Connection::new(url, context).await?; + Ok(Box::new(connection)) + } + + fn file_media_type(&self) -> Option<&'static str> { + Some("text/xml") + } +} + +/// Convert XML to JSON so that it can be read by Polars +fn xml_to_json(xml: &str) -> Result { + let mut reader = Reader::from_str(xml); + reader.config_mut().trim_text(true); + let mut buffer = Vec::new(); + let mut stack: Vec<(String, HashMap)> = Vec::new(); + + loop { + match reader.read_event_into(&mut buffer) { + Ok(Event::Start(e)) => { + let name = String::from_utf8_lossy(e.name().as_ref()).to_string(); + let mut element = HashMap::new(); + // Store all attributes as key-value pairs where the name is prefixed with '@' + for attribute in e.attributes() { + match attribute { + Ok(attribute) => { + let name = String::from_utf8_lossy(attribute.key.as_ref()).to_string(); + let text = + String::from_utf8_lossy(attribute.value.as_ref()).to_string(); + let value = infer_value(&text); + element.insert(format!("@{name}"), value); + } + Err(error) => return Err(IoError(error.into())), + } + } + stack.push((name, element)); + } + Ok(Event::Text(e)) => { + if let Some((_, map)) = stack.last_mut() { + let text = e.unescape().unwrap_or_default().into_owned(); + if !text.is_empty() { + let value = infer_value(&text); + map.insert("#text".to_string(), value); + } + } + } + Ok(Event::End(_)) => { + if let Some((name, attributes)) = stack.pop() { + let value = match attributes.get("#text").cloned() { + Some(text) if attributes.len() == 1 => text, + _ if attributes.is_empty() => json!(null), + _ => json!(attributes), + }; + + if let Some((_, parent_map)) = stack.last_mut() { + match parent_map.get_mut(&name) { + Some(existing) => { + if let serde_json::Value::Array(arr) = existing { + arr.push(value); + } else { + let prev = std::mem::replace( + existing, + json!([existing.clone(), value]), + ); + parent_map.insert(name, json!([prev, value])); + } + } + None => { + parent_map.insert(name, value); + } + } + } else { + return Ok(json!({ name: value })); + } + } + } + Ok(Event::Eof) => break, + Err(error) => return Err(IoError(error.into())), + _ => (), + } + buffer.clear(); + } + + Ok(json!({})) +} + +fn infer_value(text: &str) -> Value { + let text = text.trim(); + + if let Ok(v) = text.parse::() { + if text.starts_with('0') && text.len() > 1 { + return Value::String(text.into()); + } + return Value::Number(Number::from(v)); + } + if let Ok(v) = text.parse::() { + if text.starts_with('0') && !text.starts_with("0.") { + return Value::String(text.into()); + } + if let Some(val) = Number::from_f64(v) { + return Value::Number(val); + } + } + if let Ok(v) = text.parse::() { + return Value::Bool(v); + } + + Value::String(text.into()) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test::dataset_url; + use crate::{DriverManager, Value}; + use indoc::indoc; + + fn database_url() -> String { + dataset_url("xml", "users.xml") + } + + #[tokio::test] + async fn test_driver_connect() -> anyhow::Result<()> { + let database_url = database_url(); + let driver_manager = DriverManager::default(); + let mut connection = driver_manager.connect(&database_url).await?; + assert_eq!(&database_url, connection.url()); + connection.close().await?; + Ok(()) + } + + #[tokio::test] + async fn test_connection_interface() -> anyhow::Result<()> { + let database_url = database_url(); + let driver_manager = DriverManager::default(); + let mut connection = driver_manager.connect(&database_url).await?; + + let mut query_result = connection + .query(indoc! {r" + WITH cte_user AS ( + SELECT unnest(data.user) FROM users + ) + SELECT user.* FROM cte_user + "}) + .await?; + + assert_eq!(query_result.columns().await, vec!["id", "name"]); + assert_eq!( + query_result.next().await, + Some(vec![Value::I64(1), Value::String("John Doe".to_string())]) + ); + assert_eq!( + query_result.next().await, + Some(vec![Value::I64(2), Value::String("Jane Smith".to_string())]) + ); + assert!(query_result.next().await.is_none()); + + connection.close().await?; + Ok(()) + } + + #[test] + fn test_xml_to_json() -> Result<()> { + let xml = indoc! {r#" + + + 1 + John Doe + john.doe@none.com + + + Jane Smith + 2 + + + "#}; + + let value = xml_to_json(xml)?; + let _json = serde_json::to_string(&value); + let data = value.get("data").expect("Expected data object"); + let foo = data.get("@foo").expect("Expected foo attribute"); + let foo_value = foo + .as_i64() + .expect("Expected foo attribute to be an integer"); + assert_eq!(foo_value, 42); + let user = data.get("user").expect("Expected user value"); + let user = user.as_array().expect("Expected user value to be an array"); + assert_eq!(user.len(), 2); + let user1 = user.first().expect("Expected user 1"); + let score = user1 + .get("@score") + .expect("Expected score attribute") + .as_f64() + .expect("Expected score attribute to be a float"); + let diff = score - 1.234; + assert!(diff.abs() < 0.01f64); + let user1_id = user1 + .get("id") + .expect("Expected id") + .as_i64() + .expect("Expected id to be an integer"); + let user1_name = user1 + .get("name") + .expect("Expected name") + .as_str() + .expect("Expected name to be a string"); + assert_eq!(user1_id, 1); + assert_eq!(user1_name, "John Doe"); + + // Test element with text and an attribute + let user1_email = user1.get("email").expect("Expected email"); + let user1_email_secure = user1_email + .get("@secure") + .expect("Expected secure attribute") + .as_bool() + .expect("Expected secure attribute to be a boolean"); + let user1_email = user1_email + .get("#text") + .expect("Expected email text") + .as_str() + .expect("Expected email text to be a string"); + assert!(!user1_email_secure); + assert_eq!(user1_email, "john.doe@none.com"); + + let user2 = user.last().expect("Expected user 2"); + let user2_id = user2 + .get("id") + .expect("Expected id") + .as_i64() + .expect("Expected id to be an integer"); + let user2_name = user2 + .get("name") + .expect("Expected name") + .as_str() + .expect("Expected name to be a string"); + assert_eq!(user2_id, 2); + assert_eq!(user2_name, "Jane Smith"); + Ok(()) + } +} diff --git a/rsql_drivers/src/xml/mod.rs b/rsql_drivers/src/xml/mod.rs new file mode 100644 index 00000000..dc243eec --- /dev/null +++ b/rsql_drivers/src/xml/mod.rs @@ -0,0 +1,3 @@ +pub mod driver; + +pub use driver::Driver;