Skip to content

Commit

Permalink
Arrow Flight SQL example JDBC driver incompatibility (#5666)
Browse files Browse the repository at this point in the history
* feat: handshake returns auth header + clear_token

* fixed location

* Update arrow-flight/examples/flight_sql_server.rs

Cleaner type for response variable

Co-authored-by: Jeffrey Vo <[email protected]>

* removed location for more sensible default behavior

* Removed unused import

* Switched back to 0.0.0.0 IP

---------

Co-authored-by: Jeffrey Vo <[email protected]>
  • Loading branch information
istvan-fodor and Jefffrey authored Apr 24, 2024
1 parent 0230795 commit 11450ae
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
48 changes: 27 additions & 21 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use once_cell::sync::Lazy;
use prost::Message;
use std::collections::HashSet;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use tonic::metadata::MetadataValue;
use tonic::transport::Server;
use tonic::transport::{Certificate, Identity, ServerTlsConfig};
use tonic::{Request, Response, Status, Streaming};
Expand Down Expand Up @@ -52,7 +54,7 @@ use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action,
FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
IpcMessage, Location, SchemaAsIpc, Ticket,
IpcMessage, SchemaAsIpc, Ticket,
};
use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema};
Expand Down Expand Up @@ -184,7 +186,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
};
let result = Ok(result);
let output = futures::stream::iter(vec![result]);
return Ok(Response::new(Box::pin(output)));

let token = format!("Bearer {}", FAKE_TOKEN);
let mut response: Response<Pin<Box<dyn Stream<Item = _> + Send>>> =
Response::new(Box::pin(output));
response.metadata_mut().append(
"authorization",
MetadataValue::from_str(token.as_str()).unwrap(),
);
return Ok(response);
}

async fn do_get_fallback(
Expand Down Expand Up @@ -235,21 +245,20 @@ impl FlightSqlService for FlightSqlServiceImpl {
self.check_token(&request)?;
let handle = std::str::from_utf8(&cmd.prepared_statement_handle)
.map_err(|e| status!("Unable to parse handle", e))?;

let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
let schema = (*batch.schema()).clone();
let num_rows = batch.num_rows();
let num_bytes = batch.get_array_memory_size();
let loc = Location {
uri: "grpc+tcp://127.0.0.1".to_string(),
};

let fetch = FetchResults {
handle: handle.to_string(),
};
let buf = fetch.as_any().encode_to_vec().into();
let ticket = Ticket { ticket: buf };
let endpoint = FlightEndpoint {
ticket: Some(ticket),
location: vec![loc],
location: vec![],
expiration_time: None,
app_metadata: vec![].into(),
};
Expand Down Expand Up @@ -662,9 +671,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
_query: ActionClosePreparedStatementRequest,
_request: Request<Action>,
) -> Result<(), Status> {
Err(Status::unimplemented(
"Implement do_action_close_prepared_statement",
))
Ok(())
}

async fn do_action_create_prepared_substrait_plan(
Expand Down Expand Up @@ -725,9 +732,8 @@ impl FlightSqlService for FlightSqlServiceImpl {
/// This example shows how to run a FlightSql server
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "0.0.0.0:50051".parse()?;

let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
let addr_str = "0.0.0.0:50051";
let addr = addr_str.parse()?;

println!("Listening on {:?}", addr);

Expand All @@ -736,6 +742,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?;
let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?;

let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
let tls_config = ServerTlsConfig::new()
.identity(Identity::from_pem(&cert, &key))
.client_ca_root(Certificate::from_pem(&client_ca));
Expand All @@ -746,6 +753,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.serve(addr)
.await?;
} else {
let svc = FlightServiceServer::new(FlightSqlServiceImpl {});

Server::builder().add_service(svc).serve(addr).await?;
}

Expand Down Expand Up @@ -999,15 +1008,6 @@ mod tests {
.to_string()
.contains("Invalid credentials"));

// forget to set_token
client.handshake("admin", "password").await.unwrap();
assert!(client
.prepare("select 1;".to_string(), None)
.await
.unwrap_err()
.to_string()
.contains("No authorization header"));

// Invalid Tokens
client.handshake("admin", "password").await.unwrap();
client.set_token("wrong token".to_string());
Expand All @@ -1017,6 +1017,12 @@ mod tests {
.unwrap_err()
.to_string()
.contains("invalid token"));

client.clear_token();

// Successful call (token is automatically set by handshake)
client.handshake("admin", "password").await.unwrap();
client.prepare("select 1;".to_string(), None).await.unwrap();
})
.await
}
Expand Down
5 changes: 5 additions & 0 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ impl FlightSqlServiceClient<Channel> {
self.token = Some(token);
}

/// Clear the auth token.
pub fn clear_token(&mut self) {
self.token = None;
}

/// Set header value.
pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key: String = key.into();
Expand Down

0 comments on commit 11450ae

Please sign in to comment.