diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index a8f8d1606506..031628eaa833 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -24,7 +24,9 @@ use once_cell::sync::Lazy; use prost::Message; use std::collections::HashSet; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; +use tonic::metadata::MetadataValue; use tonic::transport::Server; use tonic::transport::{Certificate, Identity, ServerTlsConfig}; use tonic::{Request, Response, Status, Streaming}; @@ -52,7 +54,7 @@ use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, - IpcMessage, Location, SchemaAsIpc, Ticket, + IpcMessage, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; @@ -184,7 +186,15 @@ impl FlightSqlService for FlightSqlServiceImpl { }; let result = Ok(result); let output = futures::stream::iter(vec![result]); - return Ok(Response::new(Box::pin(output))); + + let token = format!("Bearer {}", FAKE_TOKEN); + let mut response: Response + Send>>> = + Response::new(Box::pin(output)); + response.metadata_mut().append( + "authorization", + MetadataValue::from_str(token.as_str()).unwrap(), + ); + return Ok(response); } async fn do_get_fallback( @@ -235,13 +245,12 @@ impl FlightSqlService for FlightSqlServiceImpl { self.check_token(&request)?; let handle = std::str::from_utf8(&cmd.prepared_statement_handle) .map_err(|e| status!("Unable to parse handle", e))?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; let schema = (*batch.schema()).clone(); let num_rows = batch.num_rows(); let num_bytes = batch.get_array_memory_size(); - let loc = Location { - uri: "grpc+tcp://127.0.0.1".to_string(), - }; + let fetch = FetchResults { handle: handle.to_string(), }; @@ -249,7 +258,7 @@ impl FlightSqlService for FlightSqlServiceImpl { let ticket = Ticket { ticket: buf }; let endpoint = FlightEndpoint { ticket: Some(ticket), - location: vec![loc], + location: vec![], expiration_time: None, app_metadata: vec![].into(), }; @@ -662,9 +671,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: ActionClosePreparedStatementRequest, _request: Request, ) -> Result<(), Status> { - Err(Status::unimplemented( - "Implement do_action_close_prepared_statement", - )) + Ok(()) } async fn do_action_create_prepared_substrait_plan( @@ -725,9 +732,8 @@ impl FlightSqlService for FlightSqlServiceImpl { /// This example shows how to run a FlightSql server #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "0.0.0.0:50051".parse()?; - - let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + let addr_str = "0.0.0.0:50051"; + let addr = addr_str.parse()?; println!("Listening on {:?}", addr); @@ -736,6 +742,7 @@ async fn main() -> Result<(), Box> { let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?; let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); let tls_config = ServerTlsConfig::new() .identity(Identity::from_pem(&cert, &key)) .client_ca_root(Certificate::from_pem(&client_ca)); @@ -746,6 +753,8 @@ async fn main() -> Result<(), Box> { .serve(addr) .await?; } else { + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + Server::builder().add_service(svc).serve(addr).await?; } @@ -999,15 +1008,6 @@ mod tests { .to_string() .contains("Invalid credentials")); - // forget to set_token - client.handshake("admin", "password").await.unwrap(); - assert!(client - .prepare("select 1;".to_string(), None) - .await - .unwrap_err() - .to_string() - .contains("No authorization header")); - // Invalid Tokens client.handshake("admin", "password").await.unwrap(); client.set_token("wrong token".to_string()); @@ -1017,6 +1017,12 @@ mod tests { .unwrap_err() .to_string() .contains("invalid token")); + + client.clear_token(); + + // Successful call (token is automatically set by handshake) + client.handshake("admin", "password").await.unwrap(); + client.prepare("select 1;".to_string(), None).await.unwrap(); }) .await } diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 44250fbe63e2..29782a2bc44b 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -97,6 +97,11 @@ impl FlightSqlServiceClient { self.token = Some(token); } + /// Clear the auth token. + pub fn clear_token(&mut self) { + self.token = None; + } + /// Set header value. pub fn set_header(&mut self, key: impl Into, value: impl Into) { let key: String = key.into();