Skip to content

Commit

Permalink
bump tiberius & bb8-tiberius, replaced deprecated queryresult method …
Browse files Browse the repository at this point in the history
…w/ querystream in mssql impl
  • Loading branch information
pangjunrong committed Oct 18, 2024
1 parent 83b4d50 commit e6b197d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 113 deletions.
111 changes: 13 additions & 98 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions connectorx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ chrono = "0.4"
arrow = {workspace = true, optional = true}
arrow2 = {workspace = true, default-features = false, optional = true}
bb8 = {version = "0.7", optional = true}
bb8-tiberius = {version = "0.5", optional = true}
bb8-tiberius = {version = "0.8", optional = true}
csv = {version = "1", optional = true}
fallible-streaming-iterator = {version = "0.1", optional = true}
futures = {version = "0.3", optional = true}
Expand All @@ -50,7 +50,7 @@ regex = {version = "1", optional = true}
rusqlite = {version = "0.30.0", features = ["column_decltype", "chrono", "bundled"], optional = true}
rust_decimal = {version = "1", features = ["db-postgres"], optional = true}
rust_decimal_macros = {version = "1", optional = true}
tiberius = {version = "0.5", features = ["rust_decimal", "chrono", "integrated-auth-gssapi"], optional = true}
tiberius = {version = "0.7.3", features = ["rust_decimal", "chrono", "integrated-auth-gssapi"], optional = true}
tokio = {version = "1", features = ["rt", "rt-multi-thread", "net"], optional = true}
tokio-util = {version = "0.6", optional = true}
urlencoding = {version = "2.1", optional = true}
Expand Down
34 changes: 21 additions & 13 deletions connectorx/src/sources/mssql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use rust_decimal::Decimal;
use sqlparser::dialect::MsSqlDialect;
use std::collections::HashMap;
use std::sync::Arc;
use tiberius::{AuthMethod, Config, EncryptionLevel, QueryResult, Row};
use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
use tokio::runtime::{Handle, Runtime};
use url::Url;
use urlencoding::decode;
Expand Down Expand Up @@ -157,22 +157,27 @@ where
let mut conn = self.rt.block_on(self.pool.get())?;
let first_query = &self.queries[0];
let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
Ok(stream) => {
let columns = stream.columns().ok_or_else(|| {
anyhow!("MsSQL failed to get the columns of query: {}", first_query)
})?;
columns
Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
Ok(Some(columns)) => columns
.iter()
.map(|col| {
(
col.name().to_string(),
MsSQLTypeSystem::from(&col.column_type()),
)
})
.unzip()
}
.unzip(),
Ok(None) => {
throw!(anyhow!(
"MsSQL returned no columns for query: {}",
first_query
));
}
Err(e) => {
throw!(anyhow!("Error fetching columns: {}", e));
}
},
Err(e) => {
// tried the last query but still get an error
debug!(
"cannot get metadata for '{}', try next query: {}",
first_query, e
Expand Down Expand Up @@ -279,7 +284,7 @@ impl SourcePartition for MsSQLSourcePartition {
#[throws(MsSQLSourceError)]
fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
let conn = self.rt.block_on(self.pool.get())?;
let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>> =
let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
let conn = &mut *(conn as *mut Conn<'a>);

Expand All @@ -304,7 +309,7 @@ impl SourcePartition for MsSQLSourcePartition {

pub struct MsSQLSourceParser<'a> {
rt: &'a Handle,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>>,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
rowbuf: Vec<Row>,
ncols: usize,
current_col: usize,
Expand All @@ -315,7 +320,7 @@ pub struct MsSQLSourceParser<'a> {
impl<'a> MsSQLSourceParser<'a> {
fn new(
rt: &'a Handle,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>>,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
schema: &[MsSQLTypeSystem],
) -> Self {
Self {
Expand Down Expand Up @@ -358,7 +363,10 @@ impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {

for _ in 0..DB_BUFFER_SIZE {
if let Some(item) = self.rt.block_on(self.iter.next()) {
self.rowbuf.push(item?);
match item.map_err(MsSQLSourceError::MsSQLError)? {
QueryItem::Row(row) => self.rowbuf.push(row),
_ => continue,
}
} else {
self.is_finished = true;
break;
Expand Down

0 comments on commit e6b197d

Please sign in to comment.