Skip to content

Commit

Permalink
Start setting up flightsql pagination (#163)
Browse files Browse the repository at this point in the history
This adds basic pagination for FlightSQL. Only basic because for this
initial implementation it only paginates over records in a record batch.
It does not yet handle paginating over multiple record batches. The
structure is in place for that (with `take_record_batches` but
implementation details need to be figured out.
  • Loading branch information
matthewmturner authored Oct 5, 2024
1 parent a35c340 commit a293817
Show file tree
Hide file tree
Showing 16 changed files with 698 additions and 279 deletions.
29 changes: 28 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,39 @@ jobs:
run: |
# All features except FlightSQL which requires being run on a single thread for determinism
cargo test --features=deltalake,s3,functions-json
test-flightsql:
name: Test FlightSQL on AMD64 Rust ${{ matrix.rust }}
runs-on: ubuntu-latest
strategy:
matrix:
arch: [amd64]
rust: [stable]
steps:
- uses: actions/checkout@v2
with:
submodules: true
- name: Cache Cargo
uses: actions/cache@v2
with:
path: /home/runner/.cargo
key: cargo-dft-cache-
- name: Cache Rust dependencies
uses: actions/cache@v2
with:
path: /home/runner/target
key: target-dft-cache-
- name: Setup Rust toolchain
run: |
rustup toolchain install ${{ matrix.rust }}
rustup default ${{ matrix.rust }}
rustup component add rustfmt
- name: Run FlightSQL tests
run: |
# Single thread needed because we spin up a server that listens on port and we need each
# test to only be run against the server spun up in that test. With parallelism tests
# can connec to server in different test which breaks determinism.
cargo test --features=flightsql cli_cases::flightsql -- --test-threads=1
cargo test --features=flightsql -- --test-threads=1
fmt:
name: Rust formatting
runs-on: ubuntu-latest
Expand Down
150 changes: 146 additions & 4 deletions src/app/app_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ use log::{error, info};
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;
#[cfg(feature = "flightsql")]
use tokio_stream::StreamMap;

#[cfg(feature = "flightsql")]
use {
crate::config::FlightSQLConfig, arrow_flight::decode::FlightRecordBatchStream,
arrow_flight::sql::client::FlightSqlServiceClient, arrow_flight::Ticket,
tonic::transport::Channel, tonic::IntoRequest,
};

/// Handles executing queries for the TUI application, formatting results
/// and sending them to the UI.
///
/// TODO: I think we want to store the SQL associated with a stream
pub struct AppExecution {
inner: Arc<ExecutionContext>,
result_stream: Arc<Mutex<Option<SendableRecordBatchStream>>>,
#[cfg(feature = "flightsql")]
flightsql_result_stream: Arc<Mutex<Option<StreamMap<String, FlightRecordBatchStream>>>>,
}

impl AppExecution {
Expand All @@ -42,6 +55,8 @@ impl AppExecution {
Self {
inner,
result_stream: Arc::new(Mutex::new(None)),
#[cfg(feature = "flightsql")]
flightsql_result_stream: Arc::new(Mutex::new(None)),
}
}

Expand All @@ -54,8 +69,32 @@ impl AppExecution {
*s = Some(stream)
}

/// Run the sequence of SQL queries, sending the results as [`AppEvent::QueryResult`] via the sender.
///
#[cfg(feature = "flightsql")]
pub async fn set_flightsql_result_stream(
&self,
ticket: Ticket,
stream: FlightRecordBatchStream,
) {
let mut s = self.flightsql_result_stream.lock().await;
if let Some(ref mut streams) = *s {
streams.insert(ticket.to_string(), stream);
} else {
let mut map: StreamMap<String, FlightRecordBatchStream> = StreamMap::new();
let t = ticket.to_string();
info!("Adding {t} to FlightSQL streams");
map.insert(ticket.to_string(), stream);
*s = Some(map);
}
}

#[cfg(feature = "flightsql")]
pub async fn reset_flightsql_result_stream(&self) {
let mut s = self.flightsql_result_stream.lock().await;
*s = None;
}

/// Run the sequence of SQL queries, sending the results as
/// [`AppEvent::ExecutionResultsBatch`].
/// All queries except the last one will have their results discarded.
///
/// Error handling: If an error occurs while executing a query, the error is
Expand Down Expand Up @@ -93,7 +132,7 @@ impl AppExecution {
batch: b,
duration,
};
sender.send(AppEvent::ExecutionResultsNextPage(
sender.send(AppEvent::ExecutionResultsNextBatch(
results,
))?;
}
Expand Down Expand Up @@ -141,6 +180,99 @@ impl AppExecution {
Ok(())
}

#[cfg(feature = "flightsql")]
pub async fn run_flightsqls(
self: Arc<Self>,
sqls: Vec<String>,
sender: UnboundedSender<AppEvent>,
) -> Result<()> {
info!("Running sqls: {:?}", sqls);
self.reset_flightsql_result_stream().await;
let non_empty_sqls: Vec<String> = sqls.into_iter().filter(|s| !s.is_empty()).collect();
let statement_count = non_empty_sqls.len();
for (i, sql) in non_empty_sqls.into_iter().enumerate() {
let _sender = sender.clone();
if i == statement_count - 1 {
info!("Executing last query and display results");
sender.send(AppEvent::FlightSQLNewExecution)?;
if let Some(ref mut client) = *self.flightsql_client().lock().await {
let start = std::time::Instant::now();
match client.execute(sql.clone(), None).await {
Ok(flight_info) => {
for endpoint in flight_info.endpoint {
if let Some(ticket) = endpoint.ticket {
match client.do_get(ticket.clone().into_request()).await {
Ok(stream) => {
self.set_flightsql_result_stream(ticket, stream).await;
if let Some(streams) =
self.flightsql_result_stream.lock().await.as_mut()
{
match streams.next().await {
Some((ticket, Ok(batch))) => {
info!("Received batch for {ticket}");
let duration = start.elapsed();
let results = ExecutionResultsBatch {
batch,
duration,
query: sql.to_string(),
};
sender.send(
AppEvent::FlightSQLExecutionResultsNextBatch(
results,
),
)?;
}
Some((ticket, Err(e))) => {
error!(
"Error executing stream for ticket {ticket}: {:?}",
e
);
let elapsed = start.elapsed();
let e = ExecutionError {
query: sql.to_string(),
error: e.to_string(),
duration: elapsed,
};
sender.send(
AppEvent::FlightSQLExecutionResultsError(e),
)?;
}
None => {}
}
}
}
Err(e) => {
error!("Error creating result stream: {:?}", e);
let elapsed = start.elapsed();
let e = ExecutionError {
query: sql.to_string(),
error: e.to_string(),
duration: elapsed,
};
sender.send(AppEvent::ExecutionResultsError(e))?;
}
}
}
}
}
Err(e) => {
error!("Error getting flight info: {:?}", e);
let elapsed = start.elapsed();
let e = ExecutionError {
query: sql.to_string(),
error: e.to_string(),
duration: elapsed,
};
sender.send(AppEvent::FlightSQLExecutionResultsError(e))?;
}
}
}
}
}

Ok(())
}

pub async fn next_batch(&self, sql: String, sender: UnboundedSender<AppEvent>) {
let mut stream = self.result_stream.lock().await;
if let Some(s) = stream.as_mut() {
Expand All @@ -154,7 +286,7 @@ impl AppExecution {
batch: b,
duration,
};
let _ = sender.send(AppEvent::ExecutionResultsNextPage(results));
let _ = sender.send(AppEvent::ExecutionResultsNextBatch(results));
}
Err(e) => {
error!("Error getting RecordBatch: {:?}", e);
Expand All @@ -163,4 +295,14 @@ impl AppExecution {
}
}
}

#[cfg(feature = "flightsql")]
pub async fn create_flightsql_client(&self, config: FlightSQLConfig) -> Result<()> {
self.inner.create_flightsql_client(config).await
}

#[cfg(feature = "flightsql")]
pub fn flightsql_client(&self) -> &Mutex<Option<FlightSqlServiceClient<Channel>>> {
self.inner.flightsql_client()
}
}
100 changes: 22 additions & 78 deletions src/app/handlers/flightsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
// under the License.

use std::sync::Arc;
use std::time::{Duration, Instant};

use datafusion::arrow::array::RecordBatch;
use log::{error, info};
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use tokio_stream::StreamExt;
use tonic::IntoRequest;

use crate::app::state::tabs::flightsql::FlightSQLQuery;
use crate::app::{handlers::tab_navigation_handler, AppEvent};

use super::App;
Expand Down Expand Up @@ -66,74 +61,28 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) {
}
}

// KeyCode::Enter => {
// info!("Run FS query");
// let sql = app.state.flightsql_tab.editor().lines().join("");
// info!("SQL: {}", sql);
// let execution = Arc::clone(&app.execution);
// let _event_tx = app.event_tx();
// tokio::spawn(async move {
// let client = execution.flightsql_client();
// let mut query =
// FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None);
// let start = Instant::now();
// if let Some(ref mut c) = *client.lock().await {
// info!("Sending query");
// match c.execute(sql, None).await {
// Ok(flight_info) => {
// for endpoint in flight_info.endpoint {
// if let Some(ticket) = endpoint.ticket {
// match c.do_get(ticket.into_request()).await {
// Ok(mut stream) => {
// let mut batches: Vec<RecordBatch> = Vec::new();
// // temporarily only show the first batch to avoid
// // buffering massive result sets. Eventually there should
// // be some sort of paging logic
// // see https://github.com/datafusion-contrib/datafusion-tui/pull/133#discussion_r1756680874
// // while let Some(maybe_batch) = stream.next().await {
// if let Some(maybe_batch) = stream.next().await {
// match maybe_batch {
// Ok(batch) => {
// info!("Batch rows: {}", batch.num_rows());
// batches.push(batch);
// }
// Err(e) => {
// error!("Error getting batch: {:?}", e);
// let elapsed = start.elapsed();
// query.set_error(Some(e.to_string()));
// query.set_execution_time(elapsed);
// }
// }
// }
// let elapsed = start.elapsed();
// let rows: usize =
// batches.iter().map(|r| r.num_rows()).sum();
// query.set_results(Some(batches));
// query.set_num_rows(Some(rows));
// query.set_execution_time(elapsed);
// }
// Err(e) => {
// error!("Error getting response: {:?}", e);
// let elapsed = start.elapsed();
// query.set_error(Some(e.to_string()));
// query.set_execution_time(elapsed);
// }
// }
// }
// }
// }
// Err(e) => {
// error!("Error getting response: {:?}", e);
// let elapsed = start.elapsed();
// query.set_error(Some(e.to_string()));
// query.set_execution_time(elapsed);
// }
// }
// }
//
// let _ = _event_tx.send(AppEvent::FlightSQLQueryResult(query));
// });
// }
KeyCode::Enter => {
info!("Executing FlightSQL query");
let sql = app.state.flightsql_tab.editor().lines().join("");
info!("SQL: {}", sql);
let sqls: Vec<String> = sql.split(';').map(|s| s.to_string()).collect();
let execution = Arc::clone(&app.execution);
let _event_tx = app.event_tx();
let handle = tokio::spawn(execution.run_flightsqls(sqls, _event_tx));
app.state.flightsql_tab.set_execution_task(handle);
}
KeyCode::Right => {
let _event_tx = app.event_tx();
if let Err(e) = _event_tx.send(AppEvent::FlightSQLExecutionResultsNextPage) {
error!("Error going to next FlightSQL results page: {e}");
}
}
KeyCode::Left => {
let _event_tx = app.event_tx();
if let Err(e) = _event_tx.send(AppEvent::FlightSQLExecutionResultsPreviousPage) {
error!("Error going to previous FlightSQL results page: {e}");
}
}
_ => {}
}
}
Expand All @@ -154,11 +103,6 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) {
true => editable_handler(app, key),
false => normal_mode_handler(app, key),
},
AppEvent::FlightSQLQueryResult(r) => {
info!("Query results: {:?}", r);
app.state.flightsql_tab.set_query(r);
app.state.flightsql_tab.refresh_query_results_state();
}
AppEvent::Error => {}
_ => {}
};
Expand Down
Loading

0 comments on commit a293817

Please sign in to comment.