Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stateless prepared statements #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use arrow_flight::sql::server::PeekableFlightDataStream;
use arrow_flight::sql::DoPutPreparedStatementResult;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::{stream, Stream, TryStreamExt};
Expand Down Expand Up @@ -619,7 +620,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
&self,
_query: CommandPreparedStatementQuery,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
))
Expand Down
20 changes: 20 additions & 0 deletions arrow-flight/src/sql/arrow.flight.protocol.sql.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 27 additions & 5 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use crate::sql::{
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
SqlInfo,
};
use crate::trailers::extract_lazy_trailers;
use crate::{
Expand Down Expand Up @@ -501,6 +502,7 @@ impl PreparedStatement<Channel> {
}

/// Submit parameters to the server, if any have been set on this prepared statement instance
/// Updates our stored prepared statement handle with the handle given by the server response.
async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
Expand All @@ -519,17 +521,37 @@ impl PreparedStatement<Channel> {
.await
.map_err(flight_error_to_arrow_error)?;

self.flight_sql_client
// Attempt to update the stored handle with any updated handle in the DoPut result.
// Not all servers support this, so ignore any errors when attempting to decode.
if let Some(result) = self
.flight_sql_client
.do_put(stream::iter(flight_data))
.await?
.try_collect::<Vec<_>>()
.message()
.await
.map_err(status_to_arrow_error)?;
.map_err(status_to_arrow_error)?
{
if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
self.handle = handle;
}
}
}

Ok(())
}

/// Decodes the app_metadata stored in a [`PutResult`] as a
/// [`DoPutPreparedStatementResult`] and then returns
/// the inner prepared statement handle as [`Bytes`]
fn unpack_prepared_statement_handle(
&self,
put_result: &PutResult,
) -> Result<Option<Bytes>, ArrowError> {
let any = Any::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
Ok(any
.unpack::<DoPutPreparedStatementResult>()?
.and_then(|result| result.prepared_statement_handle))
}

/// Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
pub async fn close(mut self) -> Result<(), ArrowError> {
Expand Down
2 changes: 2 additions & 0 deletions arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub use gen::CommandPreparedStatementUpdate;
pub use gen::CommandStatementQuery;
pub use gen::CommandStatementSubstraitPlan;
pub use gen::CommandStatementUpdate;
pub use gen::DoPutPreparedStatementResult;
pub use gen::DoPutUpdateResult;
pub use gen::Nullable;
pub use gen::Searchable;
Expand Down Expand Up @@ -251,6 +252,7 @@ prost_message_ext!(
CommandStatementSubstraitPlan,
CommandStatementUpdate,
DoPutUpdateResult,
DoPutPreparedStatementResult,
TicketStatementQuery,
);

Expand Down
17 changes: 14 additions & 3 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use super::{
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};
use crate::{
flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
Expand Down Expand Up @@ -397,11 +398,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
}

/// Bind parameters to given prepared statement.
///
/// Returns an opaque handle that the client should pass
/// back to the server during subsequent requests with this
/// prepared statement.
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query has no default implementation",
))
Expand Down Expand Up @@ -709,7 +714,13 @@ where
Ok(Response::new(Box::pin(output)))
}
Command::CommandPreparedStatementQuery(command) => {
self.do_put_prepared_statement_query(command, request).await
let result = self
.do_put_prepared_statement_query(command, request)
.await?;
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.as_any().encode_to_vec().into(),
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandStatementSubstraitPlan(command) => {
let record_count = self.do_put_substrait_plan(command, request).await?;
Expand Down
77 changes: 59 additions & 18 deletions arrow-flight/tests/flight_sql_client_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ use arrow_flight::{
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery,
CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
},
utils::batches_to_flight_data,
Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket,
HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket,
};
use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use assert_cmd::Command;
use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt};
use futures::{Stream, TryStreamExt};
use prost::Message;
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{Request, Response, Status, Streaming};
Expand All @@ -51,7 +52,7 @@ const QUERY: &str = "SELECT * FROM table;";

