Skip to content

Commit

Permalink
trino: Implement most date functions
Browse files Browse the repository at this point in the history
We don't implement date formatting or date array generation, but this
should contain everything else.

Co-authored-by: Dave Shirley <[email protected]>
  • Loading branch information
emk and dave-shirley-faraday committed Nov 10, 2023
1 parent e5aa307 commit 0e56c8b
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 22 deletions.
35 changes: 21 additions & 14 deletions sql/trino_compat.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
-- Compatibility functions for Trino, needed to run code generated by Joinery.

CREATE FUNCTION memory.joinery_compat.DATETIME(date VARCHAR)
RETURNS TIMESTAMP
RETURNS NULL ON NULL INPUT
RETURN (
CAST(date AS TIMESTAMP)
);

CREATE FUNCTION memory.joinery_compat.DATETIME(date DATE)
RETURNS TIMESTAMP
RETURNS NULL ON NULL INPUT
RETURN (
CAST(date AS TIMESTAMP)
);

CREATE FUNCTION memory.joinery_compat.GENERATE_UUID()
RETURNS VARCHAR
NOT DETERMINISTIC
Expand All @@ -9,24 +23,17 @@ RETURN (

CREATE FUNCTION memory.joinery_compat.SHA256_COMPAT(input VARCHAR)
RETURNS VARBINARY
CALLED ON NULL INPUT
RETURNS NULL ON NULL INPUT
RETURN (
SHA256(TO_UTF8(input))
);

-- This is how we'd define a version for BINARY values, but doing so causes
-- `SHA256_COMPAT(NULL)` to fail with:
--
-- > Could not choose a best candidate operator. Explicit type casts must be
-- > added.
--
-- `CALLED ON NULL INPUT` doesn't seem to help?
--
-- CREATE FUNCTION memory.joinery_compat.SHA256_COMPAT(input VARBINARY)
-- RETURNS VARBINARY
-- RETURN (
-- SHA256(input)
-- );
CREATE FUNCTION memory.joinery_compat.SHA256_COMPAT(input VARBINARY)
RETURNS VARBINARY
RETURNS NULL ON NULL INPUT
RETURN (
SHA256(input)
);

CREATE FUNCTION memory.joinery_compat.TO_HEX_COMPAT(input VARBINARY)
RETURNS VARCHAR
Expand Down
51 changes: 48 additions & 3 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ pub enum Expression {
Not(NotExpression),
If(IfExpression),
Case(CaseExpression),
Unary(UnaryExpression),
Binop(BinopExpression),
Query {
paren1: Punct,
Expand Down Expand Up @@ -910,6 +911,14 @@ pub struct DatePart {
pub date_part_token: PseudoKeyword,
}

impl DatePart {
/// Get a lowercase ASCII string.
pub fn to_literal(&self) -> Literal {
let s = self.date_part_token.ident.name.to_ascii_lowercase();
Literal::string(&s, self.date_part_token.span())
}
}

/// A cast expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct Cast {
Expand Down Expand Up @@ -1060,6 +1069,13 @@ pub struct CaseExpression {
pub end_token: Keyword,
}

/// A unary operator.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct UnaryExpression {
pub op_token: Punct,
pub expression: Box<Expression>,
}

/// A binary operator.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct BinopExpression {
Expand Down Expand Up @@ -1341,6 +1357,32 @@ pub enum SpecialDateExpression {
DatePart(DatePart),
}

impl SpecialDateExpression {
/// If this contains an expression, return it.
pub fn try_as_expression(&self) -> Option<&Expression> {
match self {
SpecialDateExpression::Expression(expression) => Some(expression),
_ => None,
}
}

/// If this contains an interval expression, return it.
pub fn try_as_interval(&self) -> Option<&IntervalExpression> {
match self {
SpecialDateExpression::Interval(interval) => Some(interval),
_ => None,
}
}

/// If this contains a date part, return it.
pub fn try_as_date_part(&self) -> Option<&DatePart> {
match self {
SpecialDateExpression::DatePart(date_part) => Some(date_part),
_ => None,
}
}
}

/// An `ARRAY_AGG` expression.
#[derive(Clone, Debug, Drive, DriveMut, EmitDefault, Spanned, ToTokens)]
pub struct ArrayAggExpression {
Expand Down Expand Up @@ -2269,6 +2311,10 @@ peg::parser! {
})
}
--
op_token:p("-") right:@ {
Expression::Unary(UnaryExpression { op_token, expression: Box::new(right) })
}
--
expression:(@) dot:p(".") field_name:ident() {
Expression::FieldAccess(FieldAccessExpression {
expression: Box::new(expression),
Expand Down Expand Up @@ -2527,7 +2573,7 @@ peg::parser! {
}
}

rule special_date_function_call() -> SpecialDateFunctionCall
pub rule special_date_function_call() -> SpecialDateFunctionCall
= function_name:special_date_function_name() paren1:p("(")
args:sep(<special_date_expression()>, ",") paren2:p(")") {
SpecialDateFunctionCall {
Expand All @@ -2540,7 +2586,7 @@ peg::parser! {

rule special_date_function_name() -> PseudoKeyword
= pk("DATE_DIFF") / pk("DATE_TRUNC") / pk("DATE_ADD") / pk("DATE_SUB")
/ pk("DATETIME_DIFF") / pk("DATETIME_TRUNC") / pk("DATETIME_SUB")
/ pk("DATETIME_DIFF") / pk("DATETIME_TRUNC") / pk("DATETIME_ADD") / pk("DATETIME_SUB")

rule special_date_expression() -> SpecialDateExpression
= interval:interval_expression() { SpecialDateExpression::Interval(interval) }
Expand Down Expand Up @@ -3093,7 +3139,6 @@ mod tests {
(r"SELECT * FROM t WHERE a < 0.5", None),
(r"SELECT * FROM t WHERE a BETWEEN 1 AND 10", None),
(r"SELECT * FROM t WHERE a NOT BETWEEN 1 AND 10", None),
(r"SELECT INTERVAL -3 DAY", None),
(r"SELECT * FROM t WHERE a IN (1,2)", None),
(r"SELECT * FROM t WHERE a NOT IN (1,2)", None),
(r"SELECT * FROM t WHERE a IN (SELECT b FROM t)", None),
Expand Down
2 changes: 2 additions & 0 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! {
"ARRAY_LENGTH" => "CARDINALITY",
"ARRAY_TO_STRING" => "ARRAY_JOIN",
"CURRENT_DATETIME" => "CURRENT_TIMESTAMP",
"DATETIME" => "memory.joinery_compat.DATETIME",
"GENERATE_UUID" => "memory.joinery_compat.GENERATE_UUID",
"SHA256" => "memory.joinery_compat.SHA256_COMPAT",
"TO_HEX" => "memory.joinery_compat.TO_HEX_COMPAT",
Expand Down Expand Up @@ -188,6 +189,7 @@ impl Driver for TrinoDriver {
&UDFS,
&format_udf,
)),
Box::new(transforms::SpecialDateFunctionsToTrino),
Box::new(transforms::StandardizeCurrentTimeUnit::no_parens()),
Box::new(transforms::CleanUpTempManually {
format_name: &|table_name| AnsiIdent(&table_name.unescaped_bigquery()).to_string(),
Expand Down
7 changes: 7 additions & 0 deletions src/infer/contains_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl ContainsAggregate for ast::Expression {
ast::Expression::Not(not) => not.contains_aggregate(scope),
ast::Expression::If(if_expr) => if_expr.contains_aggregate(scope),
ast::Expression::Case(case) => case.contains_aggregate(scope),
ast::Expression::Unary(unary) => unary.contains_aggregate(scope),
ast::Expression::Binop(binop) => binop.contains_aggregate(scope),
// We never look into sub-queries.
ast::Expression::Query { .. } => false,
Expand Down Expand Up @@ -179,6 +180,12 @@ impl ContainsAggregate for ast::CaseElseClause {
}
}

impl ContainsAggregate for ast::UnaryExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.expression.contains_aggregate(scope)
}
}

impl ContainsAggregate for ast::BinopExpression {
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool {
self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope)
Expand Down
16 changes: 16 additions & 0 deletions src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ impl InferTypes for ast::Expression {
ast::Expression::Not(not) => not.infer_types(scope),
ast::Expression::If(if_expr) => if_expr.infer_types(scope),
ast::Expression::Case(case) => case.infer_types(scope),
ast::Expression::Unary(unary) => unary.infer_types(scope),
ast::Expression::Binop(binop) => binop.infer_types(scope),
ast::Expression::Query { query, .. } => {
let table_ty = query.infer_types(&scope.clone().try_into_handle_for_subquery()?)?;
Expand Down Expand Up @@ -958,6 +959,21 @@ impl InferTypes for ast::CaseExpression {
}
}

impl InferTypes for ast::UnaryExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let func_name = &Name::new(
&format!("%UNARY{}", self.op_token.token.as_str()),
self.op_token.span(),
);
let func_ty = scope.get_function_type(func_name)?;
let arg_ty = self.expression.infer_types(scope)?;
func_ty.return_type_for(&[arg_ty], false, func_name)
}
}

impl InferTypes for ast::BinopExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;
Expand Down
3 changes: 3 additions & 0 deletions src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ static BUILT_IN_FUNCTIONS: &str = "
%NOT = Fn(BOOL) -> BOOL;
%OR = Fn(BOOL, BOOL) -> BOOL;
%UNARY- = Fn(INT64) -> INT64 | Fn(FLOAT64) -> FLOAT64;
%= = Fn<?T>(?T, ?T) -> BOOL;
%!= = Fn<?T>(?T, ?T) -> BOOL;
%<= = Fn<?T>(?T, ?T) -> BOOL;
Expand Down Expand Up @@ -254,6 +256,7 @@ DATE_DIFF = Fn(DATE, DATE, DATEPART) -> INT64;
DATE_SUB = Fn(DATE, INTERVAL) -> DATE;
DATE_TRUNC = Fn(DATE, DATEPART) -> DATE;
DATETIME = Fn(STRING) -> DATETIME | Fn(DATE) -> DATETIME;
DATETIME_ADD = Fn(DATETIME, INTERVAL) -> DATETIME;
DATETIME_DIFF = Fn(DATETIME, DATETIME, DATEPART) -> INT64;
DATETIME_SUB = Fn(DATETIME, INTERVAL) -> DATETIME;
DATETIME_TRUNC = Fn(DATETIME, DATEPART) -> DATETIME;
Expand Down
13 changes: 13 additions & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ impl Ident {
name: name.to_owned(),
}
}

/// Create a new [`Ident`], overriding the string.
pub fn with_str(&self, s: &str) -> Self {
Self {
token: self.token.with_str(s),
name: s.to_owned(),
}
}
}

impl Spanned for Ident {
Expand Down Expand Up @@ -627,6 +635,11 @@ impl TokenStream {
self.try_into_parsed(ast::sql_program::expression)
}

/// Try to parse this stream as a [`ast::SpecialDateFunctionCall`].
pub fn try_into_special_date_function_call(self) -> Result<ast::SpecialDateFunctionCall> {
self.try_into_parsed(ast::sql_program::special_date_function_call)
}

/// Try to parse this stream as a [`ast::FunctionCall`].
#[allow(dead_code)]
pub fn try_into_function_call(self) -> Result<ast::FunctionCall> {
Expand Down
2 changes: 2 additions & 0 deletions src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use self::{
or_replace_to_drop_if_exists::OrReplaceToDropIfExists,
qualify_to_subquery::QualifyToSubquery,
rename_functions::{RenameFunctions, Udf},
special_date_functions_to_trino::SpecialDateFunctionsToTrino,
standardize_current_time_unit::StandardizeCurrentTimeUnit,
wrap_nested_queries::WrapNestedQueries,
};
Expand All @@ -36,6 +37,7 @@ mod is_bool_to_case;
mod or_replace_to_drop_if_exists;
mod qualify_to_subquery;
mod rename_functions;
mod special_date_functions_to_trino;
mod standardize_current_time_unit;
mod wrap_nested_queries;

Expand Down
Loading

0 comments on commit 0e56c8b

Please sign in to comment.