Skip to content

Commit

Permalink
Infer struct field accesses
Browse files Browse the repository at this point in the history
Co-authored-by: Dave Shirley <[email protected]>
  • Loading branch information
emk and dave-shirley-faraday committed Nov 10, 2023
1 parent e5d76ce commit 86f36ad
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 13 deletions.
21 changes: 19 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ pub enum Expression {
BoolValue(Keyword),
Null(Keyword),
Interval(IntervalExpression),
ColumnName(Name),
Name(Name),
Cast(Cast),
Is(IsExpression),
In(InExpression),
Expand Down Expand Up @@ -874,6 +874,7 @@ pub enum Expression {
SpecialDateFunctionCall(SpecialDateFunctionCall),
FunctionCall(FunctionCall),
Index(IndexExpression),
FieldAccess(FieldAccessExpression),
}

impl Expression {
Expand Down Expand Up @@ -1647,6 +1648,14 @@ pub enum IndexOffset {
},
}

/// A field access expression.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct FieldAccessExpression {
pub expression: Box<Expression>,
pub dot: Punct,
pub field_name: Ident,
}

/// An `AS` alias.
#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)]
pub struct Alias {
Expand Down Expand Up @@ -2248,6 +2257,14 @@ peg::parser! {
})
}
--
expression:(@) dot:p(".") field_name:ident() {
Expression::FieldAccess(FieldAccessExpression {
expression: Box::new(expression),
dot,
field_name,
})
}
--
case_token:k("CASE")
case_expr:expression()?
when_clauses:(case_when_clause()*)
Expand Down Expand Up @@ -2288,7 +2305,7 @@ peg::parser! {
// Things from here down might start with arbitrary identifiers, so
// we need to be careful about the order.
function_call:function_call() { Expression::FunctionCall(function_call) }
column_name:name() { Expression::ColumnName(column_name) }
column_name:name() { Expression::Name(column_name) }
}

