Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added Params for CA Certs in MSSQL Config #698

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 13 additions & 98 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions connectorx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ chrono = "0.4"
arrow = {workspace = true, optional = true}
arrow2 = {workspace = true, default-features = false, optional = true}
bb8 = {version = "0.7", optional = true}
bb8-tiberius = {version = "0.5", optional = true}
bb8-tiberius = {version = "0.8", optional = true}
csv = {version = "1", optional = true}
fallible-streaming-iterator = {version = "0.1", optional = true}
futures = {version = "0.3", optional = true}
Expand All @@ -50,7 +50,7 @@ regex = {version = "1", optional = true}
rusqlite = {version = "0.30.0", features = ["column_decltype", "chrono", "bundled"], optional = true}
rust_decimal = {version = "1", features = ["db-postgres"], optional = true}
rust_decimal_macros = {version = "1", optional = true}
tiberius = {version = "0.5", features = ["rust_decimal", "chrono", "integrated-auth-gssapi"], optional = true}
tiberius = {version = "0.7.3", features = ["rust_decimal", "chrono", "integrated-auth-gssapi"], optional = true}
tokio = {version = "1", features = ["rt", "rt-multi-thread", "net"], optional = true}
tokio-util = {version = "0.6", optional = true}
urlencoding = {version = "2.1", optional = true}
Expand Down
44 changes: 31 additions & 13 deletions connectorx/src/sources/mssql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use rust_decimal::Decimal;
use sqlparser::dialect::MsSqlDialect;
use std::collections::HashMap;
use std::sync::Arc;
use tiberius::{AuthMethod, Config, EncryptionLevel, QueryResult, Row};
use tiberius::{AuthMethod, Config, EncryptionLevel, QueryItem, QueryStream, Row};
use tokio::runtime::{Handle, Runtime};
use url::Url;
use urlencoding::decode;
Expand Down Expand Up @@ -84,6 +84,16 @@ pub fn mssql_config(url: &Url) -> Config {
decode(url.password().unwrap_or(""))?.to_owned(),
));

match params.get("trust_server_certificate") {
Some(v) if v.to_lowercase() == "true" => config.trust_cert(),
_ => {}
};

match params.get("trust_server_certificate_ca") {
Some(v) => config.trust_cert_ca(v),
_ => {}
};

match params.get("encrypt") {
Some(v) if v.to_lowercase() == "true" => config.encryption(EncryptionLevel::Required),
_ => config.encryption(EncryptionLevel::NotSupported),
Expand Down Expand Up @@ -147,22 +157,27 @@ where
let mut conn = self.rt.block_on(self.pool.get())?;
let first_query = &self.queries[0];
let (names, types) = match self.rt.block_on(conn.query(first_query.as_str(), &[])) {
Ok(stream) => {
let columns = stream.columns().ok_or_else(|| {
anyhow!("MsSQL failed to get the columns of query: {}", first_query)
})?;
columns
Ok(mut stream) => match self.rt.block_on(async { stream.columns().await }) {
Ok(Some(columns)) => columns
.iter()
.map(|col| {
(
col.name().to_string(),
MsSQLTypeSystem::from(&col.column_type()),
)
})
.unzip()
}
.unzip(),
Ok(None) => {
throw!(anyhow!(
"MsSQL returned no columns for query: {}",
first_query
));
}
Err(e) => {
throw!(anyhow!("Error fetching columns: {}", e));
}
},
Err(e) => {
// tried the last query but still get an error
debug!(
"cannot get metadata for '{}', try next query: {}",
first_query, e
Expand Down Expand Up @@ -269,7 +284,7 @@ impl SourcePartition for MsSQLSourcePartition {
#[throws(MsSQLSourceError)]
fn parser<'a>(&'a mut self) -> Self::Parser<'a> {
let conn = self.rt.block_on(self.pool.get())?;
let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>> =
let rows: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>> =
OwningHandle::new_with_fn(Box::new(conn), |conn: *const Conn<'a>| unsafe {
let conn = &mut *(conn as *mut Conn<'a>);

Expand All @@ -294,7 +309,7 @@ impl SourcePartition for MsSQLSourcePartition {

pub struct MsSQLSourceParser<'a> {
rt: &'a Handle,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>>,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
rowbuf: Vec<Row>,
ncols: usize,
current_col: usize,
Expand All @@ -305,7 +320,7 @@ pub struct MsSQLSourceParser<'a> {
impl<'a> MsSQLSourceParser<'a> {
fn new(
rt: &'a Handle,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryResult<'a>>>,
iter: OwningHandle<Box<Conn<'a>>, DummyBox<QueryStream<'a>>>,
schema: &[MsSQLTypeSystem],
) -> Self {
Self {
Expand Down Expand Up @@ -348,7 +363,10 @@ impl<'a> PartitionParser<'a> for MsSQLSourceParser<'a> {

for _ in 0..DB_BUFFER_SIZE {
if let Some(item) = self.rt.block_on(self.iter.next()) {
self.rowbuf.push(item?);
match item.map_err(MsSQLSourceError::MsSQLError)? {
QueryItem::Row(row) => self.rowbuf.push(row),
_ => continue,
}
} else {
self.is_finished = true;
break;
Expand Down
14 changes: 10 additions & 4 deletions docs/databases/mssql.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ SQLServer does not need to specify protocol.

### MsSQL Connection
```{hint}
By adding `trusted_connection=true` to connection uri parameter, windows authentication will be enabled. Example: `mssql://host:port/db?trusted_connection=true`
By adding `encrypt=true` to connection uri parameter, SQLServer will use SSL encryption. Example: `mssql://host:port/db?encrypt=true&trusted_connection=true`
```
```{hint}
if the user password has special characters, they need to be sanitized. example: `from urllib import parse; password = parse.quote_plus(password)`
```

Expand All @@ -20,6 +16,16 @@ query = 'SELECT * FROM table' # query string
cx.read_sql(conn, query) # read data from MsSQL
```

### Connection Parameters
* By adding `trusted_connection=true` to connection uri parameter, windows authentication will be enabled.
* Example: `mssql://host:port/db?trusted_connection=true`
* By adding `encrypt=true` to connection uri parameter, SQLServer will use SSL encryption.
* Example: `mssql://host:port/db?encrypt=true&trusted_connection=true`
* By adding `trust_server_certificate=true` to connection uri parameter, the SQLServer certificate will not be validated and it is accepted as-is.
* Example: `mssql://host:port/db?trust_server_certificate=true&encrypt=true`
* By adding `trust_server_certificate_ca=/path/to/ca-cert.crt` to connection uri parameter, the SQLServer certificate will be validated against the given CA certificate in addition to the system-truststore.
* Example: `mssql://host:port/db?encrypt=true&trust_server_certificate_ca=/path/to/ca-cert.crt`

### SQLServer-Pandas Type Mapping
| SQLServer Type | Pandas Type | Comment |
|:---------------:|:---------------------------:|:----------------------------------:|
Expand Down
Loading