Skip to content

Commit

Permalink
rework api
Browse files Browse the repository at this point in the history
  • Loading branch information
suremarc committed Sep 15, 2023
1 parent 80855ec commit 6cfc5c4
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 21 deletions.
10 changes: 5 additions & 5 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use arrow_flight::sql::server::PeekableFlightDataStream;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::stream::Peekable;
use futures::{stream, Stream, TryStreamExt};
use once_cell::sync::Lazy;
use prost::Message;
Expand Down Expand Up @@ -603,15 +603,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
Expand All @@ -621,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
Expand All @@ -631,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
Expand Down
2 changes: 1 addition & 1 deletion arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
match err {
FlightError::Arrow(e) => e,
e => ArrowError::ExternalError(Box::new(e))
e => ArrowError::ExternalError(Box::new(e)),
}
}

Expand Down
100 changes: 93 additions & 7 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use std::pin::Pin;

use futures::{stream::Peekable, Stream};
use futures::{stream::Peekable, Stream, StreamExt};
use prost::Message;
use tonic::{Request, Response, Status, Streaming};

Expand Down Expand Up @@ -366,7 +366,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
/// Implementors may override to handle additional calls to do_put()
async fn do_put_fallback(
&self,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
message: Any,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(format!(
Expand All @@ -379,7 +379,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_statement_update has no default implementation",
Expand All @@ -390,7 +390,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query has no default implementation",
Expand All @@ -401,7 +401,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update has no default implementation",
Expand All @@ -412,7 +412,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_put_substrait_plan(
&self,
_query: CommandStatementSubstraitPlan,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan has no default implementation",
Expand Down Expand Up @@ -696,7 +696,7 @@ where
// To allow the first message to be reused by the `do_put` handler,
// we wrap this stream in a `Peekable` one, which allows us to peek at
// the first message without discarding it.
let mut request = request.map(futures::StreamExt::peekable);
let mut request = request.map(PeekableFlightDataStream::new);
let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;

let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
Expand Down Expand Up @@ -965,3 +965,89 @@ fn decode_error_to_status(err: prost::DecodeError) -> Status {
fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
Status::internal(format!("{err:?}"))
}

/// A wrapper around [`Streaming<FlightData>`] that allows "peeking" at the
/// message at the front of the stream without consuming it.
/// This is needed because sometimes the first message in the stream will contain
/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic
/// must inspect this information.
///
/// # Example
///
/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without
/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream.
/// See the following example:
///
/// ```no_run
/// use arrow_array::RecordBatch;
/// use arrow_flight::decode::FlightRecordBatchStream;
/// use arrow_flight::FlightDescriptor;
/// use arrow_flight::error::FlightError;
/// use arrow_flight::sql::server::PeekableFlightDataStream;
/// use tonic::{Request, Status};
/// use futures::TryStreamExt;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Status> {
/// let request: Request<PeekableFlightDataStream> = todo!();
/// let stream: PeekableFlightDataStream = request.into_inner();
///
/// // The first message contains the flight descriptor and the schema.
/// // Read the flight descriptor without discarding the schema:
/// let flight_descriptor: FlightDescriptor = stream
/// .peek()
/// .await
/// .cloned()
/// .transpose()?
/// .and_then(|data| data.flight_descriptor)
/// .expect("first message should contain flight descriptor");
///
/// // Pass the stream through a decoder
/// let batches: Vec<RecordBatch> = FlightRecordBatchStream::new_from_flight_data(
/// request.into_inner().map_err(|e| e.into()),
/// )
/// .try_collect()
/// .await?;
/// }
/// ```
pub struct PeekableFlightDataStream {
inner: Peekable<Streaming<FlightData>>,
}

impl PeekableFlightDataStream {
fn new(stream: Streaming<FlightData>) -> Self {
Self {
inner: stream.peekable(),
}
}

/// Convert this stream into a `Streaming<FlightData>`.
/// Any messages observed through [`Self::peek`] will be lost
/// after the conversion.
pub fn into_inner(self) -> Streaming<FlightData> {
self.inner.into_inner()
}

/// Convert this stream into a `Peekable<Streaming<FlightData>>`.
/// Preserves the state of the stream, so that calls to [`Self::peek`]
/// and [`Self::poll_next`] are the same.
pub fn into_peekable(self) -> Peekable<Streaming<FlightData>> {
self.inner
}

/// Peek at the head of this stream without advancing it.
pub async fn peek(&mut self) -> Option<&Result<FlightData, Status>> {
Pin::new(&mut self.inner).peek().await
}
}

impl Stream for PeekableFlightDataStream {
type Item = Result<FlightData, Status>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
17 changes: 9 additions & 8 deletions arrow-flight/tests/flight_sql_client_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ use arrow_flight::{
decode::FlightRecordBatchStream,
flight_service_server::{FlightService, FlightServiceServer},
sql::{
server::FlightSqlService, ActionBeginSavepointRequest,
ActionBeginSavepointResult, ActionBeginTransactionRequest,
ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
server::{FlightSqlService, PeekableFlightDataStream},
ActionBeginSavepointRequest, ActionBeginSavepointResult,
ActionBeginTransactionRequest, ActionBeginTransactionResult,
ActionCancelQueryRequest, ActionCancelQueryResult,
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
Expand All @@ -43,7 +44,7 @@ use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use assert_cmd::Command;
use bytes::Bytes;
use futures::{stream::Peekable, Stream, StreamExt, TryStreamExt};
use futures::{Stream, StreamExt, TryStreamExt};
use prost::Message;
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{Request, Response, Status, Streaming};
Expand Down Expand Up @@ -505,7 +506,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_statement_update not implemented",
Expand All @@ -515,7 +516,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
Expand All @@ -525,7 +526,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
request: Request<Peekable<Streaming<FlightData>>>,
request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
// just make sure decoding the parameters works
let parameters = FlightRecordBatchStream::new_from_flight_data(
Expand Down Expand Up @@ -554,7 +555,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Peekable<Streaming<FlightData>>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
Expand Down

0 comments on commit 6cfc5c4

Please sign in to comment.