Skip to content

Commit

Permalink
Don't include SimpleFiles in errors
Browse files Browse the repository at this point in the history
This is a big refactoring needed to support much nicer errors during
type inference. We need to make Span and FileId have effectively global
scope.
  • Loading branch information
emk committed Oct 26, 2023
1 parent 0e9d5c9 commit b597a29
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 100 deletions.
20 changes: 8 additions & 12 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@ use std::{
fmt::{self},
io::{self, Write as _},
mem::take,
path::Path,
};

use codespan_reporting::{
diagnostic::{Diagnostic, Label},
files::SimpleFiles,
};
use codespan_reporting::diagnostic::{Diagnostic, Label};
use derive_visitor::{Drive, DriveMut};
use joinery_macros::{Emit, EmitDefault, Spanned, ToTokens};

Expand All @@ -39,6 +35,7 @@ use crate::{
trino::{TrinoString, KEYWORDS as TRINO_KEYWORDS},
},
errors::{Result, SourceError},
known_files::{FileId, KnownFiles},
tokenizer::{
tokenize_sql, EmptyFile, Ident, Keyword, Literal, LiteralValue, PseudoKeyword, Punct,
RawToken, Spanned, ToTokens, Token, TokenStream, TokenWriter,
Expand Down Expand Up @@ -1673,10 +1670,8 @@ pub struct IfExists {
}

/// Parse BigQuery SQL.
pub fn parse_sql(filename: &Path, sql: &str) -> Result<SqlProgram> {
let mut files = SimpleFiles::new();
let file_id = files.add(filename.to_string_lossy().into_owned(), sql.to_string());
let token_stream = tokenize_sql(&files, file_id)?;
pub fn parse_sql(files: &KnownFiles, file_id: FileId) -> Result<SqlProgram> {
let token_stream = tokenize_sql(files, file_id)?;
//println!("token_stream = {:?}", token_stream);

// Parse with or without tracing, as appropriate. The tracing code throws
Expand Down Expand Up @@ -1707,7 +1702,6 @@ pub fn parse_sql(filename: &Path, sql: &str) -> Result<SqlProgram> {
};
Err(SourceError {
expected: e.to_string(),
files,
diagnostic,
}
.into())
Expand Down Expand Up @@ -2972,13 +2966,15 @@ CREATE OR REPLACE TABLE `project-123`.proxies.t2 AS (
.await
.expect("failed to create SQLite3 fixtures");

let mut files = KnownFiles::new();
for &(sql, normalized) in sql_examples {
println!("parsing: {}", sql);
let file_id = files.add_string("test.sql", sql);
let normalized = normalized.unwrap_or(sql);
let parsed = match parse_sql(Path::new("test.sql"), sql) {
let parsed = match parse_sql(&files, file_id) {
Ok(parsed) => parsed,
Err(err) => {
err.emit();
err.emit(&files);
panic!("{}", err);
}
};
Expand Down
8 changes: 5 additions & 3 deletions src/cmd/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
analyze::FunctionCallCounts,
ast::{self},
errors::{Context, Result},
known_files::KnownFiles,
};

/// Parse SQL from a CSV file containing `id` and `query` columns.
Expand All @@ -34,7 +35,7 @@ struct Row {

/// Parse queries from a CSV file.
#[instrument(skip(opt))]
pub fn cmd_parse(opt: &ParseOpt) -> Result<()> {
pub fn cmd_parse(files: &mut KnownFiles, opt: &ParseOpt) -> Result<()> {
// Keep track of how many rows we've processed and how many queries we've
// successfully parsed.
let mut row_count = 0;
Expand Down Expand Up @@ -64,7 +65,8 @@ pub fn cmd_parse(opt: &ParseOpt) -> Result<()> {
}

// Parse query.
match ast::parse_sql(&opt.csv_path, &row.query) {
let file_id = files.add_string(&opt.csv_path, &row.query);
match ast::parse_sql(files, file_id) {
Ok(sql_program) => {
ok_count += 1;
ok_line_count += row.query.lines().count();
Expand All @@ -75,7 +77,7 @@ pub fn cmd_parse(opt: &ParseOpt) -> Result<()> {
}
Err(e) => {
println!("ERR {}", row.id);
e.emit();
e.emit(files);
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions src/cmd/sql_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
ast::{self, parse_sql, CreateTableStatement, CreateViewStatement, Target},
drivers::{self, Driver},
errors::{format_err, Context, Error, Result},
known_files::{FileId, KnownFiles},
};

/// Run SQL tests from a directory.
Expand All @@ -36,7 +37,7 @@ pub struct SqlTestOpt {

/// Run our SQL test suite.
#[instrument(skip(opt))]
pub async fn cmd_sql_test(opt: &SqlTestOpt) -> Result<()> {
pub async fn cmd_sql_test(files: &mut KnownFiles, opt: &SqlTestOpt) -> Result<()> {
// Get a database driver for our target.
let locator = opt.database.parse::<Box<dyn drivers::Locator>>()?;

Expand Down Expand Up @@ -68,13 +69,14 @@ pub async fn cmd_sql_test(opt: &SqlTestOpt) -> Result<()> {
let path = entry.context("Failed to read test file")?;

// Read file.
let query = std::fs::read_to_string(&path).context("Failed to read test file")?;
let file_id = files.add(&path)?;
let sql = files.source_code(file_id)?;

// Skip pending tests unless asked to run them.
if !opt.pending {
let short_path = path.strip_prefix(&base_dir).unwrap_or(&path);
if let Some(pending_test_info) =
PendingTestInfo::for_target(locator.target(), short_path, &query)
PendingTestInfo::for_target(locator.target(), short_path, sql)
{
progress('P');
pending.push(pending_test_info);
Expand All @@ -84,7 +86,7 @@ pub async fn cmd_sql_test(opt: &SqlTestOpt) -> Result<()> {

// Test query.
let mut driver = locator.driver().await?;
match run_test(&mut *driver, &path, &query).await {
match run_test(&mut *driver, files, file_id).await {
Ok(_) => {
progress('.');
test_ok_count += 1;
Expand All @@ -99,7 +101,7 @@ pub async fn cmd_sql_test(opt: &SqlTestOpt) -> Result<()> {

for (i, (path, e)) in test_failures.iter().enumerate() {
println!("\n{} {}: {}", "FAILED".red(), i + 1, path.display());
e.emit();
e.emit(files);
}

if !pending.is_empty() {
Expand Down Expand Up @@ -135,13 +137,13 @@ fn progress(c: char) {
let _ = io::stdout().flush();
}

#[instrument(skip_all, fields(path = %path.display()))]
#[instrument(skip_all)]
async fn run_test(
driver: &mut dyn Driver,
path: &Path,
sql: &str,
files: &mut KnownFiles,
file_id: FileId,
) -> std::result::Result<(), Error> {
let ast = parse_sql(path, sql)?;
let ast = parse_sql(files, file_id)?;
//eprintln!("SQLite3: {}", ast.emit_to_string(Target::SQLite3));
let output_tables = find_output_tables(&ast)?;

Expand Down
11 changes: 5 additions & 6 deletions src/cmd/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use tracing::instrument;
use crate::{
ast::{parse_sql, Emit},
drivers,
errors::{Context, Result},
errors::Result,
known_files::KnownFiles,
};

/// Run SQL tests from a directory.
Expand All @@ -25,16 +26,14 @@ pub struct TranspileOpt {

/// Run our SQL test suite.
#[instrument(skip(opt))]
pub async fn cmd_transpile(opt: &TranspileOpt) -> Result<()> {
pub async fn cmd_transpile(files: &mut KnownFiles, opt: &TranspileOpt) -> Result<()> {
// Get a database driver for our target.
let locator = opt.database.parse::<Box<dyn drivers::Locator>>()?;
let driver = locator.driver().await?;

// Parse our SQL.
let sql = tokio::fs::read_to_string(&opt.sql_path)
.await
.with_context(|| format!("could not read SQL file {}", opt.sql_path.display()))?;
let ast = parse_sql(&opt.sql_path, &sql)?;
let file_id = files.add(&opt.sql_path)?;
let ast = parse_sql(files, file_id)?;
let rewritten_ast = driver.rewrite_ast(&ast)?;

// Print our rewritten AST.
Expand Down
14 changes: 7 additions & 7 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ use anstream::eprintln;
use async_rusqlite::AlreadyClosed;
use codespan_reporting::{
diagnostic::Diagnostic,
files::SimpleFiles,
term::{
self,
termcolor::{ColorChoice, StandardStream},
},
};
use owo_colors::OwoColorize;

use crate::known_files::{FileId, KnownFiles};

/// Our standard result type.
pub type Result<T, E = Error> = result::Result<T, E>;

Expand Down Expand Up @@ -57,10 +58,10 @@ impl Error {

/// Emit this error to stderr. This does extra formatting for `SourceError`,
/// with colors and source code snippets.
pub fn emit(&self) {
pub fn emit(&self, files: &KnownFiles) {
match self {
Error::Source(e) => {
e.emit();
e.emit(files);
}
_ => {
let first = if self.is_transparent() {
Expand Down Expand Up @@ -172,15 +173,14 @@ where
#[derive(Debug)]
pub struct SourceError {
pub expected: String,
pub files: SimpleFiles<String, String>,
pub diagnostic: Diagnostic<usize>,
pub diagnostic: Diagnostic<FileId>,
}

impl SourceError {
pub fn emit(&self) {
pub fn emit(&self, files: &KnownFiles) {
let writer = StandardStream::stderr(ColorChoice::Auto);
let config = term::Config::default();
term::emit(&mut writer.lock(), &config, &self.files, &self.diagnostic)
term::emit(&mut writer.lock(), &config, files, &self.diagnostic)
.expect("could not write to stderr");
}
}
Expand Down
28 changes: 15 additions & 13 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,29 +434,31 @@ impl InferColumnName for ast::Alias {

#[cfg(test)]
mod tests {
use std::path::Path;

use pretty_assertions::assert_eq;

use crate::{
ast::parse_sql,
known_files::KnownFiles,
scope::{CaseInsensitiveIdent, Scope, ScopeValue},
tokenizer::Span,
types::tests::ty,
};

use super::*;

fn infer(sql: &str) -> Result<(Option<TableType>, ScopeHandle)> {
let mut program = match parse_sql(Path::new("test.sql"), sql) {
fn infer(sql: &str) -> Result<(KnownFiles, Option<TableType>, ScopeHandle)> {
let mut files = KnownFiles::new();
let file_id = files.add_string("test.sql", sql);
let mut program = match parse_sql(&files, file_id) {
Ok(program) => program,
Err(e) => {
e.emit();
e.emit(&files);
panic!("parse error");
}
};
let scope = Scope::root();
program.infer_types(&scope)
let (ty, scope) = program.infer_types(&scope)?;
Ok((files, ty, scope))
}

fn lookup(scope: &ScopeHandle, name: &str) -> Option<ScopeValue> {
Expand All @@ -479,27 +481,27 @@ mod tests {

#[test]
fn root_scope_defines_functions() {
let (_, scope) = infer("SELECT 1 AS x").unwrap();
let (_, _, scope) = infer("SELECT 1 AS x").unwrap();
assert_defines!(scope, "LOWER", "Fn(STRING) -> STRING");
assert_defines!(scope, "lower", "Fn(STRING) -> STRING");
assert_not_defines!(scope, "NO_SUCH_FUNCTION");
}

#[test]
fn create_table_adds_table_to_scope() {
let (_, scope) = infer("CREATE TABLE foo (x INT64, y STRING)").unwrap();
let (_, _, scope) = infer("CREATE TABLE foo (x INT64, y STRING)").unwrap();
assert_defines!(scope, "foo", "TABLE<x INT64, y STRING>");
}

#[test]
fn drop_table_removes_table_from_scope() {
let (_, scope) = infer("CREATE TABLE foo (x INT64, y STRING); DROP TABLE foo").unwrap();
let (_, _, scope) = infer("CREATE TABLE foo (x INT64, y STRING); DROP TABLE foo").unwrap();
assert_not_defines!(scope, "foo");
}

#[test]
fn create_table_as_infers_column_types() {
let (_, scope) = infer("CREATE TABLE foo AS SELECT 'a' AS x, TRUE AS y").unwrap();
let (_, _, scope) = infer("CREATE TABLE foo AS SELECT 'a' AS x, TRUE AS y").unwrap();
assert_defines!(scope, "foo", "TABLE<x STRING, y BOOL>");
}

Expand All @@ -511,13 +513,13 @@ WITH
t1 AS (SELECT 'a' AS x),
t2 AS (SELECT x FROM t1)
SELECT x FROM t2";
let (_, scope) = infer(sql).unwrap();
let (_, _, scope) = infer(sql).unwrap();
assert_defines!(scope, "foo", "TABLE<x STRING>");
}

#[test]
fn anon_and_aliased_columns() {
let (_, scope) = infer("CREATE TABLE foo AS SELECT 1, 2 AS x, 3").unwrap();
let (_, _, scope) = infer("CREATE TABLE foo AS SELECT 1, 2 AS x, 3").unwrap();
assert_defines!(scope, "foo", "TABLE<_f0 INT64, x INT64, _f1 INT64>");
}

Expand All @@ -527,7 +529,7 @@ SELECT x FROM t2";
CREATE TABLE foo AS
WITH t AS (SELECT 'a' AS x)
SELECT t.x FROM t";
let (_, scope) = infer(sql).unwrap();
let (_, _, scope) = infer(sql).unwrap();
assert_defines!(scope, "foo", "TABLE<x STRING>");
}
}
Loading

0 comments on commit b597a29

Please sign in to comment.