diff --git a/examples/Readme.md b/examples/Readme.md index 15f6bec304..8ad782b98c 100644 --- a/examples/Readme.md +++ b/examples/Readme.md @@ -22,6 +22,7 @@ List of Examples: * Kotlin: * CsvExprValueExample: how to create an `ExprValue` for a custom data format, in this case CSV * CustomFunctionsExample: how to create and register user defined functions (UDF) + * CustomProceduresExample: how to create and register stored procedures * EvaluationWithBindings: query evaluation with global bindings * EvaluationWithLazyBindings: query evaluation with global bindings that are lazily evaluated * ParserErrorExample: inspecting errors thrown by the `Parser` diff --git a/examples/src/kotlin/org/partiql/examples/CustomProceduresExample.kt b/examples/src/kotlin/org/partiql/examples/CustomProceduresExample.kt new file mode 100644 index 0000000000..9435d27e16 --- /dev/null +++ b/examples/src/kotlin/org/partiql/examples/CustomProceduresExample.kt @@ -0,0 +1,130 @@ +package org.partiql.examples + +import com.amazon.ion.IonDecimal +import com.amazon.ion.IonStruct +import com.amazon.ion.system.IonSystemBuilder +import org.partiql.examples.util.Example +import org.partiql.lang.CompilerPipeline +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.BindingCase +import org.partiql.lang.eval.BindingName +import org.partiql.lang.eval.Bindings +import org.partiql.lang.eval.EvaluationException +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.ExprValueFactory +import org.partiql.lang.eval.ExprValueType +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedureSignature +import org.partiql.lang.eval.stringValue +import java.io.PrintStream +import java.math.BigDecimal +import java.math.RoundingMode + +private val ion = IonSystemBuilder.standard().build() + +/** + * A simple custom stored procedure that calculates the moon weight for each crewmate of the given crew, storing the + * moon weight in the [EvaluationSession] global bindings. This procedure also returns the number of crewmates we + * calculated the moon weight for, returning -1 if no crew is found. + * + * This example demonstrates how to create a custom stored procedure, check argument types, and modify the + * [EvaluationSession]. + */ +class CalculateCrewMoonWeight(private val valueFactory: ExprValueFactory): StoredProcedure { + private val MOON_GRAVITATIONAL_CONSTANT = BigDecimal(1.622 / 9.81) + + // [StoredProcedureSignature] takes two arguments: + // 1. the name of the stored procedure + // 2. the arity of this stored procedure. Checks to arity are taken care of by the evaluator. However, we must + // still check that the passed arguments are of the right type in our implementation of the procedure. + override val signature = StoredProcedureSignature(name = "calculate_crew_moon_weight", arity = 1) + + // `call` is where you define the logic of the stored procedure given an [EvaluationSession] and a list of + // arguments + override fun call(session: EvaluationSession, args: List): ExprValue { + // We first check that the first argument is a string + val crewName = args.first() + // In the future the evaluator will also verify function argument types, but for now we must verify their type + // manually + if (crewName.type != ExprValueType.STRING) { + val errorContext = PropertyValueMap().also { + it[Property.EXPECTED_ARGUMENT_TYPES] = "STRING" + it[Property.ACTUAL_ARGUMENT_TYPES] = crewName.type.name + it[Property.FUNCTION_NAME] = signature.name + } + throw EvaluationException("First argument to ${signature.name} was not a string", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + + // Next we check if the given `crewName` is in the [EvaluationSession]'s global bindings. If not, we return 0. + val sessionGlobals = session.globals + val crewBindings = sessionGlobals[BindingName(crewName.stringValue(), BindingCase.INSENSITIVE)] + ?: return valueFactory.newInt(-1) + + // Now that we've confirmed the given `crewName` is in the session's global bindings, we calculate and store + // the moon weight for each crewmate in the crew. + // In addition, we keep a running a tally of how many crewmates we do this for. + var numCalculated = 0 + for (crewmateBinding in crewBindings) { + val crewmate = crewmateBinding.ionValue as IonStruct + val mass = crewmate["mass"] as IonDecimal + val moonWeight = (mass.decimalValue() * MOON_GRAVITATIONAL_CONSTANT).setScale(1, RoundingMode.HALF_UP) + crewmate.add("moonWeight", ion.newDecimal(moonWeight)) + + numCalculated++ + } + return valueFactory.newInt(numCalculated) + } +} + +/** + * Demonstrates the use of custom stored procedure [CalculateCrewMoonWeight] in PartiQL queries. + */ +class CustomProceduresExample(out: PrintStream) : Example(out) { + override fun run() { + /** + * To make custom stored procedures available to the PartiQL query being executed, they must be passed to + * [CompilerPipeline.Builder.addProcedure]. + */ + val pipeline = CompilerPipeline.build(ion) { + addProcedure(CalculateCrewMoonWeight(valueFactory)) + } + + // Here, we initialize the crews to be stored in our global session bindings + val initialCrews = Bindings.ofMap( + mapOf( + "crew1" to pipeline.valueFactory.newFromIonValue( + ion.singleValue("""[ { name: "Neil", mass: 80.5 }, + { name: "Buzz", mass: 72.3 }, + { name: "Michael", mass: 89.9 } ]""")), + "crew2" to pipeline.valueFactory.newFromIonValue( + ion.singleValue("""[ { name: "James", mass: 77.1 }, + { name: "Spock", mass: 81.6 } ]""")) + ) + ) + val session = EvaluationSession.build { globals(initialCrews) } + + val crew1BindingName = BindingName("crew1", BindingCase.INSENSITIVE) + val crew2BindingName = BindingName("crew2", BindingCase.INSENSITIVE) + + out.println("Initial global session bindings:") + print("Crew 1:", "${session.globals[crew1BindingName]}") + print("Crew 2:", "${session.globals[crew2BindingName]}") + + // We call our custom stored procedure using PartiQL's `EXEC` clause. Here we call our stored procedure + // 'calculate_crew_moon_weight' with the arg 'crew1', which outputs the number of crewmates we've calculated + // the moon weight for + val procedureCall = "EXEC calculate_crew_moon_weight 'crew1'" + val procedureCallOutput = pipeline.compile(procedureCall).eval(session) + print("Number of calculated moon weights:", "$procedureCallOutput") + + out.println("Updated global session bindings:") + print("Crew 1:", "${session.globals[crew1BindingName]}") + print("Crew 2:", "${session.globals[crew2BindingName]}") + } +} diff --git a/examples/src/kotlin/org/partiql/examples/util/Main.kt b/examples/src/kotlin/org/partiql/examples/util/Main.kt index 37e830283f..f2e955be87 100644 --- a/examples/src/kotlin/org/partiql/examples/util/Main.kt +++ b/examples/src/kotlin/org/partiql/examples/util/Main.kt @@ -15,6 +15,7 @@ private val examples = mapOf( // Kotlin Examples CsvExprValueExample::class.java.simpleName to CsvExprValueExample(System.out), CustomFunctionsExample::class.java.simpleName to CustomFunctionsExample(System.out), + CustomProceduresExample::class.java.simpleName to CustomProceduresExample(System.out), EvaluationWithBindings::class.java.simpleName to EvaluationWithBindings(System.out), EvaluationWithLazyBindings::class.java.simpleName to EvaluationWithLazyBindings(System.out), ParserErrorExample::class.java.simpleName to ParserErrorExample(System.out), diff --git a/examples/test/org/partiql/examples/CustomProceduresExampleTest.kt b/examples/test/org/partiql/examples/CustomProceduresExampleTest.kt new file mode 100644 index 0000000000..5866a8cf5e --- /dev/null +++ b/examples/test/org/partiql/examples/CustomProceduresExampleTest.kt @@ -0,0 +1,24 @@ +package org.partiql.examples + +import org.partiql.examples.util.Example +import java.io.PrintStream + +class CustomProceduresExampleTest : BaseExampleTest() { + override fun example(out: PrintStream): Example = CustomProceduresExample(out) + + override val expected = """ + |Initial global session bindings: + |Crew 1: + | [{'name': 'Neil', 'mass': 80.5}, {'name': 'Buzz', 'mass': 72.3}, {'name': 'Michael', 'mass': 89.9}] + |Crew 2: + | [{'name': 'James', 'mass': 77.1}, {'name': 'Spock', 'mass': 81.6}] + |Number of calculated moon weights: + | 3 + |Updated global session bindings: + |Crew 1: + | [{'name': 'Neil', 'mass': 80.5, 'moonWeight': 13.3}, {'name': 'Buzz', 'mass': 72.3, 'moonWeight': 12.0}, {'name': 'Michael', 'mass': 89.9, 'moonWeight': 14.9}] + |Crew 2: + | [{'name': 'James', 'mass': 77.1}, {'name': 'Spock', 'mass': 81.6}] + | + """.trimMargin() +} diff --git a/lang/resources/org/partiql/type-domains/partiql.ion b/lang/resources/org/partiql/type-domains/partiql.ion index 50084f905f..f213d87514 100644 --- a/lang/resources/org/partiql/type-domains/partiql.ion +++ b/lang/resources/org/partiql/type-domains/partiql.ion @@ -26,7 +26,12 @@ (where where::(? expr))) // Data definition operations also cannot be composed with other `expr` nodes. - (ddl op::ddl_op)) + (ddl op::ddl_op) + + // Stored procedure calls are only allowed at the top level of a query and cannot be used as an expression + // Currently supports stored procedure calls with the unnamed argument syntax: + // EXEC [.*] + (exec procedure_name::symbol args::(* expr 0))) // The expressions that can result in values. (sum expr diff --git a/lang/src/org/partiql/lang/CompilerPipeline.kt b/lang/src/org/partiql/lang/CompilerPipeline.kt index ebb2f93981..e44810ba9e 100644 --- a/lang/src/org/partiql/lang/CompilerPipeline.kt +++ b/lang/src/org/partiql/lang/CompilerPipeline.kt @@ -18,6 +18,7 @@ import com.amazon.ion.* import org.partiql.lang.ast.* import org.partiql.lang.eval.* import org.partiql.lang.eval.builtins.* +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure import org.partiql.lang.syntax.* /** @@ -35,7 +36,13 @@ data class StepContext( * Includes built-in functions as well as custom functions added while the [CompilerPipeline] * was being built. */ - val functions: @JvmSuppressWildcards Map + val functions: @JvmSuppressWildcards Map, + + /** + * Returns a list of all stored procedures which are available for execution. + * Only includes the custom stored procedures added while the [CompilerPipeline] was being built. + */ + val procedures: @JvmSuppressWildcards Map ) /** @@ -65,6 +72,12 @@ interface CompilerPipeline { */ val functions: @JvmSuppressWildcards Map + /** + * Returns a list of all stored procedures which are available for execution. + * Only includes the custom stored procedures added while the [CompilerPipeline] was being built. + */ + val procedures: @JvmSuppressWildcards Map + /** Compiles the specified PartiQL query using the configured parser. */ fun compile(query: String): Expression @@ -106,6 +119,7 @@ interface CompilerPipeline { private var parser: Parser? = null private var compileOptions: CompileOptions? = null private val customFunctions: MutableMap = HashMap() + private val customProcedures: MutableMap = HashMap() private val preProcessingSteps: MutableList = ArrayList() /** @@ -137,6 +151,13 @@ interface CompilerPipeline { */ fun addFunction(function: ExprFunction): Builder = this.apply { customFunctions[function.name] = function } + /** + * Add a custom stored procedure which will be callable by the compiled queries. + * + * Stored procedures added here will replace any built-in procedure with the same name. + */ + fun addProcedure(procedure: StoredProcedure): Builder = this.apply { customProcedures[procedure.signature.name] = procedure } + /** Adds a preprocessing step to be executed after parsing but before compilation. */ fun addPreprocessingStep(step: ProcessingStep): Builder = this.apply { preProcessingSteps.add(step) } @@ -153,6 +174,7 @@ interface CompilerPipeline { parser ?: SqlParser(valueFactory.ion), compileOptions ?: CompileOptions.standard(), allFunctions, + customProcedures, preProcessingSteps) } } @@ -163,17 +185,18 @@ private class CompilerPipelineImpl( private val parser: Parser, override val compileOptions: CompileOptions, override val functions: Map, + override val procedures: Map, private val preProcessingSteps: List ) : CompilerPipeline { - private val compiler = EvaluatingCompiler(valueFactory, functions, compileOptions) + private val compiler = EvaluatingCompiler(valueFactory, functions, procedures, compileOptions) override fun compile(query: String): Expression { return compile(parser.parseExprNode(query)) } override fun compile(query: ExprNode): Expression { - val context = StepContext(valueFactory, compileOptions, functions) + val context = StepContext(valueFactory, compileOptions, functions, procedures) val preProcessedQuery = preProcessingSteps.fold(query) { currentExprNode, step -> step(currentExprNode, context) diff --git a/lang/src/org/partiql/lang/ast/AstSerialization.kt b/lang/src/org/partiql/lang/ast/AstSerialization.kt index 1052eb9baf..27a6b01529 100644 --- a/lang/src/org/partiql/lang/ast/AstSerialization.kt +++ b/lang/src/org/partiql/lang/ast/AstSerialization.kt @@ -94,6 +94,7 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem): is DropTable -> case { writeDropTable(expr) } is DropIndex -> case { writeDropIndex(expr) } is Parameter -> case { writeParameter(expr)} + is Exec -> throw UnsupportedOperationException("EXEC clause not supported by the V0 AST") }.toUnit() } } diff --git a/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt b/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt index 902a51a19b..5380d3f4fb 100644 --- a/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt +++ b/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt @@ -2,6 +2,7 @@ package org.partiql.lang.ast import com.amazon.ionelement.api.toIonElement import org.partiql.lang.domains.PartiqlAst +import org.partiql.pig.runtime.SymbolPrimitive import org.partiql.pig.runtime.asPrimitive /** Converts an [ExprNode] to a [PartiqlAst.statement]. */ @@ -16,12 +17,16 @@ fun ExprNode.toAstStatement(): PartiqlAst.Statement { is CreateTable, is CreateIndex, is DropTable, is DropIndex -> toAstDdl() + is Exec -> toAstExec() } } private fun PartiQlMetaContainer.toElectrolyteMetaContainer(): ElectrolyteMetaContainer = com.amazon.ionelement.api.metaContainerOf(map { it.tag to it }) +private fun SymbolicName.toSymbolPrimitive() : SymbolPrimitive = + SymbolPrimitive(this.name, this.metas.toElectrolyteMetaContainer()) + private fun ExprNode.toAstDdl(): PartiqlAst.Statement { val thiz = this val metas = metas.toElectrolyteMetaContainer() @@ -30,7 +35,7 @@ private fun ExprNode.toAstDdl(): PartiqlAst.Statement { when(thiz) { is Literal, is LiteralMissing, is VariableReference, is Parameter, is NAry, is CallAgg, is Typed, is Path, is SimpleCase, is SearchedCase, is Select, is Struct, is Seq, - is DataManipulation -> error("Can't convert ${thiz.javaClass} to PartiqlAst.ddl") + is DataManipulation, is Exec -> error("Can't convert ${thiz.javaClass} to PartiqlAst.ddl") is CreateTable -> ddl(createTable(thiz.tableName), metas) is CreateIndex -> ddl(createIndex(identifier(thiz.tableName, caseSensitive()), thiz.keys.map { it.toAstExpr() }), metas) @@ -48,6 +53,18 @@ private fun ExprNode.toAstDdl(): PartiqlAst.Statement { } } +private fun ExprNode.toAstExec() : PartiqlAst.Statement { + val node = this + val metas = metas.toElectrolyteMetaContainer() + + return PartiqlAst.build { + when (node) { + is Exec -> exec_(node.procedureName.toSymbolPrimitive(), node.args.map { it.toAstExpr() }, metas) + else -> error("Can't convert ${node.javaClass} to PartiqlAst.Statement.Exec") + } + } +} + fun ExprNode.toAstExpr(): PartiqlAst.Expr { val node = this val metas = this.metas.toElectrolyteMetaContainer() @@ -147,8 +164,8 @@ fun ExprNode.toAstExpr(): PartiqlAst.Expr { SeqType.BAG -> bag(node.values.map { it.toAstExpr() }) } - // These are handled by `toAstDml()` - is DataManipulation, is CreateTable, is CreateIndex, is DropTable, is DropIndex -> + // These are handled by `toAstDml()`, `toAstDdl()`, and `toAstExec()` + is DataManipulation, is CreateTable, is CreateIndex, is DropTable, is DropIndex, is Exec -> error("Can't transform ${node.javaClass} to a PartiqlAst.expr }") } } diff --git a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt index d11bc77c8b..0cbca85f1c 100644 --- a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt +++ b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt @@ -2,6 +2,7 @@ package org.partiql.lang.ast import com.amazon.ion.IonSystem import com.amazon.ionelement.api.toIonValue +import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.PartiqlAst.CaseSensitivity import org.partiql.lang.domains.PartiqlAst.DdlOp import org.partiql.lang.domains.PartiqlAst.DmlOp @@ -33,6 +34,7 @@ private class StatementTransformer(val ion: IonSystem) { is Statement.Query -> stmt.toExprNode() is Statement.Dml -> stmt.toExprNode() is Statement.Ddl -> stmt.toExprNode() + is Statement.Exec -> stmt.toExprNode() } private fun ElectrolyteMetaContainer.toPartiQlMetaContainer(): PartiQlMetaContainer { @@ -344,4 +346,8 @@ private class StatementTransformer(val ion: IonSystem) { metas = metas) } } + + private fun Statement.Exec.toExprNode(): ExprNode { + return Exec(procedureName.toSymbolicName(), this.args.toExprNodeList(), metas.toPartiQlMetaContainer()) + } } diff --git a/lang/src/org/partiql/lang/ast/ast.kt b/lang/src/org/partiql/lang/ast/ast.kt index 8be4f19948..8ac3d463a2 100644 --- a/lang/src/org/partiql/lang/ast/ast.kt +++ b/lang/src/org/partiql/lang/ast/ast.kt @@ -103,6 +103,9 @@ sealed class ExprNode : AstNode(), HasMetas { is Parameter -> { copy(metas = metas) } + is Exec -> { + copy(metas = metas) + } } } } @@ -210,6 +213,19 @@ data class Typed( override val children: List = listOf(expr, type) } +//******************************** +// Stored procedure clauses +//******************************** + +/** Represents a call to a stored procedure, i.e. `EXEC stored_procedure [.*]` */ +data class Exec( + val procedureName: SymbolicName, + val args: List, + override val metas: MetaContainer +) : ExprNode() { + override val children: List = args +} + //******************************** // Path expressions //******************************** diff --git a/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt b/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt index f322988b72..402dc8a808 100644 --- a/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt +++ b/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt @@ -49,6 +49,7 @@ open class AstRewriterBase : AstRewriter { is CreateIndex -> rewriteCreateIndex(node) is DropTable -> rewriteDropTable(node) is DropIndex -> rewriteDropIndex(node) + is Exec -> rewriteExec(node) } open fun rewriteMetas(itemWithMetas: HasMetas): MetaContainer = itemWithMetas.metas @@ -398,4 +399,10 @@ open class AstRewriterBase : AstRewriter { rewriteVariableReference(node.identifier) as VariableReference, rewriteMetas(node)) + open fun rewriteExec(node: Exec): Exec = + Exec( + rewriteSymbolicName(node.procedureName), + node.args.map { rewriteExprNode(it) }, + rewriteMetas(node)) + } diff --git a/lang/src/org/partiql/lang/ast/passes/AstWalker.kt b/lang/src/org/partiql/lang/ast/passes/AstWalker.kt index d57efd20d0..ef46a8d314 100644 --- a/lang/src/org/partiql/lang/ast/passes/AstWalker.kt +++ b/lang/src/org/partiql/lang/ast/passes/AstWalker.kt @@ -116,6 +116,7 @@ open class AstWalker(private val visitor: AstVisitor) { } } is CreateTable, is DropTable, is DropIndex -> case { } + is Exec -> case { } }.toUnit() } } diff --git a/lang/src/org/partiql/lang/errors/ErrorAndErrorContexts.kt b/lang/src/org/partiql/lang/errors/ErrorAndErrorContexts.kt index ab2e7e7afe..c84d57eff7 100644 --- a/lang/src/org/partiql/lang/errors/ErrorAndErrorContexts.kt +++ b/lang/src/org/partiql/lang/errors/ErrorAndErrorContexts.kt @@ -72,6 +72,7 @@ enum class Property(val propertyName: String, val propertyType: PropertyType) { LIKE_PATTERN("pattern", STRING_CLASS), LIKE_ESCAPE("escape_char", STRING_CLASS), FUNCTION_NAME("function_name", STRING_CLASS), + PROCEDURE_NAME("procedure_name", STRING_CLASS), EXPECTED_ARGUMENT_TYPES("expected_types", STRING_CLASS), ACTUAL_ARGUMENT_TYPES("actual_types", STRING_CLASS), FEATURE_NAME("FEATURE_NAME", STRING_CLASS), diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index 5ae6a600f7..05bcf58011 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -297,6 +297,16 @@ enum class ErrorCode(private val category: ErrorCategory, LOC_TOKEN, "Aggregate function calls take 1 argument only"), + PARSE_NO_STORED_PROCEDURE_PROVIDED( + ErrorCategory.PARSER, + LOC_TOKEN, + "No stored procedure provided"), + + PARSE_EXEC_AT_UNEXPECTED_LOCATION( + ErrorCategory.PARSER, + LOC_TOKEN, + "EXEC call found at unexpected location"), + PARSE_MALFORMED_JOIN( ErrorCategory.PARSER, LOC_TOKEN, @@ -407,11 +417,24 @@ enum class ErrorCode(private val category: ErrorCategory, "No such function: ${errorContext?.get(Property.FUNCTION_NAME)?.stringValue() ?: UNKNOWN} " }, + EVALUATOR_NO_SUCH_PROCEDURE( + ErrorCategory.EVALUATOR, + LOCATION + setOf(Property.PROCEDURE_NAME), + ""){ + override fun getErrorMessage(errorContext: PropertyValueMap?): String = + "No such stored procedure: ${errorContext?.get(Property.PROCEDURE_NAME)?.stringValue() ?: UNKNOWN} " + }, + EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_FUNC_CALL( ErrorCategory.EVALUATOR, LOCATION + setOf(Property.EXPECTED_ARITY_MIN, Property.EXPECTED_ARITY_MAX), "Incorrect number of arguments to function call"), + EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL( + ErrorCategory.EVALUATOR, + LOCATION + setOf(Property.EXPECTED_ARITY_MIN, Property.EXPECTED_ARITY_MAX), + "Incorrect number of arguments to procedure call"), + EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_FUNC_CALL( ErrorCategory.EVALUATOR, LOCATION + setOf(Property.EXPECTED_ARGUMENT_TYPES, Property.ACTUAL_ARGUMENT_TYPES, Property.FUNCTION_NAME), @@ -422,6 +445,16 @@ enum class ErrorCode(private val category: ErrorCategory, "got: ${errorContext?.get(Property.ACTUAL_ARGUMENT_TYPES) ?: UNKNOWN}" }, + EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL( + ErrorCategory.EVALUATOR, + LOCATION + setOf(Property.EXPECTED_ARGUMENT_TYPES, Property.ACTUAL_ARGUMENT_TYPES, Property.FUNCTION_NAME), + "Incorrect type of arguments to procedure call") { + override fun getErrorMessage(errorContext: PropertyValueMap?): String = + "Invalid argument types for ${errorContext?.get(Property.FUNCTION_NAME) ?: UNKNOWN}, " + + "expected: ${errorContext?.get(Property.EXPECTED_ARGUMENT_TYPES) ?: UNKNOWN} " + + "got: ${errorContext?.get(Property.ACTUAL_ARGUMENT_TYPES) ?: UNKNOWN}" + }, + EVALUATOR_CONCAT_FAILED_DUE_TO_INCOMPATIBLE_TYPE( ErrorCategory.EVALUATOR, LOCATION + setOf(Property.ACTUAL_ARGUMENT_TYPES), diff --git a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt index 96eb8144aa..4c2a47f702 100644 --- a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt +++ b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt @@ -21,6 +21,7 @@ import org.partiql.lang.ast.passes.* import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.errors.* import org.partiql.lang.eval.binding.* +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure import org.partiql.lang.eval.like.PatternPart import org.partiql.lang.eval.like.executePattern import org.partiql.lang.eval.like.parsePattern @@ -55,6 +56,7 @@ import kotlin.collections.* internal class EvaluatingCompiler( private val valueFactory: ExprValueFactory, private val functions: Map, + private val procedures: Map, private val compileOptions: CompileOptions = CompileOptions.standard() ) { private val thunkFactory = ThunkFactory(compileOptions.thunkOptions) @@ -279,6 +281,8 @@ internal class EvaluatingCompiler( is CreateIndex, is DropIndex, is DropTable -> compileDdl(expr) + is Exec -> compileExec(expr) + } } @@ -1925,6 +1929,48 @@ internal class EvaluatingCompiler( } } + private fun compileExec(node: Exec): ThunkEnv { + val (procedureName, args, metas: MetaContainer) = node + val procedure = procedures[procedureName.name] ?: err( + "No such stored procedure: ${procedureName.name}", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + errorContextFrom(metas).also { + it[Property.PROCEDURE_NAME] = procedureName.name + }, + internal = false) + + // Check arity + if (args.size !in procedure.signature.arity) { + val errorContext = errorContextFrom(metas).also { + it[Property.EXPECTED_ARITY_MIN] = procedure.signature.arity.first + it[Property.EXPECTED_ARITY_MAX] = procedure.signature.arity.last + } + + val message = when { + procedure.signature.arity.first == 1 && procedure.signature.arity.last == 1 -> + "${procedure.signature.name} takes a single argument, received: ${args.size}" + procedure.signature.arity.first == procedure.signature.arity.last -> + "${procedure.signature.name} takes exactly ${procedure.signature.arity.first} arguments, received: ${args.size}" + else -> + "${procedure.signature.name} takes between ${procedure.signature.arity.first} and " + + "${procedure.signature.arity.last} arguments, received: ${args.size}" + } + + throw EvaluationException(message, + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + + // Compile the procedure's arguments + val argThunks = args.map { compileExprNode(it) } + + return thunkFactory.thunkEnv(metas) { env -> + val procedureArgValues = argThunks.map { it(env) } + procedure.call(env.session, procedureArgValues) + } + } + /** A special wrapper for `UNPIVOT` values as a BAG. */ private class UnpivotedExprValue(private val values: Iterable) : BaseExprValue() { override val type = ExprValueType.BAG diff --git a/lang/src/org/partiql/lang/eval/builtins/storedprocedure/StoredProcedure.kt b/lang/src/org/partiql/lang/eval/builtins/storedprocedure/StoredProcedure.kt new file mode 100644 index 0000000000..8e2b17dc23 --- /dev/null +++ b/lang/src/org/partiql/lang/eval/builtins/storedprocedure/StoredProcedure.kt @@ -0,0 +1,40 @@ +package org.partiql.lang.eval.builtins.storedprocedure + +import org.partiql.lang.eval.EvaluatingCompiler +import org.partiql.lang.eval.EvaluationSession +import org.partiql.lang.eval.ExprFunction +import org.partiql.lang.eval.ExprValue +import org.partiql.lang.eval.Expression + +/** + * A typed version of a stored procedure signature. This signature includes the stored procedure's [name] and [arity]. + */ +data class StoredProcedureSignature(val name: String, val arity: IntRange) { + constructor(name: String, arity: Int) : this(name, (arity..arity)) +} + +/** + * Represents a stored procedure that can be invoked. + * + * Stored procedures differ from functions (i.e. [ExprFunction]) in that: + * 1. stored procedures are allowed to have side-effects + * 2. stored procedures are only allowed at the top level of a query and cannot be used as an [Expression] (i.e. stored + * procedures can only be called using `EXEC sproc 'arg1', 'arg2'` and cannot be called from queries such as + * `SELECT * FROM (EXEC sproc 'arg1', 'arg2')` + */ +interface StoredProcedure { + /** + * [StoredProcedureSignature] representing the stored procedure's name and arity to be referenced in stored + * procedure calls. + */ + val signature: StoredProcedureSignature + + /** + * Invokes the stored procedure. Proper arity is checked by the [EvaluatingCompiler], but argument type checking + * is left to the implementation. + * + * @param session the calling environment session + * @param args argument list supplied to the stored procedure + */ + fun call(session: EvaluationSession, args: List): ExprValue +} \ No newline at end of file diff --git a/lang/src/org/partiql/lang/syntax/SqlParser.kt b/lang/src/org/partiql/lang/syntax/SqlParser.kt index a8bdf85836..3a7d0d5b03 100644 --- a/lang/src/org/partiql/lang/syntax/SqlParser.kt +++ b/lang/src/org/partiql/lang/syntax/SqlParser.kt @@ -114,7 +114,8 @@ class SqlParser(private val ion: IonSystem) : Parser { DROP_TABLE, DROP_INDEX, CREATE_INDEX, - PARAMETER; + PARAMETER, + EXEC; val identifier = name.toLowerCase() } @@ -335,6 +336,10 @@ class SqlParser(private val ion: IonSystem) : Parser { } } } + EXEC -> { + val procedureName = SymbolicName(token?.text!!.toLowerCase(), token.toSourceLocationMetaContainer()) + Exec(procedureName, children.map { it.toExprNode() }, metas) + } CALL_AGG -> { val funcExpr = VariableReference( @@ -1015,6 +1020,7 @@ class SqlParser(private val ion: IonSystem) : Parser { tail.tail.parseFunctionCall(head!!) else -> err("Unexpected keyword", PARSE_UNEXPECTED_KEYWORD) } + "exec" -> tail.parseExec() else -> err("Unexpected keyword", PARSE_UNEXPECTED_KEYWORD) } LEFT_PAREN -> { @@ -1712,6 +1718,28 @@ class SqlParser(private val ion: IonSystem) : Parser { } } + private fun List.parseExec(): ParseNode { + var rem = this + if (rem.head?.type == EOF) { + rem.err("No stored procedure provided", PARSE_NO_STORED_PROCEDURE_PROVIDED) + } + + val procedureName = rem.head + rem = rem.tail + + // Stored procedure call has no args + if (rem.head?.type == EOF) { + return ParseNode(EXEC, procedureName, emptyList(), rem) + } + + else if (rem.head?.type == LEFT_PAREN) { + rem.err("Unexpected $LEFT_PAREN found following stored procedure call", PARSE_UNEXPECTED_TOKEN) + } + + return rem.parseArgList(aliasSupportType = NONE, mode = NORMAL_ARG_LIST) + .copy(type = EXEC, token = procedureName) + } + /** * Parses substring * @@ -2216,9 +2244,26 @@ class SqlParser(private val ion: IonSystem) : Parser { return ParseNode(ARG_LIST, null, items, rem) } + /** + * Checks that the given [Token] list does not have any top-level tokens outside of the top level query. Throws + * an error if so. + * + * Currently only checks for `EXEC`. DDL and DML along with corresponding error codes to be added in the future + * (https://github.com/partiql/partiql-lang-kotlin/issues/354). + */ + private fun List.checkUnexpectedTopLevelToken() { + for ((index, token) in this.withIndex()) { + if (token.keywordText == "exec" && index != 0) { + token.err("EXEC call found at unexpected location", PARSE_EXEC_AT_UNEXPECTED_LOCATION) + } + } + } + /** Entry point into the parser. */ override fun parseExprNode(source: String): ExprNode { - val node = SqlLexer(ion).tokenize(source).parseExpression() + val tokens = SqlLexer(ion).tokenize(source) + tokens.checkUnexpectedTopLevelToken() + val node = tokens.parseExpression() val rem = node.remaining if (!rem.onlyEndOfStatement()) { when(rem.head?.type ) { diff --git a/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt b/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt index 76800282fc..a6dd8672d6 100644 --- a/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt +++ b/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt @@ -1546,4 +1546,87 @@ class ParserErrorsTest : TestBase() { Property.TOKEN_TYPE to TokenType.EOF, Property.TOKEN_VALUE to ion.newSymbol("EOF"))) + //**************************************** + // EXEC clause parsing errors + //**************************************** + + @Test + fun execNoStoredProcedureProvided() = checkInputThrowingParserException( + "EXEC", + ErrorCode.PARSE_NO_STORED_PROCEDURE_PROVIDED, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 5L, + Property.TOKEN_TYPE to TokenType.EOF, + Property.TOKEN_VALUE to ion.newSymbol("EOF"))) + + @Test + fun execCommaBetweenStoredProcedureAndArg() = checkInputThrowingParserException( + "EXEC foo, arg0, arg1", + ErrorCode.PARSE_UNEXPECTED_TERM, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 9L, + Property.TOKEN_TYPE to TokenType.COMMA, + Property.TOKEN_VALUE to ion.newSymbol(","))) + + @Test + fun execArgTrailingComma() = checkInputThrowingParserException( + "EXEC foo arg0, arg1,", + ErrorCode.PARSE_UNEXPECTED_TERM, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 21L, + Property.TOKEN_TYPE to TokenType.EOF, + Property.TOKEN_VALUE to ion.newSymbol("EOF"))) + + @Test + fun execUnexpectedParen() = checkInputThrowingParserException( + "EXEC foo()", + ErrorCode.PARSE_UNEXPECTED_TOKEN, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 9L, + Property.TOKEN_TYPE to TokenType.LEFT_PAREN, + Property.TOKEN_VALUE to ion.newSymbol("("))) + + @Test + fun execUnexpectedParenWithArgs() = checkInputThrowingParserException( + "EXEC foo(arg0, arg1)", + ErrorCode.PARSE_UNEXPECTED_TOKEN, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 9L, + Property.TOKEN_TYPE to TokenType.LEFT_PAREN, + Property.TOKEN_VALUE to ion.newSymbol("("))) + + @Test + fun execAtUnexpectedLocation() = checkInputThrowingParserException( + "EXEC EXEC", + ErrorCode.PARSE_EXEC_AT_UNEXPECTED_LOCATION, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.TOKEN_TYPE to TokenType.KEYWORD, + Property.TOKEN_VALUE to ion.newSymbol("exec"))) + + @Test + fun execAtUnexpectedLocationAfterExec() = checkInputThrowingParserException( + "EXEC foo EXEC", + ErrorCode.PARSE_EXEC_AT_UNEXPECTED_LOCATION, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 10L, + Property.TOKEN_TYPE to TokenType.KEYWORD, + Property.TOKEN_VALUE to ion.newSymbol("exec"))) + + @Test + fun execAtUnexpectedLocationInExpression() = checkInputThrowingParserException( + "SELECT * FROM (EXEC undrop 'foo')", + ErrorCode.PARSE_EXEC_AT_UNEXPECTED_LOCATION, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 16L, + Property.TOKEN_TYPE to TokenType.KEYWORD, + Property.TOKEN_VALUE to ion.newSymbol("exec"))) } diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerExecTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerExecTests.kt new file mode 100644 index 0000000000..9c37f3f290 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerExecTests.kt @@ -0,0 +1,274 @@ +package org.partiql.lang.eval + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.CompilerPipeline +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.errors.PropertyValueMap +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure +import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedureSignature +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.softAssert +import org.partiql.lang.util.to + +private fun createWrongSProcErrorContext(arg: ExprValue, expectedArgTypes: String, procName: String): PropertyValueMap { + val errorContext = PropertyValueMap() + errorContext[Property.EXPECTED_ARGUMENT_TYPES] = expectedArgTypes + errorContext[Property.ACTUAL_ARGUMENT_TYPES] = arg.type.name + errorContext[Property.FUNCTION_NAME] = procName + return errorContext +} + +/** + * Simple stored procedure that takes no arguments and outputs 0. + */ +private class ZeroArgProcedure(val valueFactory: ExprValueFactory): StoredProcedure { + override val signature = StoredProcedureSignature("zero_arg_procedure", 0) + + override fun call(session: EvaluationSession, args: List): ExprValue { + return valueFactory.newInt(0) + } +} + +/** + * Simple stored procedure that takes no arguments and outputs -1. Used to show that added stored procedures of the + * same name will be overridden. + */ +private class OverriddenZeroArgProcedure(val valueFactory: ExprValueFactory): StoredProcedure { + override val signature = StoredProcedureSignature("zero_arg_procedure", 0) + + override fun call(session: EvaluationSession, args: List): ExprValue { + return valueFactory.newInt(-1) + } +} + +/** + * Simple stored procedure that takes one integer argument and outputs that argument back. + */ +private class OneArgProcedure(val valueFactory: ExprValueFactory): StoredProcedure { + override val signature = StoredProcedureSignature("one_arg_procedure", 1) + + override fun call(session: EvaluationSession, args: List): ExprValue { + val arg = args.first() + if (arg.type != ExprValueType.INT) { + val errorContext = createWrongSProcErrorContext(arg, "INT", signature.name) + throw EvaluationException("invalid first argument", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + return arg + } +} + +/** + * Simple stored procedure that takes two integer arguments and outputs the args as a string separated by + * a space. + */ +private class TwoArgProcedure(val valueFactory: ExprValueFactory): StoredProcedure { + override val signature = StoredProcedureSignature("two_arg_procedure", 2) + + override fun call(session: EvaluationSession, args: List): ExprValue { + val arg1 = args.first() + if (arg1.type != ExprValueType.INT) { + val errorContext = createWrongSProcErrorContext(arg1, "INT", signature.name) + throw EvaluationException("invalid first argument", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + + val arg2 = args[1] + if (arg2.type != ExprValueType.INT) { + val errorContext = createWrongSProcErrorContext(arg2, "INT", signature.name) + throw EvaluationException("invalid second argument", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + return valueFactory.newString("$arg1 $arg2") + } +} + +/** + * Simple stored procedure that takes one string argument and checks if the binding (case-insensitive) is in the + * current session's global bindings. If so, returns the value associated with that binding. Otherwise, returns missing. + */ +private class OutputBindingProcedure(val valueFactory: ExprValueFactory): StoredProcedure { + override val signature = StoredProcedureSignature("output_binding", 1) + + override fun call(session: EvaluationSession, args: List): ExprValue { + val arg = args.first() + if (arg.type != ExprValueType.STRING) { + val errorContext = createWrongSProcErrorContext(arg, "STRING", signature.name) + throw EvaluationException("invalid first argument", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + errorContext, + internal = false) + } + val bindingName = BindingName(arg.stringValue(), BindingCase.INSENSITIVE) + return when(val value = session.globals[bindingName]) { + null -> valueFactory.missingValue + else -> value + } + } +} + +class EvaluatingCompilerExecTest : EvaluatorTestBase() { + private val session = mapOf("A" to "[ { id : 1 } ]").toSession() + + /** + * Custom [CompilerPipeline] w/ additional stored procedures + */ + private val pipeline = CompilerPipeline.build(ion) { + addProcedure(OverriddenZeroArgProcedure(valueFactory)) + addProcedure(ZeroArgProcedure(valueFactory)) + addProcedure(OneArgProcedure(valueFactory)) + addProcedure(TwoArgProcedure(valueFactory)) + addProcedure(OutputBindingProcedure(valueFactory)) + } + + /** + * Runs the given [query] with the provided [session] using the custom [CompilerPipeline] with additional stored + * procedures to query. + */ + private fun evalSProc(query: String, session: EvaluationSession): ExprValue { + val e = pipeline.compile(query) + return e.eval(session) + } + + /** + * Similar to [EvaluatorTestBase]'s [runTestCase], but evaluates using a [CompilerPipeline] with added stored + * procedures. + */ + private fun runSProcTestCase(tc: EvaluatorTestCase, session: EvaluationSession) { + val queryExprValue = evalSProc(tc.sqlUnderTest, session) + val expectedExprValue = evalSProc(tc.expectedSql, session) + + if(!expectedExprValue.exprEquals(queryExprValue)) { + println("Expected ionValue : ${expectedExprValue.ionValue}") + println("Actual ionValue : ${queryExprValue.ionValue}") + + fail("Expected and actual ExprValue instances are not equivalent") + } + } + + /** + * Similar to [EvaluatorTestBase]'s [checkInputThrowingEvaluationException], but evaluates using a + * [CompilerPipeline] with added stored procedures. + */ + private fun checkInputThrowingEvaluationExceptionSProc(tc: EvaluatorErrorTestCase, session: EvaluationSession) { + softAssert { + try { + val result = evalSProc(tc.sqlUnderTest, session = session).ionValue; + fail("Expected EvaluationException but there was no Exception. " + + "The unexpected result was: \n${result.toPrettyString()}") + } + catch (e: EvaluationException) { + if (tc.cause != null) assertThat(e).hasRootCauseExactlyInstanceOf(tc.cause.java) + checkErrorAndErrorContext(tc.errorCode, e, tc.expectErrorContextValues) + } + catch (e: Exception) { + fail("Expected EvaluationException but a different exception was thrown:\n\t $e") + } + } + } + + private class ArgsProviderValid : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // OverriddenZeroArgProcedure w/ same name as ZeroArgProcedure overridden + EvaluatorTestCase( + "EXEC zero_arg_procedure", + "0"), + EvaluatorTestCase( + "EXEC one_arg_procedure 1", + "1"), + EvaluatorTestCase( + "EXEC two_arg_procedure 1, 2", + "'1 2'"), + EvaluatorTestCase( + "EXEC output_binding 'A'", + "[{'id':1}]"), + EvaluatorTestCase( + "EXEC output_binding 'B'", + "MISSING")) + } + + @ParameterizedTest + @ArgumentsSource(ArgsProviderValid::class) + fun validTests(tc: EvaluatorTestCase) = runSProcTestCase(tc, session) + + + private class ArgsProviderError : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // call function that is not a stored procedure + EvaluatorErrorTestCase( + "EXEC utcnow", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.PROCEDURE_NAME to "utcnow")), + // call function that is not a stored procedure, w/ args + EvaluatorErrorTestCase( + "EXEC substring 0, 1, 'foo'", + ErrorCode.EVALUATOR_NO_SUCH_PROCEDURE, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.PROCEDURE_NAME to "substring")), + // invalid # args to sproc (too many) + EvaluatorErrorTestCase( + "EXEC zero_arg_procedure 1", + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.EXPECTED_ARITY_MIN to 0, + Property.EXPECTED_ARITY_MAX to 0)), + // invalid # args to sproc (too many) + EvaluatorErrorTestCase( + "EXEC two_arg_procedure 1, 2, 3", + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.EXPECTED_ARITY_MIN to 2, + Property.EXPECTED_ARITY_MAX to 2)), + // invalid # args to sproc (too few) + EvaluatorErrorTestCase( + "EXEC one_arg_procedure", + ErrorCode.EVALUATOR_INCORRECT_NUMBER_OF_ARGUMENTS_TO_PROCEDURE_CALL, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.EXPECTED_ARITY_MIN to 1, + Property.EXPECTED_ARITY_MAX to 1)), + // invalid first arg type + EvaluatorErrorTestCase( + "EXEC one_arg_procedure 'foo'", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.EXPECTED_ARGUMENT_TYPES to "INT", + Property.ACTUAL_ARGUMENT_TYPES to "STRING", + Property.FUNCTION_NAME to "one_arg_procedure")), + // invalid second arg type + EvaluatorErrorTestCase( + "EXEC two_arg_procedure 1, 'two'", + ErrorCode.EVALUATOR_INCORRECT_TYPE_OF_ARGUMENTS_TO_PROCEDURE_CALL, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 6L, + Property.EXPECTED_ARGUMENT_TYPES to "INT", + Property.ACTUAL_ARGUMENT_TYPES to "STRING", + Property.FUNCTION_NAME to "two_arg_procedure")) + ) + } + + @ParameterizedTest + @ArgumentsSource(ArgsProviderError::class) + fun errorTests(tc: EvaluatorErrorTestCase) = checkInputThrowingEvaluationExceptionSProc(tc, session) +} diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt index 39774579da..3b13338770 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt @@ -14,6 +14,8 @@ package org.partiql.lang.syntax +import com.amazon.ion.Decimal +import com.amazon.ionelement.api.ionDecimal import com.amazon.ionelement.api.ionInt import com.amazon.ionelement.api.ionString import org.junit.Test @@ -3130,4 +3132,53 @@ class SqlParserTest : SqlParserTestBase() { fromLet = let(letBinding(call("foo", listOf(id("table1"))), "A")) ) } + + //**************************************** + // EXEC clause parsing + //**************************************** + @Test + fun execNoArgs() = assertExpression( + "EXEC foo") { + exec("foo", emptyList()) + } + + @Test + fun execOneStringArg() = assertExpression( + "EXEC foo 'bar'") { + exec("foo", listOf(lit(ionString("bar")))) + } + + @Test + fun execOneIntArg() = assertExpression( + "EXEC foo 1") { + exec("foo", listOf(lit(ionInt(1)))) + } + + @Test + fun execMultipleArg() = assertExpression( + "EXEC foo 'bar0', `1d0`, 2, [3]") { + exec("foo", listOf(lit(ionString("bar0")), lit(ionDecimal(Decimal.valueOf(1))), lit(ionInt(2)), list(lit(ionInt(3))))) + } + + @Test + fun execWithMissing() = assertExpression( + "EXEC foo MISSING") { + exec("foo", listOf(missing())) + } + + @Test + fun execWithBag() = assertExpression( + "EXEC foo <<1>>") { + exec("foo", listOf(bag(lit(ionInt(1))))) + } + + @Test + fun execWithSelectQuery() = assertExpression( + "EXEC foo SELECT baz FROM bar") { + exec("foo", listOf( + select( + project = projectList(projectExpr(id("baz"))), + from = scan(id("bar")) + ))) + } } diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt index d1a0fe9862..c3bed9ea61 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt @@ -103,9 +103,10 @@ abstract class SqlParserTestBase : TestBase() { */ private fun unwrapQuery(statement: PartiqlAst.Statement) : SexpElement { return when (statement) { - is PartiqlAst.Statement.Query -> statement.expr.toIonElement() - is PartiqlAst.Statement.Dml, - is PartiqlAst.Statement.Ddl -> statement.toIonElement() + is PartiqlAst.Statement.Query -> statement.expr.toIonElement() + is PartiqlAst.Statement.Dml, + is PartiqlAst.Statement.Ddl, + is PartiqlAst.Statement.Exec -> statement.toIonElement() } }