Skip to content

Commit

Permalink
trino: Show source location for query errors
Browse files Browse the repository at this point in the history
This will make it easier to find errors in large transpiled queries.
  • Loading branch information
emk committed Nov 7, 2023
1 parent d0beb1d commit 6c65082
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
73 changes: 61 additions & 12 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
//! Trino and maybe Presto driver.

use std::{fmt, str::FromStr};
use std::{fmt, str::FromStr, sync::Arc};

use async_trait::async_trait;
use codespan_reporting::{diagnostic::Diagnostic, files::Files};
use once_cell::sync::Lazy;
use prusto::{Client, ClientBuilder, Presto, Row};
use prusto::{error::Error as PrustoError, Client, ClientBuilder, Presto, QueryError, Row};
use regex::Regex;
use tracing::debug;

use crate::{
ast::Target,
errors::{format_err, Context, Error, Result},
errors::{format_err, Context, Error, Result, SourceError},
known_files::KnownFiles,
transforms::{self, Transform, Udf},
util::AnsiIdent,
};
Expand Down Expand Up @@ -161,8 +163,7 @@ impl Driver for TrinoDriver {
self.client
.execute(sql.to_owned())
.await
.map_err(abbreviate_trino_error)
.with_context(|| format!("Failed to execute SQL: {}", sql))?;
.map_err(|err| abbreviate_trino_error(sql, err))?;
Ok(())
}

Expand Down Expand Up @@ -193,10 +194,11 @@ impl Driver for TrinoDriver {

#[tracing::instrument(skip(self))]
async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> {
let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name));
self.client
.execute(format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)))
.execute(sql.clone())
.await
.map_err(abbreviate_trino_error)
.map_err(|err| abbreviate_trino_error(&sql, err))
.with_context(|| format!("Failed to drop table: {}", table_name))?;
Ok(())
}
Expand Down Expand Up @@ -233,9 +235,9 @@ impl DriverImpl for TrinoDriver {
);
Ok(self
.client
.get_all::<Col>(sql)
.get_all::<Col>(sql.clone())
.await
.map_err(abbreviate_trino_error)
.map_err(|err| abbreviate_trino_error(&sql, err))
.with_context(|| format!("Failed to get columns for table: {}", table_name))?
.into_vec()
.into_iter()
Expand Down Expand Up @@ -265,9 +267,9 @@ impl DriverImpl for TrinoDriver {
);
let rows = self
.client
.get_all::<Row>(sql)
.get_all::<Row>(sql.clone())
.await
.map_err(abbreviate_trino_error)
.map_err(|err| abbreviate_trino_error(&sql, err))
.with_context(|| format!("Failed to query table: {}", table_name))?
.into_vec()
.into_iter()
Expand Down Expand Up @@ -313,7 +315,54 @@ impl fmt::Display for TrinoString<'_> {
}

/// These errors are pages long.
fn abbreviate_trino_error(e: prusto::error::Error) -> Error {
fn abbreviate_trino_error(sql: &str, e: PrustoError) -> Error {
if let PrustoError::QueryError(e) = &e {
// We can make these look pretty.
let QueryError {
message,
error_code,
error_location,
..
} = e;
let mut files = KnownFiles::default();
let file_id = files.add_string("trino.sql", sql);

let offset = if let Some(loc) = error_location {
// We don't want to panic, because we're already processing an
// error, and the error comes from an external source. So just
// muddle through and return Span::Unknown or a bogus location
// if our input data is too odd.
//
// Convert from u32, defaulting negative values to 1. (Although
// lines count from 1.)
let line_number = usize::try_from(loc.line_number).unwrap_or(0);
let column_number = usize::try_from(loc.column_number).unwrap_or(0);
files
.line_range(file_id, line_number.saturating_sub(1))
.ok()
.map(|r| r.start + column_number.saturating_sub(1))
} else {
None
};

if let Some(offset) = offset {
let diagnostic = Diagnostic::error()
.with_message(message.clone())
.with_code(format!("TRINO {}", error_code))
.with_labels(vec![codespan_reporting::diagnostic::Label::primary(
file_id,
offset..offset,
)
.with_message("Trino error")]);

return Error::Source(Box::new(SourceError {
alternate_summary: message.clone(),
diagnostic,
files_override: Some(Arc::new(files)),
}));
}
}

let msg = e
.to_string()
.lines()
Expand Down
8 changes: 8 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::{
error::{self, Error as _},
fmt, result,
sync::Arc,
};

use anstream::eprintln;
Expand Down Expand Up @@ -190,6 +191,10 @@ where
pub struct SourceError {
pub alternate_summary: String,
pub diagnostic: Diagnostic<FileId>,
/// If you're not using the standard set of known files, perhaps because
/// you're in a database driver, you can override the [`KnownFiles`] used to
/// display this error.
pub files_override: Option<Arc<KnownFiles>>,
}

impl SourceError {
Expand All @@ -212,13 +217,15 @@ impl SourceError {
SourceError {
alternate_summary,
diagnostic,
files_override: None,
}
} else {
let alternate_summary = format!("{} (at unknown location): {}", summary, annotation);
let diagnostic = Diagnostic::error().with_message(alternate_summary.clone());
SourceError {
alternate_summary,
diagnostic,
files_override: None,
}
}
}
Expand All @@ -227,6 +234,7 @@ impl SourceError {
pub fn emit(&self, files: &KnownFiles) {
let writer = StandardStream::stderr(ColorChoice::Auto);
let config = term::Config::default();
let files = self.files_override.as_deref().unwrap_or(files);
term::emit(&mut writer.lock(), &config, files, &self.diagnostic)
.expect("could not write to stderr");
}
Expand Down

0 comments on commit 6c65082

Please sign in to comment.