Skip to content

Commit

Permalink
feat(function): add CreateFunctionExecutor & bind_create_function (
Browse files Browse the repository at this point in the history
…#828)

* feat(function): add `CreateFunctionExecutor` & `bind_create_function`

Signed-off-by: Michael Xu <[email protected]>

* fix check

Signed-off-by: Michael Xu <[email protected]>

---------

Signed-off-by: Michael Xu <[email protected]>
  • Loading branch information
xzhseh authored Jan 29, 2024
1 parent d12049a commit 7b49788
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 16 deletions.
78 changes: 68 additions & 10 deletions src/binder/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ use pretty_xmlish::Pretty;
use serde::{Deserialize, Serialize};

use super::*;
use crate::types::DataType as RlDataType;

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct CreateFunction {
name: String,
arg_types: Vec<DataType>,
return_types: DataType,
language: String,
body: String,
pub schema_name: String,
pub name: String,
pub arg_types: Vec<RlDataType>,
pub return_type: RlDataType,
pub language: String,
pub body: String,
}

impl fmt::Display for CreateFunction {
Expand Down Expand Up @@ -46,11 +48,67 @@ impl CreateFunction {
impl Binder {
pub(super) fn bind_create_function(
&mut self,
_name: ObjectName,
_args: Option<Vec<OperateFunctionArg>>,
_return_type: Option<DataType>,
_params: CreateFunctionBody,
name: ObjectName,
args: Option<Vec<OperateFunctionArg>>,
return_type: Option<DataType>,
params: CreateFunctionBody,
) -> Result {
todo!()
let Ok((schema_name, function_name)) = split_name(&name) else {
return Err(BindError::BindFunctionError(
"failed to parse the input function name".to_string(),
));
};

let schema_name = schema_name.to_string();
let name = function_name.to_string();

let Some(return_type) = return_type else {
return Err(BindError::BindFunctionError(
"`return type` must be specified".to_string(),
));
};
let return_type = RlDataType::new(DataTypeKind::from(&return_type), false);

// TODO: language check (e.g., currently only support sql)
let Some(language) = params.language.clone() else {
return Err(BindError::BindFunctionError(
"`language` must be specified".to_string(),
));
};
let language = language.to_string();

// SQL udf function supports both single quote (i.e., as 'select $1 + $2')
// and double dollar (i.e., as $$select $1 + $2$$) for as clause
let body = match &params.as_ {
Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(),
Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(),
None => {
if params.return_.is_none() {
return Err(BindError::BindFunctionError(
"AS or RETURN must be specified".to_string(),
));
}
// Otherwise this is a return expression
// Note: this is a current work around, and we are assuming return sql udf
// will NOT involve complex syntax, so just reuse the logic for select definition
format!("select {}", &params.return_.unwrap().to_string())
}
};

let mut arg_types = vec![];
for arg in args.unwrap_or_default() {
arg_types.push(RlDataType::new(DataTypeKind::from(&arg.data_type), false));
}

let f = self.egraph.add(Node::CreateFunction(CreateFunction {
schema_name,
name,
arg_types,
return_type,
language,
body,
}));

Ok(f)
}
}
10 changes: 10 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ impl Binder {
FunctionArgExpr::QualifiedWildcard(_) => todo!("support qualified wildcard"),
}
}

// TODO: sql udf inlining
let _catalog = self.catalog();
let Ok((_schema_name, _function_name)) = split_name(&func.name) else {
return Err(BindError::BindFunctionError(format!(
"failed to parse the function name {}",
func.name
)));
};

let node = match func.name.to_string().to_lowercase().as_str() {
"count" if args.is_empty() => Node::RowCount,
"count" => Node::Count(args[0]),
Expand Down
6 changes: 5 additions & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use egg::{Id, Language};
use itertools::Itertools;

use crate::array;
use crate::catalog::{RootCatalog, TableRefId, DEFAULT_SCHEMA_NAME};
use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId, DEFAULT_SCHEMA_NAME};
use crate::parser::*;
use crate::planner::{Expr as Node, RecExpr, TypeError, TypeSchemaAnalysis};
use crate::types::{DataTypeKind, DataValue};
Expand Down Expand Up @@ -234,6 +234,10 @@ impl Binder {
&self.egraph[id].nodes[0]
}

fn catalog(&self) -> RootCatalogRef {
self.catalog.clone()
}

fn bind_explain(&mut self, query: Statement) -> Result {
let id = self.bind_stmt(query)?;
let id = self.egraph.add(Node::Explain(id));
Expand Down
10 changes: 5 additions & 5 deletions src/catalog/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use crate::types::DataType;

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct FunctionCatalog {
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
pub name: String,
pub arg_types: Vec<DataType>,
pub return_type: DataType,
pub language: String,
pub body: String,
}

impl FunctionCatalog {
Expand Down
25 changes: 25 additions & 0 deletions src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use super::function::FunctionCatalog;
use super::*;

/// The root of all catalogs.
Expand Down Expand Up @@ -99,6 +100,30 @@ impl RootCatalog {
table_id: table.id(),
})
}

pub fn get_function_by_name(
&self,
schema_name: &str,
function_name: &str,
) -> Option<Arc<FunctionCatalog>> {
let schema = self.get_schema_by_name(schema_name)?;
schema.get_function_by_name(function_name)
}

pub fn create_function(
&self,
schema_name: String,
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
) {
let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap();
let mut inner = self.inner.lock().unwrap();
let schema = inner.schemas.get_mut(&schema_idx).unwrap();
schema.create_function(name, arg_types, return_type, language, body);
}
}

impl Inner {
Expand Down
20 changes: 20 additions & 0 deletions src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ impl SchemaCatalog {
pub fn get_function_by_name(&self, name: &str) -> Option<Arc<FunctionCatalog>> {
self.functions.get(name).cloned()
}

pub fn create_function(
&mut self,
name: String,
arg_types: Vec<DataType>,
return_type: DataType,
language: String,
body: String,
) {
self.functions.insert(
name.clone(),
Arc::new(FunctionCatalog {
name: name.clone(),
arg_types,
return_type,
language,
body,
}),
);
}
}

#[cfg(test)]
Expand Down
34 changes: 34 additions & 0 deletions src/executor/create_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::binder::CreateFunction;
use crate::catalog::RootCatalogRef;

/// The executor of `create function` statement.
pub struct CreateFunctionExecutor {
pub f: CreateFunction,
pub catalog: RootCatalogRef,
}

impl CreateFunctionExecutor {
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
pub async fn execute(self) {
let CreateFunction {
schema_name,
name,
arg_types,
return_type,
language,
body,
} = self.f;

self.catalog.create_function(
schema_name.clone(),
name.clone(),
arg_types,
return_type,
language,
body,
);
}
}
8 changes: 8 additions & 0 deletions src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ use self::top_n::TopNExecutor;
use self::values::*;
use self::window::*;
use crate::array::DataChunk;
use crate::executor::create_function::CreateFunctionExecutor;
use crate::planner::{Expr, ExprAnalysis, Optimizer, RecExpr, TypeSchemaAnalysis};
use crate::storage::{Storage, TracedStorageError};
use crate::types::{ColumnIndex, ConvertError, DataType};

mod copy_from_file;
mod copy_to_file;
mod create;
mod create_function;
mod delete;
mod drop;
mod evaluator;
Expand Down Expand Up @@ -302,6 +304,12 @@ impl<S: Storage> Builder<S> {
}
.execute(),

CreateFunction(f) => CreateFunctionExecutor {
f,
catalog: self.optimizer.catalog().clone(),
}
.execute(),

Drop(plan) => DropExecutor {
plan,
storage: self.storage.clone(),
Expand Down

0 comments on commit 7b49788

Please sign in to comment.