rule interval_expression() -> IntervalExpression
Expand Down
5 changes: 4 additions & 1 deletion src/infer/contains_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl ContainsAggregate for ast::Expression {
ast::Expression::BoolValue(_) => false,
ast::Expression::Null(_) => false,
ast::Expression::Interval(interval) => interval.contains_aggregate(scope),
ast::Expression::ColumnName(_) => false,
ast::Expression::Name(_) => false,
ast::Expression::Cast(cast) => cast.contains_aggregate(scope),
ast::Expression::Is(is) => is.contains_aggregate(scope),
ast::Expression::In(in_expr) => in_expr.contains_aggregate(scope),
Expand All @@ -85,6 +85,9 @@ impl ContainsAggregate for ast::Expression {
ast::Expression::SpecialDateFunctionCall(fcall) => fcall.contains_aggregate(scope),
ast::Expression::FunctionCall(fcall) => fcall.contains_aggregate(scope),
ast::Expression::Index(idx) => idx.contains_aggregate(scope),
// Putting an aggregate here would be very weird. Do not allow it
// until forced to do so.
ast::Expression::FieldAccess(_) => false,
}
}
}
Expand Down
63 changes: 58 additions & 5 deletions src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ impl InferTypes for ast::GroupBy {
for expr in self.expressions.node_iter_mut() {
let _ty = expr.infer_types(scope)?;
match expr {
Expression::ColumnName(name) => {
Expression::Name(name) => {
group_by_names.push(name.clone());
}
_ => {
Expand Down Expand Up @@ -653,7 +653,7 @@ impl InferTypes for ast::Expression {
ast::Expression::BoolValue(_) => Ok(ArgumentType::bool()),
ast::Expression::Null { .. } => Ok(ArgumentType::null()),
ast::Expression::Interval(_) => Err(nyi(self, "INTERVAL expression")),
ast::Expression::ColumnName(name) => name.infer_types(scope),
ast::Expression::Name(name) => name.infer_types(scope),
ast::Expression::Cast(cast) => cast.infer_types(scope),
ast::Expression::Is(is) => is.infer_types(scope),
ast::Expression::In(in_expr) => in_expr.infer_types(scope),
Expand All @@ -676,6 +676,7 @@ impl InferTypes for ast::Expression {
ast::Expression::SpecialDateFunctionCall(_) => Err(nyi(self, "special date function")),
ast::Expression::FunctionCall(fcall) => fcall.infer_types(scope),
ast::Expression::Index(index) => index.infer_types(scope),
ast::Expression::FieldAccess(field_access) => field_access.infer_types(scope),
}
}
}
Expand Down Expand Up @@ -709,7 +710,47 @@ impl InferTypes for ast::Name {
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
scope.get_argument_type(self)
// `scope` contains the columns we can see. But this might be something
// like `[my_table.]my_struct_column.field1.field2`, so we need to be
// prepared to split this.

// First, find the split between our "base name" (presumably a column)
// and any field accesses.
let mut field_names_with_base_names = vec![];
let mut candidate = self.clone();
let mut base_type = loop {
match scope.get_argument_type(&candidate) {
Ok(base_type) => break base_type,
Err(_) => {
let (next, field_name) = candidate.split_table_and_column();
if let Some(next) = next {
field_names_with_base_names.push((field_name, next.clone()));
candidate = next;
} else {
// Report an error containing the original name.
return Err(Error::annotated(
format!("unknown name: {}", self.unescaped_bigquery()),
self.span(),
"not found",
));
}
}
}
};

// Now that we have a base type, look up each field.
for (field_name, base_name) in field_names_with_base_names.into_iter().rev() {
// TODO: These errors here _might_ not always be optimal for the
// user if we're not actually dealing with structs? But maybe
// they're OK in most cases.
base_type = ArgumentType::Value(
base_type
.expect_struct_type(&base_name)?
.expect_field(&Name::from(field_name))?
.clone(),
);
}
Ok(base_type)
}
}

Expand Down Expand Up @@ -1159,7 +1200,7 @@ impl InferTypes for ast::PartitionBy {
let mut partition_by_names = vec![];
for expr in self.expressions.node_iter_mut() {
match expr {
ast::Expression::ColumnName(name) => {
ast::Expression::Name(name) => {
scope.get_argument_type(name)?;
partition_by_names.push(name.clone());
}
Expand Down Expand Up @@ -1188,6 +1229,18 @@ impl InferTypes for ast::IndexExpression {
}
}

impl InferTypes for ast::FieldAccessExpression {
type Scope = ColumnSetScope;
type Output = ArgumentType;

fn infer_types(&mut self, scope: &Self::Scope) -> Result<Self::Output> {
let struct_ty = self.expression.infer_types(scope)?;
let struct_ty = struct_ty.expect_struct_type(&self.expression)?;
let field_ty = struct_ty.expect_field(&Name::from(self.field_name.clone()))?;
Ok(ArgumentType::Value(field_ty.clone()))
}
}

/// Figure out whether an expression defines an implicit column name.
pub trait InferColumnName {
/// Infer the column name, if any.
Expand All @@ -1206,7 +1259,7 @@ impl<T: InferColumnName> InferColumnName for Option<T> {
impl InferColumnName for ast::Expression {
fn infer_column_name(&mut self) -> Option<Ident> {
match self {
ast::Expression::ColumnName(name) => {
ast::Expression::Name(name) => {
let (_table, col) = name.split_table_and_column();
Some(col)
}
Expand Down
28 changes: 28 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ impl<TV: TypeVarSupport> ArgumentType<TV> {
}
}

/// Expect a [`StructType`].
pub fn expect_struct_type(&self, spanned: &dyn Spanned) -> Result<&StructType<TV>> {
match self {
ArgumentType::Value(ValueType::Simple(SimpleType::Struct(t))) => Ok(t),
_ => Err(Error::annotated(
format!("expected struct type, found {}", self),
spanned.span(),
"type mismatch",
)),
}
}

/// Expect an [`ArrayType`].
pub fn expect_array_type(&self, spanned: &dyn Spanned) -> Result<&ValueType<TV>> {
match self {
Expand Down Expand Up @@ -739,6 +751,22 @@ pub struct StructType<TV: TypeVarSupport = ResolvedTypeVarsOnly> {
}

impl<TV: TypeVarSupport> StructType<TV> {
/// Get the type of a field, or raise an error if the field does not exist.
pub fn expect_field(&self, name: &Name) -> Result<&ValueType<TV>> {
for field in &self.fields {
if let Some(field_name) = &field.name {
if Name::from(field_name.clone()) == *name {
return Ok(&field.ty);
}
}
}
Err(Error::annotated(
format!("no such field {} in {}", name.unescaped_bigquery(), self),
name.span(),
"no such field",
))
}

/// Is this a subtype of `other`?
pub fn is_subtype_of(&self, other: &StructType<TV>) -> bool {
// We are a subtype of `other` if we have the same fields, and each of
Expand Down
13 changes: 8 additions & 5 deletions tests/sql/data_types/structs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
-- pending: sqlite3 Need to build structs from scratch

CREATE OR REPLACE TABLE __result1 AS
WITH t AS (SELECT 1 AS a)
WITH t AS (SELECT 1 AS a, STRUCT<field INT64>(2) AS s)
SELECT
-- Not allowed on Trino.
-- STRUCT() AS empty_struct,
Expand All @@ -14,7 +14,8 @@ SELECT
STRUCT<a INT64, b INT64>(1, 2) AS named_values_with_type,
STRUCT([1] AS arr) AS struct_with_array,
STRUCT(STRUCT(1 AS a) AS `inner`) AS struct_with_struct,
--STRUCT(1 AS a).a AS struct_field_access,
s.field AS struct_field_access,
STRUCT(1 AS a).a AS struct_expr_field_access,
FROM t;

CREATE OR REPLACE TABLE __expected1 (
Expand All @@ -27,7 +28,8 @@ CREATE OR REPLACE TABLE __expected1 (
named_values_with_type STRUCT<a INT64, b INT64>,
struct_with_array STRUCT<arr ARRAY<INT64>>,
struct_with_struct STRUCT<`inner` STRUCT<a INT64>>,
--struct_field_access INT64,
struct_field_access INT64,
struct_expr_field_access INT64,
);
INSERT INTO __expected1
SELECT
Expand All @@ -39,5 +41,6 @@ SELECT
STRUCT<INT64>(NULL), -- anon_value_with_type
STRUCT(1, 2), -- named_values_with_type
STRUCT([1]), -- struct_with_array
STRUCT(STRUCT(1)); -- struct_with_struct
--1; -- struct_field_access
STRUCT(STRUCT(1)), -- struct_with_struct
2, -- struct_field_access
1; -- struct_expr_field_access

0 comments on commit 86f36ad

Please sign in to comment.