#[tokio::test]
async fn test_simple() {
let test_server = FlightSqlServiceImpl {};
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(&test_server).await;
let addr = fixture.addr;

Expand Down Expand Up @@ -92,10 +93,9 @@ async fn test_simple() {

const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1";
const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle";
const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle";

#[tokio::test]
async fn test_do_put_prepared_statement() {
let test_server = FlightSqlServiceImpl {};
async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) {
let fixture = TestFixture::new(&test_server).await;
let addr = fixture.addr;

Expand Down Expand Up @@ -136,11 +136,40 @@ async fn test_do_put_prepared_statement() {
);
}

#[tokio::test]
pub async fn test_do_put_prepared_statement_stateless() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: true,
})
.await
}

#[tokio::test]
pub async fn test_do_put_prepared_statement_stateful() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: false,
})
.await
}

/// All tests must complete within this many seconds or else the test server is shutdown
const DEFAULT_TIMEOUT_SECONDS: u64 = 30;

#[derive(Clone, Default)]
pub struct FlightSqlServiceImpl {}
#[derive(Clone)]
pub struct FlightSqlServiceImpl {
/// Whether to emulate stateless (true) or stateful (false) behavior for
/// prepared statements. stateful servers will not return an updated
/// handle after executing `DoPut(CommandPreparedStatementQuery)`
stateless_prepared_statements: bool,
}

impl Default for FlightSqlServiceImpl {
fn default() -> Self {
Self {
stateless_prepared_statements: true,
}
}
}

impl FlightSqlServiceImpl {
/// Return an [`FlightServiceServer`] that can be used with a
Expand Down Expand Up @@ -274,10 +303,17 @@ impl FlightSqlService for FlightSqlServiceImpl {
cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
assert_eq!(
cmd.prepared_statement_handle,
PREPARED_STATEMENT_HANDLE.as_bytes()
);
if self.stateless_prepared_statements {
assert_eq!(
cmd.prepared_statement_handle,
UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes()
);
} else {
assert_eq!(
cmd.prepared_statement_handle,
PREPARED_STATEMENT_HANDLE.as_bytes()
);
}
let resp = Response::new(self.fake_flight_info().unwrap());
Ok(resp)
}
Expand Down Expand Up @@ -524,7 +560,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
&self,
_query: CommandPreparedStatementQuery,
request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
// just make sure decoding the parameters works
let parameters = FlightRecordBatchStream::new_from_flight_data(
request.into_inner().map_err(|e| e.into()),
Expand All @@ -543,10 +579,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
)));
}
}

Ok(Response::new(
futures::stream::once(async { Ok(PutResult::default()) }).boxed(),
))
let handle = if self.stateless_prepared_statements {
UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into()
} else {
PREPARED_STATEMENT_HANDLE.to_string().into()
};
let result = DoPutPreparedStatementResult {
prepared_statement_handle: Some(handle),
};
Ok(result)
}

async fn do_put_prepared_statement_update(
Expand Down
23 changes: 22 additions & 1 deletion format/FlightSql.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,28 @@
// an unknown updated record count.
int64 record_count = 1;
}


/* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`.
*
* *Note on legacy behavior*: previous versions of the protocol did not return any result for
* this command, and that behavior should still be supported by clients. See documentation
* of individual fields for more details on expected client behavior in this case.
*/
message DoPutPreparedStatementResult {
option (experimental) = true;

// Represents a (potentially updated) opaque handle for the prepared statement on the server.
// Because the handle could potentially be updated, any previous handles for this prepared
// statement should be considered invalid, and all subsequent requests for this prepared
// statement must use this new handle, if specified.
// The updated handle allows implementing query parameters with stateless services
// as described in https://github.com/apache/arrow/issues/37720.
//
// When an updated handle is not provided by the server, clients should contiue
// using the previous handle provided by `ActionCreatePreparedStatementResonse`.
optional bytes prepared_statement_handle = 1;
}

/*
* Request message for the "CancelQuery" action.
*
Expand Down
Loading