diff --git a/src/core/parser/duckpgq_parser.cpp b/src/core/parser/duckpgq_parser.cpp index c29d0e8..245115c 100644 --- a/src/core/parser/duckpgq_parser.cpp +++ b/src/core/parser/duckpgq_parser.cpp @@ -58,66 +58,75 @@ void duckpgq_find_match_function(TableRef *table_ref, duckpgq_find_match_function(join_ref->right.get(), duckpgq_state); } else if (auto subquery_ref = dynamic_cast(table_ref)) { // Handle SubqueryRef case - duckpgq_handle_statement(subquery_ref->subquery.get(), duckpgq_state); + auto subquery = subquery_ref->subquery.get(); + duckpgq_find_select_statement(subquery, duckpgq_state); } } -ParserExtensionPlanResult -duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { - if (statement->type == StatementType::SELECT_STATEMENT) { - const auto select_statement = dynamic_cast(statement); - auto node = dynamic_cast(select_statement->node.get()); - CTENode *cte_node = nullptr; - - // Check if node is not a SelectNode - if (!node) { - // Attempt to cast to CTENode - cte_node = dynamic_cast(select_statement->node.get()); - if (cte_node) { - // Get the child node as a SelectNode if cte_node is valid - node = dynamic_cast(cte_node->child.get()); - } +ParserExtensionPlanResult duckpgq_find_select_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { + const auto select_statement = dynamic_cast(statement); + auto node = dynamic_cast(select_statement->node.get()); + CTENode *cte_node = nullptr; + + // Check if node is not a SelectNode + if (!node) { + // Attempt to cast to CTENode + cte_node = dynamic_cast(select_statement->node.get()); + if (cte_node) { + // Get the child node as a SelectNode if cte_node is valid + node = dynamic_cast(cte_node->child.get()); } + } - // Check if node is a ShowRef - if (node) { - const auto describe_node = - dynamic_cast(node->from_table.get()); - if (describe_node) { - ParserExtensionPlanResult result; - result.function = DescribePropertyGraphFunction(); - result.requires_valid_transaction = true; - result.return_type = StatementReturnType::QUERY_RESULT; - return result; - } + // Check if node is a ShowRef + if (node) { + const auto describe_node = + dynamic_cast(node->from_table.get()); + if (describe_node) { + ParserExtensionPlanResult result; + result.function = DescribePropertyGraphFunction(); + result.requires_valid_transaction = true; + result.return_type = StatementReturnType::QUERY_RESULT; + return result; } + } - // Collect CTE keys - vector cte_keys; - if (node) { - cte_keys = node->cte_map.map.Keys(); - } else if (cte_node) { - cte_keys = cte_node->cte_map.map.Keys(); + // Collect CTE keys + vector cte_keys; + if (node) { + cte_keys = node->cte_map.map.Keys(); + } else if (cte_node) { + cte_keys = cte_node->cte_map.map.Keys(); + } + for (auto &key : cte_keys) { + auto cte = node->cte_map.map.find(key); + auto cte_select_statement = + dynamic_cast(cte->second->query.get()); + if (cte_select_statement == nullptr) { + continue; // Skip non-select statements } - for (auto &key : cte_keys) { - auto cte = node->cte_map.map.find(key); - auto cte_select_statement = - dynamic_cast(cte->second->query.get()); - if (cte_select_statement == nullptr) { - continue; // Skip non-select statements - } - auto cte_node = - dynamic_cast(cte_select_statement->node.get()); - if (cte_node) { - duckpgq_find_match_function(cte_node->from_table.get(), duckpgq_state); - } + auto cte_node = + dynamic_cast(cte_select_statement->node.get()); + if (cte_node) { + duckpgq_find_match_function(cte_node->from_table.get(), duckpgq_state); } - if (node) { - duckpgq_find_match_function(node->from_table.get(), duckpgq_state); - } else { - throw Exception(ExceptionType::INTERNAL, "node is a nullptr."); + } + if (node) { + duckpgq_find_match_function(node->from_table.get(), duckpgq_state); + } else { + throw Exception(ExceptionType::INTERNAL, "node is a nullptr."); + } + return {}; +} + +ParserExtensionPlanResult +duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state) { + if (statement->type == StatementType::SELECT_STATEMENT) { + auto result = duckpgq_find_select_statement(statement, duckpgq_state); + if (result.function.bind == nullptr) { + throw Exception(ExceptionType::BINDER, "use duckpgq_bind instead"); } - throw Exception(ExceptionType::BINDER, "use duckpgq_bind instead"); + return result; } if (statement->type == StatementType::CREATE_STATEMENT) { const auto &create_statement = statement->Cast(); diff --git a/src/include/duckpgq/core/parser/duckpgq_parser.hpp b/src/include/duckpgq/core/parser/duckpgq_parser.hpp index 891bc6b..52ef65d 100644 --- a/src/include/duckpgq/core/parser/duckpgq_parser.hpp +++ b/src/include/duckpgq/core/parser/duckpgq_parser.hpp @@ -26,6 +26,9 @@ ParserExtensionPlanResult duckpgq_plan(ParserExtensionInfo *info, ClientContext &, unique_ptr); +ParserExtensionPlanResult duckpgq_find_select_statement( + SQLStatement *statement, DuckPGQState &duckpgq_state); + ParserExtensionPlanResult duckpgq_handle_statement(SQLStatement *statement, DuckPGQState &duckpgq_state);