From a04277cff7b059f32b246ed12c0f13f57b0be896 Mon Sep 17 00:00:00 2001 From: Wang Zhiyong Date: Sat, 21 Dec 2024 10:48:20 +0800 Subject: [PATCH] Merge OSPP PR (#828) * Fix test case error * update * Query plan cache for cypher queries. (#676) * plan cache codebase. * plan cache codebase. * split plan_cache.h and plan_cache_param. * integrate plan_cache into execution process. * add test cases for parameterized query execution. * fail direction * plan cache codebase * add more pattern for fastQueryParam * add more pattern for fastQueryParam * fix lint error * remove buildit deps. * fix bug in cypher visitor * remove unused dir --------- Co-authored-by: Ke Huang <569078986@qq.com> * basic columnar-based data structure (#682) * basic columnar-based data structure * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType --------- Co-authored-by: Shipeng Qi * Column record (#683) * basic columnar-based data structure * columnar data record structure * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType * support FieldType * format fix --------- Co-authored-by: yannan-wyn <129476350+yannan-wyn@users.noreply.github.com> * Tugraph supporting column based data done (#685) * basic columnar-based data structure * columnar data record structure * operators support column data * modify runtime_context * db support columnar-data done * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType * support FieldType * initialize ColumnVector with FieldType * format fix * delete comments, adjust indent * another formatting * modify according to reviewer's comments * add username after TODO * TODO (username) * TODO * TODO * fix potential memory leaking * delete unlikely definition and usage * make BATCH_SIZE an independent config * fix coding standard * remove whitespace * static cast FLAGS_BATCH_SIZE --------- Co-authored-by: yannan-wyn <129476350+yannan-wyn@users.noreply.github.com> * Add column related testing and benchmark (#686) * basic columnar-based data structure * columnar data record structure * operators support column data * modify runtime_context * db support columnar-data done * column related testing and benchmark * remove unnecessary function in cypher _string_t * improve resizeOverflowBuffer * add FieldType * support FieldType * initialize ColumnVector with FieldType * format fix * delete comments, adjust indent * another formatting * modify according to reviewer's comments * add username after TODO * TODO (username) * TODO * TODO * fix potential memory leaking * delete unlikely definition and usage * add testing for bitmask * make BATCH_SIZE an independent config * fix coding standard * remove whitespace * static cast FLAGS_BATCH_SIZE * delete wrong assert * rm reference in moveDataChunk --------- Co-authored-by: yannan-wyn <129476350+yannan-wyn@users.noreply.github.com> * Query compilation framework. (#687) * compilation execution framework. * execution framework. * remove static_var usage in data structures. * test framework. * generate code on current directory * delete files after executions. * fix bugs in test framework. * LLVM framework * LLVM backend * fix compilation error. * fix lint error. * fix lint error. * fix lint error. --------- Co-authored-by: Ke Huang <569078986@qq.com> * revert --------- Co-authored-by: RT_Enzyme <52275903001@stu.ecnu.edu.cn> Co-authored-by: Ke Huang <569078986@qq.com> Co-authored-by: Myrrolinz Co-authored-by: Shipeng Qi Co-authored-by: yannan-wyn <129476350+yannan-wyn@users.noreply.github.com> --- ci/images/tugraph-runtime-centos7-Dockerfile | 2 +- .../ops/op_all_node_scan_col.cpp | 15 + .../execution_plan/ops/op_all_node_scan_col.h | 146 +++++++ src/cypher/execution_plan/ops/op_config.cpp | 20 + src/cypher/execution_plan/ops/op_config.h | 21 + .../execution_plan/ops/op_limit_col.cpp | 15 + src/cypher/execution_plan/ops/op_limit_col.h | 68 ++++ .../ops/op_produce_results_col.cpp | 15 + .../ops/op_produce_results_col.h | 199 ++++++++++ .../execution_plan/ops/op_project_col.cpp | 15 + .../execution_plan/ops/op_project_col.h | 170 ++++++++ .../execution_plan/plan_cache/plan_cache.cpp | 14 + .../execution_plan/plan_cache/plan_cache.h | 108 +++++ .../plan_cache/plan_cache_param.cpp | 165 ++++++++ .../plan_cache/plan_cache_param.h | 28 ++ .../experimental/data_type/field_data.h | 214 ++++++++++ src/cypher/experimental/data_type/record.h | 92 +++++ src/cypher/experimental/expressions/cexpr.cpp | 16 + src/cypher/experimental/expressions/cexpr.h | 238 +++++++++++ .../expressions/kernal/binary.cpp | 200 ++++++++++ src/cypher/experimental/jit/TuJIT.cpp | 21 + src/cypher/experimental/jit/TuJIT.h | 95 +++++ src/cypher/resultset/bit_mask.h | 221 +++++++++++ src/cypher/resultset/column_vector.h | 325 +++++++++++++++ src/cypher/resultset/cypher_string_t.h | 66 +++ test/QueryTester.cpp | 293 ++++++++++++++ test/QueryTester.h | 100 +++++ .../vector_index/cypher/vector_index.result | 2 +- test/test_bit_mask.cpp | 160 ++++++++ test/test_column_vector.cpp | 163 ++++++++ test/test_plan_cache.cpp | 59 +++ test/test_query_benchmark.cpp | 76 ++++ test/test_query_col.cpp | 375 ++++++++++++++++++ test/test_query_compilation.cpp | 128 ++++++ toolkits/CMakeLists.txt | 14 + toolkits/lgraph_compilation.cpp | 72 ++++ 36 files changed, 3929 insertions(+), 2 deletions(-) create mode 100644 src/cypher/execution_plan/ops/op_all_node_scan_col.cpp create mode 100644 src/cypher/execution_plan/ops/op_all_node_scan_col.h create mode 100644 src/cypher/execution_plan/ops/op_config.cpp create mode 100644 src/cypher/execution_plan/ops/op_config.h create mode 100644 src/cypher/execution_plan/ops/op_limit_col.cpp create mode 100644 src/cypher/execution_plan/ops/op_limit_col.h create mode 100644 src/cypher/execution_plan/ops/op_produce_results_col.cpp create mode 100644 src/cypher/execution_plan/ops/op_produce_results_col.h create mode 100644 src/cypher/execution_plan/ops/op_project_col.cpp create mode 100644 src/cypher/execution_plan/ops/op_project_col.h create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache.cpp create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache.h create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache_param.cpp create mode 100644 src/cypher/execution_plan/plan_cache/plan_cache_param.h create mode 100644 src/cypher/experimental/data_type/field_data.h create mode 100644 src/cypher/experimental/data_type/record.h create mode 100644 src/cypher/experimental/expressions/cexpr.cpp create mode 100644 src/cypher/experimental/expressions/cexpr.h create mode 100644 src/cypher/experimental/expressions/kernal/binary.cpp create mode 100644 src/cypher/experimental/jit/TuJIT.cpp create mode 100644 src/cypher/experimental/jit/TuJIT.h create mode 100644 src/cypher/resultset/bit_mask.h create mode 100644 src/cypher/resultset/column_vector.h create mode 100644 src/cypher/resultset/cypher_string_t.h create mode 100644 test/QueryTester.cpp create mode 100644 test/QueryTester.h create mode 100644 test/test_bit_mask.cpp create mode 100644 test/test_column_vector.cpp create mode 100644 test/test_plan_cache.cpp create mode 100644 test/test_query_benchmark.cpp create mode 100644 test/test_query_col.cpp create mode 100644 test/test_query_compilation.cpp create mode 100644 toolkits/lgraph_compilation.cpp diff --git a/ci/images/tugraph-runtime-centos7-Dockerfile b/ci/images/tugraph-runtime-centos7-Dockerfile index ca85ba6dc5..4541e83ada 100644 --- a/ci/images/tugraph-runtime-centos7-Dockerfile +++ b/ci/images/tugraph-runtime-centos7-Dockerfile @@ -12,7 +12,7 @@ RUN sed -e "s|^mirrorlist=|#mirrorlist=|g" \ RUN yum install -y \ libgfortran5.x86_64 \ libgomp \ - libcurl-devel.x86_64 && yum clean all + wget && yum clean all # install tugraph # specifies the path of the object storage where the installation package resides diff --git a/src/cypher/execution_plan/ops/op_all_node_scan_col.cpp b/src/cypher/execution_plan/ops/op_all_node_scan_col.cpp new file mode 100644 index 0000000000..8e79f5c1aa --- /dev/null +++ b/src/cypher/execution_plan/ops/op_all_node_scan_col.cpp @@ -0,0 +1,15 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/execution_plan/ops/op_all_node_scan_col.h" diff --git a/src/cypher/execution_plan/ops/op_all_node_scan_col.h b/src/cypher/execution_plan/ops/op_all_node_scan_col.h new file mode 100644 index 0000000000..d800902fa9 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_all_node_scan_col.h @@ -0,0 +1,146 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "cypher/execution_plan/ops/op.h" +#include "cypher/resultset/column_vector.h" +#include "cypher/execution_plan/ops/op_config.h" + +namespace cypher { + +class AllNodeScanCol : public OpBase { + /* NOTE: Nodes in pattern graph are stored in std::vector, whose reference + * will become INVALID after reallocation. + * TODO(anyone) Make sure not add nodes to the pattern graph, otherwise use NodeId instead. */ + friend class LocateNodeByVid; + friend class LocateNodeByVidV2; + friend class LocateNodeByIndexedProp; + friend class LocateNodeByIndexedPropV2; + + Node *node_ = nullptr; + lgraph::VIter *it_ = nullptr; // also can be derived from node + std::string alias_; // also can be derived from node + std::string label_; // also can be derived from node + int node_rec_idx_; // index of node in record + int rec_length_; // number of entries in a record. + const SymbolTable *sym_tab_ = nullptr; // build time context + bool consuming_ = false; // whether begin consuming + + public: + AllNodeScanCol(Node *node, const SymbolTable *sym_tab) + : OpBase(OpType::ALL_NODE_SCAN, "All Node Scan"), node_(node), sym_tab_(sym_tab) { + if (node) { + it_ = node->ItRef(); + alias_ = node->Alias(); + modifies.emplace_back(alias_); + } + auto it = sym_tab->symbols.find(alias_); + CYPHER_THROW_ASSERT(node && it != sym_tab->symbols.end()); + if (it != sym_tab->symbols.end()) node_rec_idx_ = it->second.id; + rec_length_ = sym_tab->symbols.size(); + consuming_ = false; + } + + OpResult Initialize(RTContext *ctx) override { + // allocate a new record + record = std::make_shared(rec_length_, sym_tab_, ctx->param_tab_); + record->values[node_rec_idx_].type = Entry::NODE; + record->values[node_rec_idx_].node = node_; + // transaction allocated before in plan:execute + // TODO(anyone) remove patternGraph's state (ctx) + node_->ItRef()->Initialize(ctx->txn_->GetTxn().get(), lgraph::VIter::VERTEX_ITER); + return OP_OK; + } + + OpResult RealConsume(RTContext *ctx) override { + uint32_t count = 0; + columnar_ = std::make_shared(); + while (count < FLAGS_BATCH_SIZE) { + node_->SetVid(-1); + if (!it_ || !it_->IsValid()) return (count > 0) ? OP_OK : OP_DEPLETED; + if (!consuming_) { + consuming_ = true; + } else { + it_->Next(); + if (!it_->IsValid()) { + return (count > 0) ? OP_OK : OP_DEPLETED; + } + } + int64_t vid = it_->GetId(); + for (auto& property : node_->ItRef()->GetFields()) { + const std::string& property_name = property.first; + const lgraph_api::FieldData& field = property.second; + if (field.type == lgraph_api::FieldType::STRING) { + if (columnar_->string_columns_.find(property_name) == + columnar_->string_columns_.end()) { + columnar_->string_columns_[property_name] = + std::make_unique(sizeof(cypher_string_t), + FLAGS_BATCH_SIZE, field.type); + columnar_->property_positions_[property_name] = 0; + } + columnar_->property_vids_[property_name].push_back(vid); + uint32_t pos = columnar_->property_positions_[property_name]++; + StringColumn::AddString( + columnar_->string_columns_[property_name].get(), pos, + field.AsString().c_str(), field.AsString().size()); + } else { + if (columnar_->columnar_data_.find(property_name) == + columnar_->columnar_data_.end()) { + size_t element_size = ColumnVector::GetFieldSize(field.type); + columnar_->columnar_data_[property_name] = + std::make_unique(element_size, + FLAGS_BATCH_SIZE, field.type); + columnar_->property_positions_[property_name] = 0; + } + columnar_->property_vids_[property_name].push_back(vid); + uint32_t pos = columnar_->property_positions_[property_name]++; + ColumnVector::InsertIntoColumnVector( + columnar_->columnar_data_[property_name].get(), field, pos); + } + } + + count++; + } + return OP_OK; + } + + OpResult ResetImpl(bool complete) override { + consuming_ = false; + if (complete) { + // undo method initialize() + record = nullptr; + // TODO(anyone) cleaned in ExecutionPlan::Execute + if (it_ && it_->Initialized()) it_->FreeIter(); + } else { + if (it_ && it_->Initialized()) it_->Reset(); + } + return OP_OK; + } + + std::string ToString() const override { + std::string str(name); + str.append(" [").append(alias_).append("]"); + return str; + } + + Node *GetNode() const { return node_; } + + const SymbolTable *SymTab() const { return sym_tab_; } + + CYPHER_DEFINE_VISITABLE() + + CYPHER_DEFINE_CONST_VISITABLE() +}; +} // namespace cypher diff --git a/src/cypher/execution_plan/ops/op_config.cpp b/src/cypher/execution_plan/ops/op_config.cpp new file mode 100644 index 0000000000..bca34fcd5d --- /dev/null +++ b/src/cypher/execution_plan/ops/op_config.cpp @@ -0,0 +1,20 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + + +#include "cypher/execution_plan/ops/op_config.h" + +namespace cypher { +DEFINE_int64(BATCH_SIZE, 32, "The batch size for processing"); +} diff --git a/src/cypher/execution_plan/ops/op_config.h b/src/cypher/execution_plan/ops/op_config.h new file mode 100644 index 0000000000..36791f9798 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_config.h @@ -0,0 +1,21 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + + +#pragma once +#include + +namespace cypher { +DECLARE_int64(BATCH_SIZE); +} diff --git a/src/cypher/execution_plan/ops/op_limit_col.cpp b/src/cypher/execution_plan/ops/op_limit_col.cpp new file mode 100644 index 0000000000..8b3ebb4bfe --- /dev/null +++ b/src/cypher/execution_plan/ops/op_limit_col.cpp @@ -0,0 +1,15 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/execution_plan/ops/op_limit_col.h" diff --git a/src/cypher/execution_plan/ops/op_limit_col.h b/src/cypher/execution_plan/ops/op_limit_col.h new file mode 100644 index 0000000000..b76d13aa08 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_limit_col.h @@ -0,0 +1,68 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "cypher/execution_plan/ops/op.h" +#include "cypher/execution_plan/ops/op_config.h" + +namespace cypher { + +class LimitCol : public OpBase { + friend class LazyProjectTopN; + size_t limit_ = 0; // Max number of records to consume. + size_t consumed_ = 0; // Number of records consumed so far. + + public: + explicit LimitCol(size_t limit) : OpBase(OpType::LIMIT, "Limit"), limit_(limit) {} + + OpResult Initialize(RTContext *ctx) override { + CYPHER_THROW_ASSERT(!children.empty()); + auto &child = children[0]; + auto res = child->Initialize(ctx); + if (res != OP_OK) return res; + columnar_ = std::make_shared(); + record = child->record; + return OP_OK; + } + + OpResult RealConsume(RTContext *ctx) override { + if (consumed_ >= limit_) return OP_DEPLETED; + CYPHER_THROW_ASSERT(!children.empty()); + auto &child = children[0]; + auto res = child->Consume(ctx); + columnar_ = child->columnar_; + int usable_r = std::min(static_cast(FLAGS_BATCH_SIZE), + limit_ - consumed_); + columnar_->TruncateData(usable_r); + consumed_ += usable_r; + return res; + } + + OpResult ResetImpl(bool complete) override { + consumed_ = 0; + return OP_OK; + } + + std::string ToString() const override { + std::string str(name); + str.append(" [").append(std::to_string(limit_)).append("]"); + return str; + } + + CYPHER_DEFINE_VISITABLE() + + CYPHER_DEFINE_CONST_VISITABLE() +}; +} // namespace cypher diff --git a/src/cypher/execution_plan/ops/op_produce_results_col.cpp b/src/cypher/execution_plan/ops/op_produce_results_col.cpp new file mode 100644 index 0000000000..2e7676a502 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_produce_results_col.cpp @@ -0,0 +1,15 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/execution_plan/ops/op_produce_results_col.h" diff --git a/src/cypher/execution_plan/ops/op_produce_results_col.h b/src/cypher/execution_plan/ops/op_produce_results_col.h new file mode 100644 index 0000000000..15c840855a --- /dev/null +++ b/src/cypher/execution_plan/ops/op_produce_results_col.h @@ -0,0 +1,199 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "cypher/execution_plan/ops/op.h" +#include "lgraph/lgraph_result.h" +#include "lgraph/lgraph_types.h" +#include "lgraph_api/result_element.h" +#include "resultset/record.h" +#include "server/json_convert.h" +#include "server/bolt_session.h" +#include "boost/regex.hpp" +#include "cypher/execution_plan/ops/op_produce_results.h" + +namespace cypher { + +class ProduceResultsCol : public OpBase { + enum { + Uninitialized, + RefreshAfterPass, + Resetted, + Consuming, + } state_; + lgraph_api::NODEMAP node_map_; + lgraph_api::RELPMAP relp_map_; + std::shared_ptr final_r; + + public: + ProduceResultsCol() : OpBase(OpType::PRODUCE_RESULTS, "Produce Results") { + state_ = Uninitialized; + } + + OpResult Initialize(RTContext *ctx) override { + if (!children.empty()) { + children[0]->Initialize(ctx); + } + columnar_ = std::make_shared(); + final_r = std::make_shared(); + return OP_OK; + } + + /* ProduceResults next operation + * called each time a new result record is required */ + OpResult RealConsume(RTContext *ctx) override { + if (state_ == Uninitialized) { + Initialize(ctx); + state_ = Consuming; + } + if (children.empty()) return OP_DEPLETED; + if (ctx->bolt_conn_) { + if (ctx->bolt_conn_->has_closed()) { + LOG_INFO() << "The bolt connection is closed, cancel the op execution."; + return OP_ERR; + } + auto session = (bolt::BoltSession *)ctx->bolt_conn_->GetContext(); + while (session->state == bolt::SessionState::STREAMING && !session->streaming_msg) { + session->streaming_msg = session->msgs.Pop(std::chrono::milliseconds(100)); + if (ctx->bolt_conn_->has_closed()) { + LOG_INFO() << "The bolt connection is closed, cancel the op execution."; + return OP_ERR; + } + if (!session->streaming_msg) { + continue; + } + if (session->streaming_msg.value().type == bolt::BoltMsg::PullN || + session->streaming_msg.value().type == bolt::BoltMsg::DiscardN) { + const auto &fields = session->streaming_msg.value().fields; + if (fields.size() != 1) { + std::string err = + FMA_FMT("{} msg fields size error, size: {}", + bolt::ToString(session->streaming_msg.value().type).c_str(), + fields.size()); + LOG_ERROR() << err; + bolt::PackStream ps; + ps.AppendFailure({{"code", "error"}, {"message", err}}); + ctx->bolt_conn_->PostResponse(std::move(ps.MutableBuffer())); + session->state = bolt::SessionState::FAILED; + return OP_ERR; + } + auto &val = + std::any_cast &>(fields[0]); + auto n = std::any_cast(val.at("n")); + session->streaming_msg.value().n = n; + } else if (session->streaming_msg.value().type == bolt::BoltMsg::Reset) { + LOG_INFO() << "Receive RESET, cancel the op execution."; + bolt::PackStream ps; + ps.AppendSuccess(); + ctx->bolt_conn_->PostResponse(std::move(ps.MutableBuffer())); + session->state = bolt::SessionState::READY; + return OP_ERR; + } else { + LOG_ERROR() << FMA_FMT( + "Unexpected msg:{} in STREAMING state, cancel the op execution, " + "close the connection.", + bolt::ToString(session->streaming_msg.value().type)); + ctx->bolt_conn_->Close(); + return OP_ERR; + } + break; + } + if (session->state == bolt::SessionState::INTERRUPTED) { + LOG_WARN() << "The session state is INTERRUPTED, cancel the op execution."; + return OP_ERR; + } else if (session->state != bolt::SessionState::STREAMING) { + LOG_ERROR() << "Unexpected state: {} in op execution, close the connection."; + ctx->bolt_conn_->Close(); + return OP_ERR; + } else if (session->streaming_msg.value().type != bolt::BoltMsg::PullN && + session->streaming_msg.value().type != bolt::BoltMsg::DiscardN) { + LOG_ERROR() << FMA_FMT("Unexpected msg: {} in op execution, " + "cancel the op execution, close the connection.", + bolt::ToString(session->streaming_msg.value().type)); + ctx->bolt_conn_->Close(); + return OP_ERR; + } + auto child = children[0]; + auto res = child->Consume(ctx); + if (res != OP_OK) { + if (ctx->result_->Size() > 0 && + session->streaming_msg.value().type == bolt::BoltMsg::PullN) { + session->ps.AppendRecords(ctx->result_->BoltRecords()); + } + session->ps.AppendSuccess(); + session->state = bolt::SessionState::READY; + ctx->bolt_conn_->PostResponse(std::move(session->ps.MutableBuffer())); + session->ps.Reset(); + return res; + } + if (session->streaming_msg.value().type == bolt::BoltMsg::PullN) { + auto record = ctx->result_->MutableRecord(); + RRecordToURecord(ctx->txn_.get(), ctx->result_->Header(), child->record, + *record, node_map_, relp_map_); + session->ps.AppendRecords(ctx->result_->BoltRecords()); + ctx->result_->ClearRecords(); + bool sync = false; + if (--session->streaming_msg.value().n == 0) { + std::unordered_map meta; + meta["has_more"] = true; + session->ps.AppendSuccess(meta); + session->state = bolt::SessionState::STREAMING; + session->streaming_msg.reset(); + sync = true; + } + if (sync || session->ps.ConstBuffer().size() > 1024) { + ctx->bolt_conn_->PostResponse(std::move(session->ps.MutableBuffer())); + session->ps.Reset(); + } + } else if (session->streaming_msg.value().type == bolt::BoltMsg::DiscardN) { + if (--session->streaming_msg.value().n == 0) { + std::unordered_map meta; + meta["has_more"] = true; + session->ps.AppendSuccess(meta); + session->state = bolt::SessionState::STREAMING; + session->streaming_msg.reset(); + ctx->bolt_conn_->PostResponse(std::move(session->ps.MutableBuffer())); + session->ps.Reset(); + } + } + return OP_OK; + } else { + auto child = children[0]; + auto res = child->Consume(ctx); + if (res != OP_OK) return res; + columnar_ = child->columnar_; + final_r->Append(*columnar_); + ctx->moveDataChunk(final_r); + return OP_OK; + } + } + + /* Restart */ + OpResult ResetImpl(bool complete) override { + if (complete) state_ = Uninitialized; + return OP_OK; + } + + std::string ToString() const override { + std::string str(name); + return str; + } + + CYPHER_DEFINE_VISITABLE() + + CYPHER_DEFINE_CONST_VISITABLE() +}; +} // namespace cypher diff --git a/src/cypher/execution_plan/ops/op_project_col.cpp b/src/cypher/execution_plan/ops/op_project_col.cpp new file mode 100644 index 0000000000..d6f5b935a3 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_project_col.cpp @@ -0,0 +1,15 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/execution_plan/ops/op_project_col.h" diff --git a/src/cypher/execution_plan/ops/op_project_col.h b/src/cypher/execution_plan/ops/op_project_col.h new file mode 100644 index 0000000000..557572b199 --- /dev/null +++ b/src/cypher/execution_plan/ops/op_project_col.h @@ -0,0 +1,170 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "parser/clause.h" +#include "arithmetic/arithmetic_expression.h" +#include "cypher/execution_plan/ops/op.h" + +namespace cypher { + +class ProjectCol : public OpBase { + friend class LazyProjectTopN; + const SymbolTable &sym_tab_; + std::vector return_elements_; + std::vector return_alias_; + std::vector aliases; + bool single_response_; + enum { + Uninitialized, + RefreshAfterPass, + Resetted, + Consuming, + } state_; + + /* Construct arithmetic expressions from return clause. */ + void _BuildArithmeticExpressions(const parser::QueryPart *stmt) { + const auto &return_body = stmt->return_clause ? std::get<1>(*stmt->return_clause) + : std::get<1>(*stmt->with_clause); + const auto &return_items = std::get<0>(return_body); + std::unordered_set distinct_alias; + for (auto &item : return_items) { + auto &expr = std::get<0>(item); + auto &var = std::get<1>(item); + ArithExprNode ae(expr, sym_tab_); + return_elements_.emplace_back(ae); + auto alias = var.empty() ? expr.ToString(false) : var; + if (distinct_alias.find(alias) != distinct_alias.end()) { + throw lgraph::CypherException("Duplicate alias: " + alias); + } + distinct_alias.emplace(alias); + return_alias_.emplace_back(alias); + if (!var.empty()) modifies.emplace_back(var); + } + } + + void ExtractAlias() { + // TODO(Myrrolinz): refactor this function to use ArithExprNode + std::string prefix = "n."; + for (const auto& i : return_alias_) { + std::size_t pos = i.find(prefix); + if (pos != std::string::npos) { + std::string alias = i.substr(pos + prefix.size()); + std::cout << "This is alias in Project: " << alias << std::endl; + aliases.push_back(alias); + } + } + } + + public: + ProjectCol(const parser::QueryPart *stmt, const SymbolTable *sym_tab) + : OpBase(OpType::PROJECT, "Project"), sym_tab_(*sym_tab) { + single_response_ = false; + state_ = Uninitialized; + _BuildArithmeticExpressions(stmt); + } + + ProjectCol(const std::vector> &items, + const SymbolTable *sym_tab) + : OpBase(OpType::PROJECT, "Project"), sym_tab_(*sym_tab) { + single_response_ = false; + state_ = Uninitialized; + for (const auto &[expr, var] : items) { + return_elements_.emplace_back(expr); + return_alias_.emplace_back(var.empty() ? expr.ToString() : var); + if (!var.empty()) modifies.emplace_back(var); + } + } + + OpResult Initialize(RTContext *ctx) override { + if (!children.empty()) { + auto &child = children[0]; + auto res = child->Initialize(ctx); + if (res != OP_OK) return res; + } + /* projection */ + columnar_ = std::make_shared(); + record = std::make_shared(return_elements_.size()); + ExtractAlias(); + return OP_OK; + } + + + OpResult RealConsume(RTContext *ctx) override { + OpResult res = OP_OK; + std::shared_ptr r; + + if (!children.empty()) { + CYPHER_THROW_ASSERT(children.size() == 1); + res = children[0]->Consume(ctx); + r = children[0]->record; + columnar_ = children[0]->columnar_; + } else { + // TODO(Myrrolinz): QUERY: RETURN 1+2 + // Return a single record followed by NULL + // on the second call. + if (single_response_) return OP_DEPLETED; + single_response_ = true; + r = std::make_shared(sym_tab_.symbols.size()); + } + if (res != OP_OK) return res; + if (!aliases.empty()) { + std::shared_ptr column_r = std::make_shared(); + for (auto &i : aliases) { + auto v = Evaluate(ctx, *columnar_, i); + column_r->MergeColumn(*v); + } + columnar_ = column_r; + } + return OP_OK; + } + + OpResult ResetImpl(bool complete) override { + if (complete) { + record = nullptr; + single_response_ = false; + state_ = Uninitialized; + } + return OP_OK; + } + + std::string ToString() const override { + std::string str(name); + str.append(" ["); + for (auto &i : return_alias_) { + str.append(i).append(","); + } + if (!return_alias_.empty()) str.pop_back(); + str.append("]"); + return str; + } + + std::shared_ptr Evaluate(RTContext *ctx, + const DataChunk &columnar_record, + const std::string &alias) { + std::shared_ptr projected_columnar_ = std::make_shared(); + projected_columnar_->CopyColumn(alias, columnar_record); + return projected_columnar_; + } + + const std::vector &ReturnElements() const { return return_elements_; } + + const std::vector &ReturnAlias() const { return return_alias_; } + + CYPHER_DEFINE_VISITABLE() + + CYPHER_DEFINE_CONST_VISITABLE() +}; +} // namespace cypher diff --git a/src/cypher/execution_plan/plan_cache/plan_cache.cpp b/src/cypher/execution_plan/plan_cache/plan_cache.cpp new file mode 100644 index 0000000000..96bcb31e18 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache.cpp @@ -0,0 +1,14 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/execution_plan/plan_cache/plan_cache.h" diff --git a/src/cypher/execution_plan/plan_cache/plan_cache.h b/src/cypher/execution_plan/plan_cache/plan_cache.h new file mode 100644 index 0000000000..4e50a99bec --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache.h @@ -0,0 +1,108 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#include "parser/clause.h" +#include "parser/data_typedef.h" + +namespace cypher { +class ASTCacheObj { + public: + std::vector stmts; + parser::CmdType cmd; + + ASTCacheObj() {} + + ASTCacheObj(const std::vector &stmts, parser::CmdType cmd) + : stmts(stmts), cmd(cmd) { + } + + std::vector Stmt() { + return stmts; + } + + parser::CmdType CmdType() { + return cmd; + } +}; + +template +class PlanCacheEntry { + public: + std::string key; + T value; + + PlanCacheEntry(const std::string &key, const T &value) : key(key), value(value) {} +}; + +template +class LRUPlanCache { + typedef PlanCacheEntry Entry; + std::list _item_list; + std::unordered_map _item_map; + size_t _max_size; + mutable std::shared_mutex _mutex; + inline void _KickOut() { + while (_item_map.size() > _max_size) { + auto last_it = _item_list.end(); + last_it--; + _item_map.erase(last_it->key); + _item_list.pop_back(); + } + } + + public: + explicit LRUPlanCache(size_t max_size) : _max_size(max_size) {} + + LRUPlanCache() : _max_size(512) {} + + void add_plan(std::string param_query, const Value &val) { + std::unique_lock lock(_mutex); + auto it = _item_map.find(param_query); + if (it == _item_map.end()) { + _item_list.emplace_front(std::move(param_query), val); + _item_map.emplace(_item_list.begin()->key, _item_list.begin()); + _KickOut(); + } else { + // Overwrite the cached value if the query is already present in the cache. + // And move the entry to the front of the list. + it->second->value = val; + _item_list.splice(_item_list.begin(), _item_list, it->second); + } + } + + // Get the cached value for the given parameterized query. Before calling this function, + // you MUST parameterize the query using the fastQueryParam(). + bool get_plan(const std::string ¶m_query, Value &val) { + // parameterized raw query + std::shared_lock lock(_mutex); + auto it = _item_map.find(param_query); + if (it == _item_map.end()) { + return false; + } + _item_list.splice(_item_list.begin(), _item_list, it->second); + val = it->second->value; + return true; + } + + size_t max_size() const { return _max_size; } + + size_t current_size() const { return _item_map.size(); } +}; + +typedef LRUPlanCache ASTCache; +} // namespace cypher diff --git a/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp b/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp new file mode 100644 index 0000000000..0bd98220a6 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache_param.cpp @@ -0,0 +1,165 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "cypher/execution_plan/plan_cache/plan_cache_param.h" + +namespace cypher { +std::string fastQueryParam(RTContext *ctx, const std::string query) { + /** + * We don't parameterize the queries or literals in query if: + * 1. The query is a CALL statement. + * 2. limit/skip `n`. + * 3. Range literals: ()->[e*..3]->(m) + * 4. the items in return body: return RETURN a,-2,9.78,'im a string' (@todo) + * 5. match ... create: MATCH (c {name:$0}) CREATE (p:Person {name:$1, birthyear:$2})-[r:BORN_IN]->(c) RETURN p,r,c + */ + antlr4::ANTLRInputStream input(query); + parser::LcypherLexer lexer(&input); + antlr4::CommonTokenStream token_stream(&lexer); + token_stream.fill(); + + std::vector tokens = token_stream.getTokens(); + size_t delete_size = 0; + std::string param_query = query; + + bool prev_limit_skip = false; + bool in_return_body = false; + bool prev_double_dots = false; // e*..3 + bool in_rel = false; // -[n]-> + int param_num = 0; + if (tokens[0]->getType() == parser::LcypherParser::CALL) { + // Don't parameterize plugin CALL statements + return query; + } + for (size_t i = 0; i < tokens.size(); i++) { + parser::Expression expr; + bool is_param; + switch (tokens[i]->getType()) { + case parser::LcypherParser::CREATE: { + // We don't parameterize the Create statements + // Remove the parsed parameters. + for (auto it = ctx->param_tab_.begin(); it!= ctx->param_tab_.end(); ) { + if (it->first[0] == '$' && std::isdigit(it->first[1])) { + it = ctx->param_tab_.erase(it); + } else { + ++it; + } + } + return query; + } + case parser::LcypherParser::T__13: { // '-' + size_t j = i; + while (++j < tokens.size() && tokens[j]->getType() == parser::LcypherParser::SP) { + } + if (j < tokens.size() && tokens[j]->getType() == parser::LcypherParser::T__7) { + in_rel = true; + } + i = j; + break; + } + case parser::LcypherParser::T__8: { // ']' + in_rel = false; + break; + } + case parser::LcypherParser::StringLiteral: { + // String literal + auto str = tokens[i]->getText(); + std::string res; + // remove escape character + for (size_t i = 1; i < str.length() - 1; i++) { + if (str[i] == '\\') { + i++; + } + res.push_back(str[i]); + } + expr.type = parser::Expression::STRING; + expr.data = std::make_shared(std::move(res)); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::HexInteger: + case parser::LcypherParser::DecimalInteger: + case parser::LcypherParser::OctalInteger: { + if (in_rel) { + // The integer literals in relationships are range literals. + // -[:HAS_CHILD*1..]-> + break; + } + if (prev_limit_skip || prev_double_dots) { + break; + } + // Integer literal + expr.type = parser::Expression::DataType::INT; + expr.data = std::stol(tokens[i]->getText()); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::ExponentDecimalReal: + case parser::LcypherParser::RegularDecimalReal: { + // Double literal + expr.type = parser::Expression::DataType::DOUBLE; + expr.data = std::stod(tokens[i]->getText()); + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::TRUE_: { + expr.type = parser::Expression::DataType::BOOL; + expr.data = true; + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::FALSE_: { + expr.type = parser::Expression::DataType::BOOL; + expr.data = false; + ctx->param_tab_.emplace("$" + std::to_string(param_num), MakeFieldData(expr)); + is_param = true; + break; + } + case parser::LcypherParser::RETURN: { + in_return_body = true; + break; + } + default: + break; + } + + // Replace the token with placeholder + if (is_param) { + if (!in_return_body) { + size_t start_index = tokens[i]->getStartIndex() - delete_size; + size_t end_index = tokens[i]->getStopIndex() - delete_size; + // Indicate the position in raw parameterized query + std::string count = "$" + std::to_string(param_num); + param_query.replace(start_index, end_index - start_index + 1, count); + delete_size += (end_index - start_index + 1) - count.size(); + param_num++; + } + is_param = false; + } + if (tokens[i]->getType() == parser::LcypherParser::LIMIT || + tokens[i]->getType() == parser::LcypherParser::L_SKIP) { + prev_limit_skip = true; + } else if (tokens[i]->getType() == parser::LcypherParser::T__11) { + prev_double_dots = true; + } else if (tokens[i]->getType() < parser::LcypherParser::SP) { + prev_limit_skip = false; + prev_double_dots = false; + } + } + return param_query; +} +} // namespace cypher diff --git a/src/cypher/execution_plan/plan_cache/plan_cache_param.h b/src/cypher/execution_plan/plan_cache/plan_cache_param.h new file mode 100644 index 0000000000..83410f7383 --- /dev/null +++ b/src/cypher/execution_plan/plan_cache/plan_cache_param.h @@ -0,0 +1,28 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "./antlr4-runtime.h" +#include "parser/generated/LcypherLexer.h" +#include "parser/generated/LcypherParser.h" + +#include "execution_plan/runtime_context.h" +#include "parser/clause.h" +#include "parser/expression.h" + +namespace cypher { + +// Leverage the lexer to parameterize queries +std::string fastQueryParam(RTContext *ctx, const std::string query); +} diff --git a/src/cypher/experimental/data_type/field_data.h b/src/cypher/experimental/data_type/field_data.h new file mode 100644 index 0000000000..853805ae52 --- /dev/null +++ b/src/cypher/experimental/data_type/field_data.h @@ -0,0 +1,214 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include +#include +#include "core/data_type.h" +#include "cypher/cypher_types.h" +#include "cypher/cypher_exception.h" + +using builder::static_var; +using builder::dyn_var; +using lgraph::FieldType; + +namespace cypher { +namespace compilation { + +struct CScalarData { + static constexpr const char* type_name = "CScalarData"; + std::variant< + std::monostate, // Represent the null state + dyn_var, + dyn_var, + dyn_var, + dyn_var, + dyn_var + > constant_; + + lgraph::FieldType type_; + + CScalarData() { + type_ = lgraph_api::FieldType::NUL; + } + + CScalarData(CScalarData &&data) + : constant_(std::move(data.constant_)), type_(data.type_) {} + + CScalarData(const CScalarData& other) + : constant_(other.constant_), type_(other.type_) {} + + explicit CScalarData(const lgraph::FieldData& other) { + type_ = other.type; + switch (other.type) { + case lgraph::FieldType::NUL: + constant_.emplace(); + break; + case lgraph::FieldType::INT64: + constant_.emplace>((int64_t)other.integer()); + break; + default: + CYPHER_TODO(); + } + } + + explicit CScalarData(int64_t integer) { + constant_.emplace>(integer); + type_ = lgraph::FieldType::INT64; + } + + explicit CScalarData(const static_var &integer) + : type_(FieldType::INT64) { + constant_ = (dyn_var) integer; + } + + explicit CScalarData(const dyn_var &integer) + : constant_(integer), type_(FieldType::INT64) {} + + explicit CScalarData(dyn_var&& integer) + : constant_(std::move(integer)), type_(FieldType::INT64) {} + + inline dyn_var integer() const { + switch (type_) { + case FieldType::NUL: + case FieldType::BOOL: + throw std::bad_cast(); + case FieldType::INT8: + return std::get>(constant_); + case FieldType::INT16: + return std::get>(constant_); + case FieldType::INT32: + return std::get>(constant_); + case FieldType::INT64: + return std::get>(constant_); + case FieldType::FLOAT: + case FieldType::DOUBLE: + case FieldType::DATE: + case FieldType::DATETIME: + case FieldType::STRING: + case FieldType::BLOB: + case FieldType::POINT: + case FieldType::LINESTRING: + case FieldType::POLYGON: + case FieldType::SPATIAL: + case FieldType::FLOAT_VECTOR: + throw std::bad_cast(); + } + return dyn_var(0); + } + + inline dyn_var real() const { + switch (type_) { + case FieldType::NUL: + case FieldType::BOOL: + case FieldType::INT8: + case FieldType::INT16: + case FieldType::INT32: + case FieldType::INT64: + throw std::bad_cast(); + case FieldType::FLOAT: + std::get>(constant_); + case FieldType::DOUBLE: + std::get>(constant_); + case FieldType::DATE: + case FieldType::DATETIME: + case FieldType::STRING: + case FieldType::BLOB: + case FieldType::POINT: + case FieldType::LINESTRING: + case FieldType::POLYGON: + case FieldType::SPATIAL: + case FieldType::FLOAT_VECTOR: + throw std::bad_cast(); + } + return dyn_var(0); + } + + dyn_var Int64() const { + return std::get>(constant_); + } + + inline bool is_integer() const { + return type_ >= FieldType::INT8 && type_ <= FieldType::INT64; + } + + inline bool is_real() const { + return type_ == FieldType::DOUBLE || type_ == FieldType::FLOAT; + } + + bool is_null() const { return type_ == lgraph::FieldType::NUL; } + + bool is_string() const { return type_ == lgraph::FieldType::STRING; } + + CScalarData& operator=(CScalarData&& other) noexcept { + if (this != &other) { + constant_ = std::move(other.constant_); + type_ = std::move(other.type_); + } + return *this; + } + + CScalarData& operator=(const CScalarData& other) { + if (this != &other) { + constant_ = other.constant_; + type_ = other.type_; + } + return *this; + } + + CScalarData operator+(const CScalarData& other) const; +}; + +struct CFieldData { + enum FieldType { SCALAR, ARRAY, MAP} type; + + CScalarData scalar; + std::vector* array = nullptr; + std::unordered_map* map = nullptr; + + CFieldData() : type(SCALAR) {} + + CFieldData(const CFieldData &data) : type(data.type), scalar(data.scalar) {} + + explicit CFieldData(const CScalarData& scalar) : type(SCALAR), scalar(scalar) {} + + CFieldData& operator=(const CFieldData& data) { + this->type = data.type; + this->scalar = data.scalar; + return *this; + } + + CFieldData& operator=(CFieldData&& data) { + this->type = std::move(data.type); + this->scalar = std::move(data.scalar); + return *this; + } + + explicit CFieldData(const static_var& scalar) : type(SCALAR), scalar(scalar) {} + + bool is_null() const { return type == SCALAR && scalar.is_null(); } + + bool is_string() const { return type == SCALAR && scalar.is_string(); } + + bool is_integer() const { return type == SCALAR && scalar.is_integer(); } + + bool is_real() const { return type == SCALAR && scalar.is_real(); } + + CFieldData operator+(const CFieldData& other) const; + + CFieldData operator-(const CFieldData& other) const; +}; +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/data_type/record.h b/src/cypher/experimental/data_type/record.h new file mode 100644 index 0000000000..e32a0b6d51 --- /dev/null +++ b/src/cypher/experimental/data_type/record.h @@ -0,0 +1,92 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "core/data_type.h" +#include "cypher/cypher_types.h" +#include "cypher/cypher_exception.h" +#include "parser/data_typedef.h" +#include "graph/node.h" +#include "graph/relationship.h" +#include "cypher/resultset/record.h" +#include "experimental/data_type/field_data.h" + +namespace cypher { + +struct SymbolTable; +class RTContext; + +namespace compilation { +struct CEntry { + compilation::CFieldData constant_; + cypher::Node* node_ = nullptr; + cypher::Relationship* relationship_ = nullptr; + + enum RecordEntryType { + UNKNOWN = 0, + CONSTANT, + NODE, + RELATIONSHIP, + VAR_LEN_RELP, + HEADER, // TODO(anyone) useless? + NODE_SNAPSHOT, + RELP_SNAPSHOT, + } type_; + + CEntry() = default; + + explicit CEntry(const cypher::Entry& entry) { + switch (entry.type) { + case cypher::Entry::CONSTANT: { + constant_ = CFieldData(CScalarData(entry.constant.scalar)); + type_ = CONSTANT; + break; + } + case cypher::Entry::NODE: { + node_ = entry.node; + type_ = NODE; + break; + } + case cypher::Entry::RELATIONSHIP: { + relationship_ = entry.relationship; + type_ = RELATIONSHIP; + break; + } + default: + CYPHER_TODO(); + } + } + + explicit CEntry(const CFieldData &data) : constant_(data), type_(CONSTANT) {} + + explicit CEntry(CFieldData&& data) : constant_(std::move(data)), type_(CONSTANT) {} + + explicit CEntry(const CScalarData& scalar) : constant_(scalar), type_(CONSTANT) {} +}; + +struct CRecord { // Should be derived from cypher::Record + std::vector values; + + CRecord() = default; + + explicit CRecord(const cypher::Record &record) { + for (auto& entry : record.values) { + values.emplace_back(entry); + } + } +}; +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/expressions/cexpr.cpp b/src/cypher/experimental/expressions/cexpr.cpp new file mode 100644 index 0000000000..d28d5c9e24 --- /dev/null +++ b/src/cypher/experimental/expressions/cexpr.cpp @@ -0,0 +1,16 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/experimental/expressions/cexpr.h" + diff --git a/src/cypher/experimental/expressions/cexpr.h b/src/cypher/experimental/expressions/cexpr.h new file mode 100644 index 0000000000..af0bcda915 --- /dev/null +++ b/src/cypher/experimental/expressions/cexpr.h @@ -0,0 +1,238 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +#include +#include +#include +#include + +#include "geax-front-end/ast/Ast.h" +#include "geax-front-end/ast/expr/AggFunc.h" +#include "cypher/arithmetic/agg_ctx.h" +#include "cypher/arithmetic/ast_agg_expr_detector.h" +#include "cypher/execution_plan/visitor/visitor.h" +#include "cypher/resultset/record.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/cypher_types.h" +#include "core/data_type.h" + +#include "experimental/data_type/field_data.h" +#include "experimental/data_type/record.h" +#include "cypher/execution_plan/runtime_context.h" + +namespace cypher { +namespace compilation { + +// struct ArithOperandNode { +// static constexpr const char* type_name = "ArithOperandNode"; +// CScalarData constant; +// struct Variadic { +// static_var alias; +// static_var alias_idx; +// static_var entity_prop; +// } variadic; +// struct Variable { +// bool hasMapFieldName; +// std::string _value_alias; +// std::string _map_field_name; +// } variable; + +// enum ArithOperandType { +// AR_OPERAND_CONSTANT, +// AR_OPERAND_VARIADIC, +// AR_OPERAND_PARAMETER, +// AR_OPERAND_VARIABLE, +// } type; + +// ArithOperandNode() = default; + +// ArithOperandNode(CScalarData &&data) : constant(std::move(data)) { +// std::cout<<"use move constructor"< +void checkedAnyCast(const std::any& s, TargetType& d) { + try { + d = std::any_cast(s); + } catch (...) { + // TODO(lingsu): remove in future + assert(false); + } +} + +class ExprEvaluator : public geax::frontend::AstExprNodeVisitorImpl { + public: + ExprEvaluator() = delete; + + ExprEvaluator(geax::frontend::Expr* expr, const SymbolTable* sym_tab) + : expr_(expr), sym_tab_(sym_tab) {} + + ~ExprEvaluator() = default; + + std::vector agg_exprs_; + std::vector> agg_ctxs_; + size_t agg_pos_; + + enum class VisitMode { + EVALUATE, + AGGREGATE, + } visit_mode_; + + CEntry Evaluate(RTContext* ctx, const CRecord* record) { + ctx_ = ctx; + record_ = record; + agg_pos_ = 0; + visit_mode_ = VisitMode::EVALUATE; + CEntry entry; + checkedAnyCast(expr_->accept(*this), entry); + return entry; + } + + void Aggregate(RTContext* ctx, const CRecord* record) { + ctx_ = ctx; + record_ = record; + visit_mode_ = VisitMode::AGGREGATE; + if (agg_exprs_.empty()) { + agg_exprs_ = AstAggExprDetector::GetAggExprs(expr_); + } + for (size_t i = 0; i < agg_exprs_.size(); i++) { + agg_pos_ = i; + agg_exprs_[i]->accept(*this); + } + } + + void Reduce() { + for (auto agg_ctx : agg_ctxs_) { + agg_ctx->ReduceNext(); + } + } + + geax::frontend::Expr* GetExpression() { + return expr_; + } + + private: + std::any visit(geax::frontend::GetField* node) override; + std::any visit(geax::frontend::TupleGet* node) override; + std::any visit(geax::frontend::Not* node) override; + std::any visit(geax::frontend::Neg* node) override; + std::any visit(geax::frontend::Tilde* node) override; + std::any visit(geax::frontend::VSome* node) override; + std::any visit(geax::frontend::BEqual* node) override; + std::any visit(geax::frontend::BNotEqual* node) override; + std::any visit(geax::frontend::BGreaterThan* node) override; + std::any visit(geax::frontend::BNotSmallerThan* node) override; + std::any visit(geax::frontend::BSmallerThan* node) override; + std::any visit(geax::frontend::BNotGreaterThan* node) override; + std::any visit(geax::frontend::BSafeEqual* node) override; + std::any visit(geax::frontend::BAdd* node) override; + std::any visit(geax::frontend::BSub* node) override; + std::any visit(geax::frontend::BDiv* node) override; + std::any visit(geax::frontend::BMul* node) override; + std::any visit(geax::frontend::BMod* node) override; + std::any visit(geax::frontend::BSquare* node) override; + std::any visit(geax::frontend::BAnd* node) override; + std::any visit(geax::frontend::BOr* node) override; + std::any visit(geax::frontend::BXor* node) override; + std::any visit(geax::frontend::BBitAnd* node) override; + std::any visit(geax::frontend::BBitOr* node) override; + std::any visit(geax::frontend::BBitXor* node) override; + std::any visit(geax::frontend::BBitLeftShift* node) override; + std::any visit(geax::frontend::BBitRightShift* node) override; + std::any visit(geax::frontend::BConcat* node) override; + std::any visit(geax::frontend::BIndex* node) override; + std::any visit(geax::frontend::BLike* node) override; + std::any visit(geax::frontend::BIn* node) override; + std::any visit(geax::frontend::If* node) override; + std::any visit(geax::frontend::Function* node) override; + std::any visit(geax::frontend::Case* node) override; + std::any visit(geax::frontend::Cast* node) override; + std::any visit(geax::frontend::MatchCase* node) override; + std::any visit(geax::frontend::AggFunc* node) override; + std::any visit(geax::frontend::BAggFunc* node) override; + std::any visit(geax::frontend::MultiCount* node) override; + std::any visit(geax::frontend::Windowing* node) override; + std::any visit(geax::frontend::MkList* node) override; + std::any visit(geax::frontend::MkMap* node) override; + std::any visit(geax::frontend::MkRecord* node) override; + std::any visit(geax::frontend::MkSet* node) override; + std::any visit(geax::frontend::MkTuple* node) override; + std::any visit(geax::frontend::VBool* node) override; + std::any visit(geax::frontend::VInt* node) override; + std::any visit(geax::frontend::VDouble* node) override; + std::any visit(geax::frontend::VString* node) override; + std::any visit(geax::frontend::VDate* node) override; + std::any visit(geax::frontend::VDatetime* node) override; + std::any visit(geax::frontend::VDuration* node) override; + std::any visit(geax::frontend::VTime* node) override; + std::any visit(geax::frontend::VNull* node) override; + std::any visit(geax::frontend::VNone* node) override; + std::any visit(geax::frontend::Ref* node) override; + std::any visit(geax::frontend::Param* node) override; + std::any visit(geax::frontend::SingleLabel* node) override; + std::any visit(geax::frontend::LabelOr* node) override; + std::any visit(geax::frontend::IsLabeled* node) override; + std::any visit(geax::frontend::IsNull* node) override; + std::any visit(geax::frontend::ListComprehension* node) override; + std::any visit(geax::frontend::Exists* node) override; + + std::any reportError() override; + + private: + std::string error_msg_; + geax::frontend::Expr* expr_; + RTContext* ctx_; + const SymbolTable* sym_tab_; + const CRecord* record_; + std::shared_ptr agg_func_; +}; + +struct CExprNode { + static constexpr const char* type_name = "ArithExprNode"; + // ArithOperandNode operand_; + // ArithExprNode* left_; + // ArithExprNode* right_; + std::shared_ptr evaluator_; + // OpType op_; + + CExprNode() = default; + + inline CEntry Eval(cypher::RTContext *ctx, const CRecord &record) { + return evaluator_->Evaluate(ctx, &record); + } +}; +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/expressions/kernal/binary.cpp b/src/cypher/experimental/expressions/kernal/binary.cpp new file mode 100644 index 0000000000..52e9b7d2cf --- /dev/null +++ b/src/cypher/experimental/expressions/kernal/binary.cpp @@ -0,0 +1,200 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include "cypher/cypher_types.h" +#include "core/data_type.h" + +#include "cypher/cypher_exception.h" +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/utils/geax_util.h" +#include "cypher/experimental/expressions/cexpr.h" + +namespace cypher { +namespace compilation { +CFieldData CFieldData::operator+(const CFieldData &other) const { + if (is_null() || other.is_null()) return CFieldData(); + CFieldData ret; + if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { + CYPHER_TODO(); + } else if (is_string() || other.is_string()) { + CYPHER_TODO(); + } else if ((is_integer() || is_real()) && (other.is_integer() || other.is_real())) { + if (is_integer() && other.is_integer()) { + ret.scalar = CScalarData(scalar.Int64() + other.scalar.Int64()); + } else { + dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() + : scalar.real(); + dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() + : other.scalar.real(); + ret.scalar = CScalarData(x_n + y_n); + } + } + return ret; +} + +CFieldData CFieldData::operator-(const CFieldData &other) const { + if (is_null() || other.is_null()) return CFieldData(); + CFieldData ret; + if (type == CFieldData::ARRAY || other.type == CFieldData::ARRAY) { + CYPHER_TODO(); + } else if (is_string() || other.is_string()) { + CYPHER_TODO(); + } else if ((is_integer() || is_real()) && (other.is_integer() || other.is_real())) { + if (is_integer() && other.is_integer()) { + ret.scalar = CScalarData(scalar.Int64() - other.scalar.Int64()); + } else { + dyn_var x_n = is_integer() ? (dyn_var)scalar.integer() + : scalar.real(); + dyn_var y_n = is_integer()? (dyn_var)other.scalar.integer() + : other.scalar.real(); + ret.scalar = CScalarData(x_n - y_n); + } + } + return ret; +} + +static CFieldData add(const CFieldData& x, const CFieldData& y) { + return x + y; +} + +static CFieldData sub(const CFieldData& x, const CFieldData& y) { + return x - y; +} + +static CFieldData div(const CFieldData& x, const CFieldData y) { + if (x.is_null() || y.is_null()) return CFieldData(); + if (!(x.is_integer() || x.is_real()) || !(y.is_integer() || y.is_real())) + throw lgraph::CypherException("Type mismatch: expect Integer or Float in div expr"); + CFieldData ret; + if (x.is_integer() && y.is_integer()) { + dyn_var x_n = x.scalar.integer(); + dyn_var y_n = y.scalar.integer(); + if (y_n == 0) throw lgraph::CypherException("divide by zero"); + ret.scalar = CScalarData(x_n / y_n); + } else { + dyn_var x_n = x.is_integer() ? (dyn_var) x.scalar.integer() + : x.scalar.real(); + dyn_var y_n = y.is_integer()? (dyn_var) y.scalar.integer() + : y.scalar.real(); + if (y_n == 0) CYPHER_TODO(); + ret.scalar = CScalarData(x_n - y_n); + } + return ret; +} + +#ifndef DO_BINARY_EXPR +#define DO_BINARY_EXPR(func) \ + auto lef = std::any_cast(node->left()->accept(*this)); \ + auto rig = std::any_cast(node->right()->accept(*this)); \ + if (lef.type_ != CEntry::RecordEntryType::CONSTANT || \ + rig.type_ != CEntry::RecordEntryType::CONSTANT) { \ + NOT_SUPPORT_AND_THROW(); \ + } \ + return CEntry(func(lef.constant_, rig.constant_)); +#endif + +std::any ExprEvaluator::visit(geax::frontend::BAdd* node) { DO_BINARY_EXPR(add); } + +std::any ExprEvaluator::visit(geax::frontend::Ref* node) { + auto it = sym_tab_->symbols.find(node->name()); + if (it == sym_tab_->symbols.end()) NOT_SUPPORT_AND_THROW(); + switch (it->second.type) { + case SymbolNode::NODE: + case SymbolNode::RELATIONSHIP: + case SymbolNode::CONSTANT: + case SymbolNode::PARAMETER: + return record_->values[it->second.id]; + case SymbolNode::NAMED_PATH: + { + // auto it = sym_tab_->anot_collection.path_elements.find(node->name()); + // if (it == sym_tab_->anot_collection.path_elements.end()) + // throw lgraph::CypherException("path_elements error: " + node->name()) + // const std::vector>& elements = it->second; + // std::vector params; + // for (auto ref: elements) { + // params.emplace_back(ref.get(), *sym_tab_); + // } + CYPHER_TODO(); + } + } + return std::any(); +} + +std::any ExprEvaluator::visit(geax::frontend::GetField* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::TupleGet* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Not* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Neg* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Tilde* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VSome* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BGreaterThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotSmallerThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSmallerThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BNotGreaterThan* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSafeEqual* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSub* node) { DO_BINARY_EXPR(sub); } +std::any ExprEvaluator::visit(geax::frontend::BDiv* node) { DO_BINARY_EXPR(div); } +std::any ExprEvaluator::visit(geax::frontend::BMul* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BMod* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BSquare* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BAnd* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BXor* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitAnd* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitXor* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitLeftShift* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BBitRightShift* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BConcat* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BIndex* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BLike* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BIn* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::If* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Function* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Case* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Cast* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MatchCase* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::AggFunc* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::BAggFunc* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MultiCount* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Windowing* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkList* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkMap* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkRecord* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkSet* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::MkTuple* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VBool* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VInt* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDouble* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VString* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDate* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDatetime* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VDuration* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VTime* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VNull* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::VNone* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Param* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::SingleLabel* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::LabelOr* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::IsLabeled* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::IsNull* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::ListComprehension* node) { CYPHER_TODO(); } +std::any ExprEvaluator::visit(geax::frontend::Exists* node) { CYPHER_TODO(); } +std::any ExprEvaluator::reportError() { CYPHER_TODO(); } +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/jit/TuJIT.cpp b/src/cypher/experimental/jit/TuJIT.cpp new file mode 100644 index 0000000000..49099306f6 --- /dev/null +++ b/src/cypher/experimental/jit/TuJIT.cpp @@ -0,0 +1,21 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cypher/experimental/jit/TuJIT.h" + +namespace cypher { +namespace compilation { + +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/experimental/jit/TuJIT.h b/src/cypher/experimental/jit/TuJIT.h new file mode 100644 index 0000000000..a4c4bf781e --- /dev/null +++ b/src/cypher/experimental/jit/TuJIT.h @@ -0,0 +1,95 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace cypher { +namespace compilation { +class JITCompiler; +class JITSymbolResolver; +class JITModuleMemoryManager; + +/** Custom JIT implementation inspired by CHJIT in clickhouse + * Main use cases: + * 1. Compiled functions in module. + * 2. Release memory for compiled function. + */ +class TuJIT { + public: + TuJIT(); + + ~TuJIT(); + + struct CompileModule { + // Size of compiled module code in bytes + size_t size_; + + // Module identifier. Should not be changed by client + uint64_t identifier_; + + // Vector of compiled functions. Should not be changed by client. + // It is client responsibility to cast result function to right signature. + // After call to deleteCompiledModule compiled functions from module become invalid. + std::unordered_map function_name_to_symbol_; + }; + + // Compile module. In compile function client responsiblity is to fill module with necessary + // IR code, then it will be compiled by TuJIT instance. + // Return compiled module. + CompileModule compileModule(std::function compile_funciton); + + // Delete compiled module. Pointers to functions from module become invalid after this call. + // It is client reponsibility to be sure that there are no pointers to compiled module code. + void deleteCompiledModule(const CompileModule& module_info); + + // Register external symbol for TuJIT instance to use, during linking. + // It can be function, or global constant. + // It is client responsibility to be sure that address of symbol + // is valid during TuJIT instance lifetime. + void registerExternalSymbol(const std::string& symbol_name, void* address); + + // Total compiled code size for module that are current valid. + size_t getCompiledCodeSize() const { + return compiled_code_size_.load(std::memory_order_relaxed); + } + + private: + std::unique_ptr createModulerForCompilation(); + + CompileModule compileModule(std::unique_ptr module); + + std::string getMangleName(const std::string& name_to_mangle) const; + + void runOptimizationPassesOnModule(llvm::Module& module) const; + + static std::unique_ptr getTargetMachine_; + + llvm::LLVMContext context_; + std::unique_ptr machine_; + llvm::DataLayout layout_; + std::unique_ptr compiler_; + + std::unordered_map> + module_identifier_to_memory_manager_; + uint64_t current_module_key_ = 0; + std::atomic compiled_code_size_ = 0; + mutable std::mutex jit_lock_; +}; +} // namespace compilation +} // namespace cypher diff --git a/src/cypher/resultset/bit_mask.h b/src/cypher/resultset/bit_mask.h new file mode 100644 index 0000000000..74f985df7e --- /dev/null +++ b/src/cypher/resultset/bit_mask.h @@ -0,0 +1,221 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include + +namespace cypher { + +constexpr uint64_t BITMASKS_SINGLE_ONE[64] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000, 0x10000, 0x20000, 0x40000, 0x80000, + 0x100000, 0x200000, 0x400000, 0x800000, 0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, + 0x20000000, 0x40000000, 0x80000000, 0x100000000, 0x200000000, 0x400000000, 0x800000000, + 0x1000000000, 0x2000000000, 0x4000000000, 0x8000000000, 0x10000000000, 0x20000000000, + 0x40000000000, 0x80000000000, 0x100000000000, 0x200000000000, 0x400000000000, 0x800000000000, + 0x1000000000000, 0x2000000000000, 0x4000000000000, 0x8000000000000, 0x10000000000000, + 0x20000000000000, 0x40000000000000, 0x80000000000000, 0x100000000000000, 0x200000000000000, + 0x400000000000000, 0x800000000000000, 0x1000000000000000, 0x2000000000000000, + 0x4000000000000000, 0x8000000000000000}; +constexpr uint64_t BITMASKS_SINGLE_ZERO[64] = {0xfffffffffffffffe, 0xfffffffffffffffd, + 0xfffffffffffffffb, 0xfffffffffffffff7, 0xffffffffffffffef, 0xffffffffffffffdf, + 0xffffffffffffffbf, 0xffffffffffffff7f, 0xfffffffffffffeff, 0xfffffffffffffdff, + 0xfffffffffffffbff, 0xfffffffffffff7ff, 0xffffffffffffefff, 0xffffffffffffdfff, + 0xffffffffffffbfff, 0xffffffffffff7fff, 0xfffffffffffeffff, 0xfffffffffffdffff, + 0xfffffffffffbffff, 0xfffffffffff7ffff, 0xffffffffffefffff, 0xffffffffffdfffff, + 0xffffffffffbfffff, 0xffffffffff7fffff, 0xfffffffffeffffff, 0xfffffffffdffffff, + 0xfffffffffbffffff, 0xfffffffff7ffffff, 0xffffffffefffffff, 0xffffffffdfffffff, + 0xffffffffbfffffff, 0xffffffff7fffffff, 0xfffffffeffffffff, 0xfffffffdffffffff, + 0xfffffffbffffffff, 0xfffffff7ffffffff, 0xffffffefffffffff, 0xffffffdfffffffff, + 0xffffffbfffffffff, 0xffffff7fffffffff, 0xfffffeffffffffff, 0xfffffdffffffffff, + 0xfffffbffffffffff, 0xfffff7ffffffffff, 0xffffefffffffffff, 0xffffdfffffffffff, + 0xffffbfffffffffff, 0xffff7fffffffffff, 0xfffeffffffffffff, 0xfffdffffffffffff, + 0xfffbffffffffffff, 0xfff7ffffffffffff, 0xffefffffffffffff, 0xffdfffffffffffff, + 0xffbfffffffffffff, 0xff7fffffffffffff, 0xfeffffffffffffff, 0xfdffffffffffffff, + 0xfbffffffffffffff, 0xf7ffffffffffffff, 0xefffffffffffffff, 0xdfffffffffffffff, + 0xbfffffffffffffff, 0x7fffffffffffffff}; +constexpr uint64_t LOWER_BITMASKS[65] = {0x0, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f, 0xff, 0x1ff, + 0x3ff, 0x7ff, 0xfff, 0x1fff, 0x3fff, 0x7fff, 0xffff, 0x1ffff, 0x3ffff, 0x7ffff, 0xfffff, + 0x1fffff, 0x3fffff, 0x7fffff, 0xffffff, 0x1ffffff, 0x3ffffff, 0x7ffffff, 0xfffffff, 0x1fffffff, + 0x3fffffff, 0x7fffffff, 0xffffffff, 0x1ffffffff, 0x3ffffffff, 0x7ffffffff, 0xfffffffff, + 0x1fffffffff, 0x3fffffffff, 0x7fffffffff, 0xffffffffff, 0x1ffffffffff, 0x3ffffffffff, + 0x7ffffffffff, 0xfffffffffff, 0x1fffffffffff, 0x3fffffffffff, 0x7fffffffffff, 0xffffffffffff, + 0x1ffffffffffff, 0x3ffffffffffff, 0x7ffffffffffff, 0xfffffffffffff, 0x1fffffffffffff, + 0x3fffffffffffff, 0x7fffffffffffff, 0xffffffffffffff, 0x1ffffffffffffff, 0x3ffffffffffffff, + 0x7ffffffffffffff, 0xfffffffffffffff, 0x1fffffffffffffff, 0x3fffffffffffffff, + 0x7fffffffffffffff, 0xffffffffffffffff}; +constexpr uint64_t HIGH_BITMASKS[65] = {0x0, 0x8000000000000000, 0xc000000000000000, + 0xe000000000000000, 0xf000000000000000, 0xf800000000000000, 0xfc00000000000000, + 0xfe00000000000000, 0xff00000000000000, 0xff80000000000000, 0xffc0000000000000, + 0xffe0000000000000, 0xfff0000000000000, 0xfff8000000000000, 0xfffc000000000000, + 0xfffe000000000000, 0xffff000000000000, 0xffff800000000000, 0xffffc00000000000, + 0xffffe00000000000, 0xfffff00000000000, 0xfffff80000000000, 0xfffffc0000000000, + 0xfffffe0000000000, 0xffffff0000000000, 0xffffff8000000000, 0xffffffc000000000, + 0xffffffe000000000, 0xfffffff000000000, 0xfffffff800000000, 0xfffffffc00000000, + 0xfffffffe00000000, 0xffffffff00000000, 0xffffffff80000000, 0xffffffffc0000000, + 0xffffffffe0000000, 0xfffffffff0000000, 0xfffffffff8000000, 0xfffffffffc000000, + 0xfffffffffe000000, 0xffffffffff000000, 0xffffffffff800000, 0xffffffffffc00000, + 0xffffffffffe00000, 0xfffffffffff00000, 0xfffffffffff80000, 0xfffffffffffc0000, + 0xfffffffffffe0000, 0xffffffffffff0000, 0xffffffffffff8000, 0xffffffffffffc000, + 0xffffffffffffe000, 0xfffffffffffff000, 0xfffffffffffff800, 0xfffffffffffffc00, + 0xfffffffffffffe00, 0xffffffffffffff00, 0xffffffffffffff80, 0xffffffffffffffc0, + 0xffffffffffffffe0, 0xfffffffffffffff0, 0xfffffffffffffff8, 0xfffffffffffffffc, + 0xfffffffffffffffe, 0xffffffffffffffff}; + +class BitMask { + public: + static constexpr uint64_t NO_NULL_ENTRY = 0; + static constexpr uint64_t ALL_NULL_ENTRY = ~uint64_t(NO_NULL_ENTRY); + static constexpr uint64_t BITS_PER_ENTRY_LOG2 = 6; // 64 bits per entry + static constexpr uint64_t BITS_PER_ENTRY = (uint64_t)1 << BITS_PER_ENTRY_LOG2; + static constexpr uint64_t BYTES_PER_ENTRY = BITS_PER_ENTRY >> 3; // 8 bytes per entry + + explicit BitMask(uint64_t capacity) : may_contain_nulls_{false} { + auto num_null_entries = (capacity + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY; + buffer_ = std::make_unique(num_null_entries); + data_ = buffer_.get(); + size_ = num_null_entries; + std::fill(data_, data_ + num_null_entries, NO_NULL_ENTRY); + } + + explicit BitMask(uint64_t* mask_data, size_t size, bool may_contain_nulls) + : data_{mask_data}, size_{size}, buffer_{nullptr}, may_contain_nulls_{may_contain_nulls} {} + BitMask(const BitMask& other) { + size_ = other.size_; + may_contain_nulls_ = other.may_contain_nulls_; + buffer_ = std::make_unique(size_); + std::copy(other.data_, other.data_ + size_, buffer_.get()); + data_ = buffer_.get(); + } + + BitMask& operator=(const BitMask& other) { + if (this == &other) return *this; + size_ = other.size_; + may_contain_nulls_ = other.may_contain_nulls_; + buffer_ = std::make_unique(size_); + std::copy(other.data_, other.data_ + size_, buffer_.get()); + data_ = buffer_.get(); + return *this; + } + + void SetAllNonNull() { + if (!may_contain_nulls_) return; + std::fill(data_, data_ + size_, NO_NULL_ENTRY); + may_contain_nulls_ = false; + } + + void SetAllNull() { + std::fill(data_, data_ + size_, ALL_NULL_ENTRY); + may_contain_nulls_ = true; + } + + bool HasNoNullsGuarantee() const { return !may_contain_nulls_; } + + static void SetBit(uint64_t* entries, uint32_t pos, bool is_null) { + auto [entry_pos, bit_pos_in_entry] = GetEntryAndBitPos(pos); + if (is_null) { + entries[entry_pos] |= BITMASKS_SINGLE_ONE[bit_pos_in_entry]; + } else { + entries[entry_pos] &= BITMASKS_SINGLE_ZERO[bit_pos_in_entry]; + } + } + + void SetBit(uint32_t pos, bool is_null) { + SetBit(data_, pos, is_null); + if (is_null) { + may_contain_nulls_ = true; + } + } + + bool IsBitSet(uint32_t pos) const { + auto [entry_pos, bit_pos_in_entry] = GetEntryAndBitPos(pos); + return data_[entry_pos] & BITMASKS_SINGLE_ONE[bit_pos_in_entry]; + } + + const uint64_t* GetData() const { return data_; } + + static uint64_t GetNumEntries(uint64_t num_bits) { + return (num_bits >> BITS_PER_ENTRY_LOG2) + + ((num_bits - (num_bits << BITS_PER_ENTRY_LOG2)) == 0 ? 0 : 1); + } + + void resize(uint64_t capacity) { + auto num_entries = (capacity + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY; + auto resized_buffer = std::make_unique(num_entries); + if (data_) { + std::memcpy(resized_buffer.get(), data_, + std::min(size_, num_entries) * sizeof(uint64_t)); + } + buffer_ = std::move(resized_buffer); + data_ = buffer_.get(); + size_ = num_entries; + } + + void SetNullFromRange(uint64_t offset, uint64_t num_bits_to_set, bool is_null) { + if (is_null) { + may_contain_nulls_ = true; + } + SetNullRange(data_, offset, num_bits_to_set, is_null); + } + + static void SetNullRange(uint64_t* null_entries, uint64_t offset, + uint64_t num_bits_to_set, bool is_null) { + auto [first_entry_pos, first_bit_pos] = GetEntryAndBitPos(offset); + auto [last_entry_pos, last_bit_pos] = GetEntryAndBitPos(offset + num_bits_to_set); + + if (last_entry_pos > first_entry_pos + 1) { + std::fill(null_entries + first_entry_pos + 1, null_entries + last_entry_pos, + is_null ? ALL_NULL_ENTRY : NO_NULL_ENTRY); + } + + if (first_entry_pos == last_entry_pos) { + if (is_null) { + null_entries[first_entry_pos] |= (~LOWER_BITMASKS[first_bit_pos] + & ~HIGH_BITMASKS[BITS_PER_ENTRY - last_bit_pos]); + } else { + null_entries[first_entry_pos] &= (LOWER_BITMASKS[first_bit_pos] + | HIGH_BITMASKS[BITS_PER_ENTRY - last_bit_pos]); + } + } else { + if (is_null) { + null_entries[first_entry_pos] |= ~LOWER_BITMASKS[first_bit_pos]; + if (last_bit_pos > 0) { + null_entries[last_entry_pos] |= LOWER_BITMASKS[last_bit_pos]; + } + } else { + null_entries[first_entry_pos] &= LOWER_BITMASKS[first_bit_pos]; + if (last_bit_pos > 0) { + null_entries[last_entry_pos] &= ~LOWER_BITMASKS[last_bit_pos]; + } + } + } + } + + private: + static std::pair GetEntryAndBitPos(uint64_t pos) { + auto entry_pos = pos >> BITS_PER_ENTRY_LOG2; + return {entry_pos, pos - (entry_pos << BITS_PER_ENTRY_LOG2)}; + } + + private: + uint64_t* data_; + size_t size_; + std::unique_ptr buffer_; + bool may_contain_nulls_; +}; + +} // namespace cypher diff --git a/src/cypher/resultset/column_vector.h b/src/cypher/resultset/column_vector.h new file mode 100644 index 0000000000..b94dc914c9 --- /dev/null +++ b/src/cypher/resultset/column_vector.h @@ -0,0 +1,325 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include "cypher/resultset/bit_mask.h" +#include "cypher/resultset/cypher_string_t.h" + +namespace cypher { + +constexpr size_t DEFAULT_VECTOR_CAPACITY = 2048; + +class ColumnVector { + friend class StringColumn; + + public: + explicit ColumnVector(size_t element_size, size_t capacity = DEFAULT_VECTOR_CAPACITY, + lgraph_api::FieldType field_type = lgraph_api::FieldType::NUL) + : element_size_(element_size), + capacity_(capacity), + field_type_(field_type), + data_(new uint8_t[element_size * capacity]()), + bitmask_(capacity) {} + + ColumnVector(const ColumnVector& other) + : element_size_(other.element_size_), + capacity_(other.capacity_), + field_type_(other.field_type_), + data_(new uint8_t[other.element_size_ * other.capacity_]), + bitmask_(other.bitmask_) { + // Check if the ColumnVector contains strings + if (element_size_ == sizeof(cypher_string_t)) { + // Initialize overflow buffer + if (other.overflow_buffer_) { + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; // will update this as we copy strings + } + // Copy each cypher_string_t individually + for (uint32_t i = 0; i < capacity_; ++i) { + auto& src_str = reinterpret_cast(other.data_.get())[i]; + auto& dst_str = reinterpret_cast(data_.get())[i]; + dst_str.len = src_str.len; + if (cypher_string_t::IsShortString(src_str.len)) { + // Copy the short string directly + std::memcpy(dst_str.prefix, src_str.prefix, src_str.len); + } else { + // Copy the prefix + std::memcpy(dst_str.prefix, src_str.prefix, cypher_string_t::PREFIX_LENGTH); + // Allocate overflow space in the new overflow buffer + uint64_t overflow_size = src_str.len - cypher_string_t::PREFIX_LENGTH; + if (!overflow_buffer_) { + // Initialize overflow buffer if not already done + overflow_buffer_capacity_ = std::max(overflow_size, + static_cast(1024)); + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; + } else if (overflow_offset_ + overflow_size > overflow_buffer_capacity_) { + // Resize overflow buffer if necessary + ResizeOverflowBuffer(overflow_offset_ + overflow_size); + } + // Copy the overflow data + void* dst_overflow_ptr = overflow_buffer_.get() + overflow_offset_; + std::memcpy(dst_overflow_ptr, reinterpret_cast(src_str.overflowPtr), + overflow_size); + dst_str.overflowPtr = reinterpret_cast(dst_overflow_ptr); + overflow_offset_ += overflow_size; + } + } + } else { + // For non-string data, we can copy the data directly + std::memcpy(data_.get(), other.data_.get(), element_size_ * capacity_); + /* Copy overflow buffer if it exists (though for non-string data, it shouldn't). + Just in case for future use. */ + if (other.overflow_buffer_) { + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_offset_ = other.overflow_offset_; + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + std::memcpy(overflow_buffer_.get(), other.overflow_buffer_.get(), overflow_offset_); + } + } + } + + ColumnVector& operator=(const ColumnVector& other) { + if (this == &other) return *this; + element_size_ = other.element_size_; + capacity_ = other.capacity_; + field_type_ = other.field_type_; + data_ = std::unique_ptr(new uint8_t[other.element_size_ * other.capacity_]); + std::memcpy(data_.get(), other.data_.get(), other.element_size_ * other.capacity_); + bitmask_ = other.bitmask_; + overflow_buffer_capacity_ = other.overflow_buffer_capacity_; + overflow_offset_ = other.overflow_offset_; + if (other.overflow_buffer_) { + overflow_buffer_ = std::unique_ptr( + new uint8_t[other.overflow_buffer_capacity_]); + std::memcpy(overflow_buffer_.get(), other.overflow_buffer_.get(), overflow_offset_); + } else { + overflow_buffer_.reset(); + } + return *this; + } + + ~ColumnVector() = default; + + void SetAllNull() { bitmask_.SetAllNull(); } + void SetAllNonNull() { bitmask_.SetAllNonNull(); } + bool HasNoNullsGuarantee() const { return bitmask_.HasNoNullsGuarantee(); } + + void SetNullRange(uint32_t start, uint32_t len, bool value) { + bitmask_.SetNullFromRange(start, len, value); + } + + void SetNull(uint32_t pos, bool is_null) { bitmask_.SetBit(pos, is_null); } + + bool IsNull(uint32_t pos) const { return bitmask_.IsBitSet(pos); } + + uint8_t* data() const { return data_.get(); } + + uint32_t GetElementSize() const { return element_size_; } + + uint32_t GetCapacity() const { return capacity_; } + + lgraph_api::FieldType GetFieldType() const { return field_type_; } + + template + const T& GetValue(uint32_t pos) const { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + return reinterpret_cast(data_.get())[pos]; + } + + template + T& GetValue(uint32_t pos) { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + return reinterpret_cast(data_.get())[pos]; + } + + template + void SetValue(uint32_t pos, T val) { + if (pos >= capacity_) { + throw std::out_of_range("Index out of range in GetValue"); + } + reinterpret_cast(data_.get())[pos] = val; + } + + void* AllocateOverflow(uint64_t size) const { + if (!overflow_buffer_) { + overflow_buffer_capacity_ = std::max(size, static_cast(1024)); + overflow_buffer_ = std::make_unique(overflow_buffer_capacity_); + overflow_offset_ = 0; + } else if (overflow_offset_ + size > overflow_buffer_capacity_) { + uint64_t new_capacity = overflow_offset_ + size; + new_capacity = ((new_capacity + 1023) / 1024) * 1024; + ResizeOverflowBuffer(new_capacity); + } + void* ptr = overflow_buffer_.get() + overflow_offset_; + overflow_offset_ += size; + return ptr; + } + + // fetch field size + static size_t GetFieldSize(lgraph_api::FieldType type) { + switch (type) { + case lgraph_api::FieldType::BOOL: + return sizeof(bool); + case lgraph_api::FieldType::INT8: + return sizeof(int8_t); + case lgraph_api::FieldType::INT16: + return sizeof(int16_t); + case lgraph_api::FieldType::INT32: + return sizeof(int32_t); + case lgraph_api::FieldType::INT64: + return sizeof(int64_t); + case lgraph_api::FieldType::FLOAT: + return sizeof(float); + case lgraph_api::FieldType::DOUBLE: + return sizeof(double); + default: + throw std::runtime_error("Unsupported field type"); + } + } + + // insert data into column vector + static void InsertIntoColumnVector(ColumnVector* column_vector, + const lgraph_api::FieldData& field, + uint32_t pos) { + switch (field.type) { + case lgraph_api::FieldType::BOOL: + column_vector->SetValue(pos, field.AsBool()); + break; + case lgraph_api::FieldType::INT8: + column_vector->SetValue(pos, field.AsInt8()); + break; + case lgraph_api::FieldType::INT16: + column_vector->SetValue(pos, field.AsInt16()); + break; + case lgraph_api::FieldType::INT32: + column_vector->SetValue(pos, field.AsInt32()); + break; + case lgraph_api::FieldType::INT64: + column_vector->SetValue(pos, field.AsInt64()); + break; + case lgraph_api::FieldType::FLOAT: + column_vector->SetValue(pos, field.AsFloat()); + break; + case lgraph_api::FieldType::DOUBLE: + column_vector->SetValue(pos, field.AsDouble()); + break; + default: + throw std::runtime_error("Unsupported field type"); + } + } + + private: + void ResizeOverflowBuffer(uint64_t new_capacity) const { + if (new_capacity <= overflow_buffer_capacity_) return; + auto new_buffer = std::make_unique(new_capacity); + if (overflow_buffer_) { + std::memcpy(new_buffer.get(), overflow_buffer_.get(), overflow_offset_); + } + overflow_buffer_ = std::move(new_buffer); + overflow_buffer_capacity_ = new_capacity; + } + + private: + uint32_t element_size_; // size of each element in bytes + uint32_t capacity_; // number of elements + lgraph_api::FieldType field_type_; + std::unique_ptr data_; + BitMask bitmask_; + mutable uint64_t overflow_buffer_capacity_; + mutable std::unique_ptr overflow_buffer_ = nullptr; + mutable uint64_t overflow_offset_; +}; + + +class StringColumn { + public: + // add string to vector + static void AddString(ColumnVector* vector, uint32_t vectorPos, cypher_string_t& srcStr) { + auto& dstStr = vector->GetValue(vectorPos); + if (cypher_string_t::IsShortString(srcStr.len)) { + dstStr.SetShortString(reinterpret_cast(srcStr.prefix), srcStr.len); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(srcStr.len)); + dstStr.SetLongString(reinterpret_cast(srcStr.prefix), srcStr.len); + } + } + + static void AddString(ColumnVector* vector, uint32_t vectorPos, const char* srcStr, + uint64_t length) { + auto& dstStr = vector->GetValue(vectorPos); + if (cypher_string_t::IsShortString(length)) { + dstStr.SetShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + dstStr.SetLongString(srcStr, length); + } + } + + static void AddString(ColumnVector* vector, uint32_t vectorPos, const std::string& srcStr) { + AddString(vector, vectorPos, srcStr.data(), srcStr.length()); + } + + static cypher_string_t& ReserveString(ColumnVector* vector, uint32_t vectorPos, + uint64_t length) { + auto& dstStr = vector->GetValue(vectorPos); + dstStr.len = length; + if (!cypher_string_t::IsShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + } + return dstStr; + } + + static void ReserveString(ColumnVector* vector, cypher_string_t& dstStr, uint64_t length) { + dstStr.len = length; + if (!cypher_string_t::IsShortString(length)) { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, cypher_string_t& srcStr) { + if (cypher_string_t::IsShortString(srcStr.len)) { + dstStr.SetShortString(reinterpret_cast(srcStr.prefix), srcStr.len); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(srcStr.len)); + dstStr.SetLongString(reinterpret_cast(srcStr.prefix), srcStr.len); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, const char* srcStr, + uint64_t length) { + if (cypher_string_t::IsShortString(length)) { + dstStr.SetShortString(srcStr, length); + } else { + dstStr.overflowPtr = reinterpret_cast(vector->AllocateOverflow(length)); + dstStr.SetLongString(srcStr, length); + } + } + + static void AddString(ColumnVector* vector, cypher_string_t& dstStr, + const std::string& srcStr) { + AddString(vector, dstStr, srcStr.data(), srcStr.length()); + } +}; +} // namespace cypher diff --git a/src/cypher/resultset/cypher_string_t.h b/src/cypher/resultset/cypher_string_t.h new file mode 100644 index 0000000000..a07404f547 --- /dev/null +++ b/src/cypher/resultset/cypher_string_t.h @@ -0,0 +1,66 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include + +namespace cypher { + +struct cypher_string_t { + static constexpr uint64_t PREFIX_LENGTH = 4; + static constexpr uint64_t INLINED_SUFFIX_LENGTH = 8; + static constexpr uint64_t SHORT_STR_LENGTH = PREFIX_LENGTH + INLINED_SUFFIX_LENGTH; + + uint32_t len; + uint8_t prefix[PREFIX_LENGTH]; + union { + uint8_t data[INLINED_SUFFIX_LENGTH]; + uint64_t overflowPtr; + }; + + cypher_string_t() : len{0}, overflowPtr{0} {} + + static bool IsShortString(uint32_t len) { return len <= SHORT_STR_LENGTH; } + + void SetShortString(const char* value, uint64_t length) { + len = length; + std::memcpy(prefix, value, length); + } + + void SetLongString(const char* value, uint64_t length) { + len = length; + std::memcpy(prefix, value, PREFIX_LENGTH); + std::memcpy(reinterpret_cast(overflowPtr), value + PREFIX_LENGTH, + length - PREFIX_LENGTH); + } + + std::string GetAsShortString() const { + return std::string(reinterpret_cast(prefix), len); + } + + std::string GetAsString() const { + if (IsShortString(len)) { + return std::string(reinterpret_cast(prefix), len); + } else { + return std::string(reinterpret_cast(prefix), PREFIX_LENGTH) + + std::string(reinterpret_cast(overflowPtr), len - PREFIX_LENGTH); + } + } +}; + +} // namespace cypher diff --git a/test/QueryTester.cpp b/test/QueryTester.cpp new file mode 100644 index 0000000000..cc3cedb4b7 --- /dev/null +++ b/test/QueryTester.cpp @@ -0,0 +1,293 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "QueryTester.h" + +extern GraphFactory::GRAPH_DATASET_TYPE _ut_graph_dataset_type; +extern lgraph::ut::QUERY_TYPE _ut_query_type; + + +void InitLogging() { + boost::log::add_console_log(std::cout, + boost::log::keywords::format = "[%TimeStamp%]: %Message%"); + boost::log::add_common_attributes(); +} + +QueryTester::QueryTester() { + [[maybe_unused]] static bool logging_initialized = []() { + InitLogging(); + return true; + }(); +} + +void QueryTester::RunTestDemo() { + set_graph_type(_ut_graph_dataset_type); + set_query_type(_ut_query_type); + std::string dir = test_suite_dir_ + "/demo"; + test_file(dir, false); +} + +void QueryTester::set_graph_type(GraphFactory::GRAPH_DATASET_TYPE graph_type) { + graph_type_ = graph_type; +} + +void QueryTester::set_query_type(lgraph::ut::QUERY_TYPE query_type) { + query_type_ = query_type; +} + +void QueryTester::init_db() { + ctx_.reset(); + galaxy_.reset(); + GraphFactory::create_graph(graph_type_, db_dir_); + gconf_.dir = db_dir_; + galaxy_ = std::make_shared(gconf_, true, nullptr); + ctx_ = std::make_shared( + nullptr, galaxy_.get(), + lgraph::_detail::DEFAULT_ADMIN_NAME, graph_name_); +} + +bool QueryTester::test_gql_case(const std::string& gql, std::string& result) { + if (ctx_ == nullptr) { + UT_LOG() << "ctx_ is nullptr"; + return false; + } + geax::frontend::AntlrGqlParser parser(gql); + parser::GqlParser::GqlRequestContext* rule = parser.gqlRequest(); + if (!parser.error().empty()) { + UT_LOG() << "parser.gqlRequest() error: " << parser.error(); + result = parser.error(); + return false; + } + geax::common::ObjectArenaAllocator objAlloc_; + GQLResolveCtx gql_ctx{objAlloc_}; + GQLAstVisitor visitor{gql_ctx}; + rule->accept(&visitor); + auto ret = visitor.error(); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "rule->accept(&visitor) ret: " << ToString(ret); + result = ToString(ret); + return false; + } + AstNode* node = visitor.result(); + // rewrite ast + cypher::GenAnonymousAliasRewriter gen_anonymous_alias_rewriter; + node->accept(gen_anonymous_alias_rewriter); + // dump + AstDumper dumper; + ret = dumper.handle(node); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "dumper.handle(node) gql: " << gql; + UT_LOG() << "dumper.handle(node) ret: " << ToString(ret); + UT_LOG() << "dumper.handle(node) error_msg: " << dumper.error_msg(); + result = dumper.error_msg(); + return false; + } else { + UT_DBG() << "--- dumper.handle(node) dump ---"; + UT_DBG() << dumper.dump(); + } + cypher::ExecutionPlanV2 execution_plan_v2; + ret = execution_plan_v2.Build(node, ctx_.get()); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "build execution_plan_v2 failed: " << execution_plan_v2.ErrorMsg(); + result = execution_plan_v2.ErrorMsg(); + return false; + } else { + try { + execution_plan_v2.Execute(ctx_.get()); + } catch (std::exception &e) { + UT_LOG() << e.what(); + result = e.what(); + return false; + } + UT_LOG() << "-----result-----"; + result = ctx_->result_->Dump(false); + UT_LOG() << result; + } + return true; +} + +bool QueryTester::test_cypher_case(const std::string& cypher, std::string& result) { + try { + antlr4::ANTLRInputStream input(cypher); + parser::LcypherLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + parser::LcypherParser parser(&tokens); + parser.addErrorListener(&parser::CypherErrorListener::INSTANCE); + geax::common::ObjectArenaAllocator objAlloc_; + parser::CypherBaseVisitorV2 visitor(objAlloc_, parser.oC_Cypher(), ctx_.get()); + AstNode* node = visitor.result(); + // rewrite ast + cypher::GenAnonymousAliasRewriter gen_anonymous_alias_rewriter; + node->accept(gen_anonymous_alias_rewriter); + cypher::MultiPathPatternRewriter multi_path_pattern_rewriter(objAlloc_); + node->accept(multi_path_pattern_rewriter); + cypher::PushDownFilterAstRewriter push_down_filter_ast_writer(objAlloc_, ctx_.get()); + node->accept(push_down_filter_ast_writer); + // dump + AstDumper dumper; + auto ret = dumper.handle(node); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "dumper.handle(node) gql: " << cypher; + UT_LOG() << "dumper.handle(node) ret: " << ToString(ret); + UT_LOG() << "dumper.handle(node) error_msg: " << dumper.error_msg(); + result = dumper.error_msg(); + return false; + } else { + UT_DBG() << "--- dumper.handle(node) dump ---"; + UT_DBG() << dumper.dump(); + } + cypher::ExecutionPlanV2 execution_plan_v2; + ret = execution_plan_v2.Build(node, ctx_.get()); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "build execution_plan_v2 failed: " << execution_plan_v2.ErrorMsg(); + result = execution_plan_v2.ErrorMsg(); + return false; + } else { + try { + execution_plan_v2.Execute(ctx_.get()); + } catch (std::exception& e) { + UT_LOG() << e.what(); + result = e.what(); + return true; + } + UT_LOG() << "-----MY result-----"; + result = ctx_->data_chunk_->Dump(false); + // UT_LOG() << "-----result-----"; + // result = ctx_->result_->Dump(false); + UT_LOG() << result; + } + } catch (std::exception& e) { + UT_LOG() << e.what(); + result = e.what(); + return true; + } + return true; +} + +void QueryTester::test_file(const std::string& file_prefix, bool check_result) { + std::string test_file = file_prefix + TEST_SUFFIX; + std::string result_file = file_prefix + RESULT_SUFFIX; + std::string real_file = file_prefix + REAL_SUFFIX; + std::string line, query, result; + bool is_error = false; + fma_common::LocalFileSystem fs; + if (!fs.FileExists(test_file)) { + UT_ERR() << "test_file not exists: " << test_file; + UT_EXPECT_TRUE(false); + return; + } + std::ifstream test_file_in(test_file); + std::ofstream real_file_out(real_file); + init_db(); + UT_DBG() << "test_file: " << test_file; + auto test_query_handle_result = [&]() { + UT_LOG() << "-----" << lgraph::ut::ToString(query_type_) << "-----"; + UT_LOG() << query; + bool success; + if (query_type_ == lgraph::ut::QUERY_TYPE::CYPHER) { + success = test_cypher_case(query, result); + } else if (query_type_ == lgraph::ut::QUERY_TYPE::GQL) { + success = test_gql_case(query, result); + } else { + LOG_FATAL() << "unhandled query_type_: " << lgraph::ut::ToString(query_type_); + UT_EXPECT_TRUE(false); + return; + } + if (!success && !is_error) { + UT_EXPECT_TRUE(false); + } + real_file_out << query << std::endl; + real_file_out << result << std::endl; + query.clear(); + is_error = false; + }; + while (std::getline(test_file_in, line)) { + std::string line_t = fma_common::Strip(line, ' '); + bool start_with_comment_prefix = fma_common::StartsWith(line_t, COMMENT_PREFIX); + if (start_with_comment_prefix || line_t.empty()) { + real_file_out << line << std::endl; + continue; + } + if (fma_common::StartsWith(line_t, ERROR_CMD_PREFIX)) { + real_file_out << line << std::endl; + is_error = true; + continue; + } else if (fma_common::StartsWith(line_t, LOAD_PROCEDURE_CMD_PREFIX)) { + // Load stored procedure + // Input format: -- loadProcedure name procedure_source_path [read_only=true] + // The default value for read_only is true. + real_file_out << line << std::endl; + auto args = fma_common::Split(line_t, " "); + if (args.size() / 2 == 2 && !args[2].empty() && !args[3].empty()) { + load_procedure(args[2], args[3], + args.size() != 5 || args[4] == LOAD_PROCEDURE_READ_ONLY); + continue; + } + UT_EXPECT_TRUE(false); + } + if (!query.empty()) { + query += "\n"; + } + query += line; + if (fma_common::EndsWith(line_t, END_LINE_SUFFIX)) { + test_query_handle_result(); + } + } + if (!query.empty()) { + test_query_handle_result(); + } + test_file_in.close(); + real_file_out.close(); + if (!check_result) { + return; + } + if (diff_file(real_file, result_file)) { + fma_common::LocalFileSystem fs; + fs.Remove(real_file); + } else { + UT_EXPECT_TRUE(false); + } +} + +void QueryTester::load_procedure(const std::string& name, const std::string& procedure_source_path, + bool read_only) { + std::ifstream f; + f.open(procedure_source_path, std::ios::in); + std::string buf; + std::string text = ""; + while (getline(f, buf)) { + text += buf; + text += "\n"; + } + f.close(); + std::string encoded = lgraph_api::encode_base64(text); + std::string result; + std::string procedure_version = "v1"; + UT_EXPECT_TRUE(test_cypher_case( + FMA_FMT("CALL db.plugin.loadPlugin('CPP','{}','{}','CPP','{}', {}, '{}')", name, + encoded, name, read_only ? "true" : "false", procedure_version), + result)); + return; +} + +bool QueryTester::diff_file(const std::string& lef, const std::string& rig) { + std::string cmd = fma_common::StringFormatter::Format("diff {} {}", lef, rig); + lgraph::SubProcess diff(cmd, false); + diff.Wait(); + if (diff.GetExitCode() != 0) { + UT_LOG() << "-----" << cmd << "-----"; + UT_LOG() << diff.Stdout(); + } + return diff.GetExitCode() == 0; +} diff --git a/test/QueryTester.h b/test/QueryTester.h new file mode 100644 index 0000000000..3a84a2c51b --- /dev/null +++ b/test/QueryTester.h @@ -0,0 +1,100 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + + +#ifndef TEST_QUERYTESTER_H_ // TEST_QUERYTESTER_H_? +#define TEST_QUERYTESTER_H_ + +#include +#include +#include +#include +#include + +#include "./graph_factory.h" +#include "./antlr4-runtime.h" +#include "geax-front-end/ast/AstNode.h" +#include "geax-front-end/ast/AstDumper.h" +#include "geax-front-end/isogql/GQLResolveCtx.h" +#include "geax-front-end/isogql/GQLAstVisitor.h" +#include "geax-front-end/isogql/parser/AntlrGqlParser.h" + +#include "cypher/parser/generated/LcypherLexer.h" +#include "cypher/parser/generated/LcypherParser.h" +#include "cypher/parser/cypher_base_visitor_v2.h" +#include "cypher/parser/cypher_error_listener.h" +#include "cypher/rewriter/GenAnonymousAliasRewriter.h" +#include "cypher/rewriter/MultiPathPatternRewriter.h" +#include "fma-common/file_system.h" +#include "db/galaxy.h" +#include "cypher/rewriter/PushDownFilterAstRewriter.h" +#include "cypher/execution_plan/runtime_context.h" +#include "cypher/execution_plan/execution_plan_v2.h" +#include "lgraph/lgraph_utils.h" +#include "./ut_utils.h" +#include "./ut_config.h" +#include "./ut_types.h" + +#ifndef LOGGING_UTILS_H +#define LOGGING_UTILS_H + +#include +#include +#include + +void InitLogging(); + +#endif // LOGGING_UTILS_H + + +using namespace geax::frontend; +using geax::frontend::GEAXErrorCode; + +class QueryTester { + public: + QueryTester(); + void RunTestDemo(); + + private: + void set_graph_type(GraphFactory::GRAPH_DATASET_TYPE graph_type); + void set_query_type(lgraph::ut::QUERY_TYPE query_type); + void init_db(); + bool test_gql_case(const std::string& gql, std::string& result); + bool test_cypher_case(const std::string& cypher, std::string& result); + void test_file(const std::string& file_prefix, bool check_result = true); + void load_procedure(const std::string& name, const std::string& procedure_source_path, + bool read_only = true); + bool diff_file(const std::string& lef, const std::string& rig); + + std::shared_ptr ctx_; + std::shared_ptr galaxy_; + lgraph::Galaxy::Config gconf_; + inline static const std::string TEST_SUFFIX = ".test"; + inline static const std::string REAL_SUFFIX = ".real"; + inline static const std::string RESULT_SUFFIX = ".result"; + inline static const std::string COMMENT_PREFIX = "#"; + inline static const std::string END_LINE_SUFFIX = ";"; + inline static const std::string LOAD_PROCEDURE_CMD_PREFIX = "-- loadProcedure"; + inline static const std::string ERROR_CMD_PREFIX = "-- error"; + inline static const std::string LOAD_PROCEDURE_READ_ONLY = "read_only=true"; + std::string db_dir_ = "./testdb"; + std::string graph_name_ = "default"; + GraphFactory::GRAPH_DATASET_TYPE graph_type_ = GraphFactory::GRAPH_DATASET_TYPE::YAGO; + lgraph::ut::QUERY_TYPE query_type_ = lgraph::ut::QUERY_TYPE::GQL; + + protected: + std::string test_suite_dir_ = lgraph::ut::TEST_RESOURCE_DIRECTORY + "/cases"; +}; + +#endif // TEST_QUERYTESTER_H_ diff --git a/test/resource/unit_test/vector_index/cypher/vector_index.result b/test/resource/unit_test/vector_index/cypher/vector_index.result index ca8ca885d2..16675122c4 100644 --- a/test/resource/unit_test/vector_index/cypher/vector_index.result +++ b/test/resource/unit_test/vector_index/cypher/vector_index.result @@ -7,7 +7,7 @@ CALL db.addVertexVectorIndex('person','embedding2', {dimension:4}); CALL db.addVertexVectorIndex('person','name', {dimension:4}); [VectorIndexException] Only FLOAT_VECTOR type supports vector index CALL db.showVertexVectorIndex(); -[{"deleted_ids_num":0,"dimension":4,"distance_type":"l2","elements_num":0,"field_name":"embedding1","index_type":"hnsw","label_name":"person","memory_usage":380,"parameter":{"hnsw.ef_construction":100,"hnsw.m":16}},{"deleted_ids_num":0,"dimension":4,"distance_type":"l2","elements_num":0,"field_name":"embedding2","index_type":"hnsw","label_name":"person","memory_usage":380,"parameter":{"hnsw.ef_construction":100,"hnsw.m":16}}] +[{"deleted_ids_num":0,"dimension":4,"distance_type":"l2","elements_num":0,"field_name":"embedding1","index_type":"hnsw","label_name":"person","memory_usage":224,"parameter":{"hnsw.ef_construction":100,"hnsw.m":16}},{"deleted_ids_num":0,"dimension":4,"distance_type":"l2","elements_num":0,"field_name":"embedding2","index_type":"hnsw","label_name":"person","memory_usage":224,"parameter":{"hnsw.ef_construction":100,"hnsw.m":16}}] CREATE (n:person {id:1, name:'name1', embedding1: [1.0,1.0,1.0,1.0], embedding2: [11.0,11.0,11.0,11.0]}); [{"":"created 1 vertices, created 0 edges."}] CREATE (n:person {id:2, name:'name2', embedding1: [2.0,2.0,2.0,2.0], embedding2: [12.0,12.0,12.0,12.0]}); diff --git a/test/test_bit_mask.cpp b/test/test_bit_mask.cpp new file mode 100644 index 0000000000..99568602be --- /dev/null +++ b/test/test_bit_mask.cpp @@ -0,0 +1,160 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include +#include + +#include "./graph_factory.h" +/* Make sure include graph_factory.h BEFORE antlr4-runtime.h. Otherwise causing the following error: + * ‘EOF’ was not declared in this scope. + * For the former (include/butil) uses macro EOF, which is undefined in antlr4. */ +#include "./antlr4-runtime.h" +#include "geax-front-end/ast/AstNode.h" +#include "geax-front-end/ast/AstDumper.h" +#include "geax-front-end/isogql/GQLResolveCtx.h" +#include "geax-front-end/isogql/GQLAstVisitor.h" +#include "geax-front-end/isogql/parser/AntlrGqlParser.h" + +#include "cypher/parser/generated/LcypherLexer.h" +#include "cypher/parser/generated/LcypherParser.h" +#include "cypher/parser/cypher_base_visitor.h" +#include "cypher/parser/cypher_error_listener.h" +#include "cypher/rewriter/GenAnonymousAliasRewriter.h" +#include "fma-common/file_system.h" +#include "db/galaxy.h" +#include "cypher/execution_plan/runtime_context.h" +#include "cypher/execution_plan/execution_plan_v2.h" +#include "lgraph/lgraph_utils.h" +#include "./ut_utils.h" +#include "./ut_config.h" +#include "./ut_types.h" +#include "cypher/resultset/column_vector.h" +#include "cypher/resultset/bit_mask.h" +#include "cypher/resultset/cypher_string_t.h" + +using namespace geax::frontend; +using geax::frontend::GEAXErrorCode; + +using namespace cypher; + +TEST(BitMaskTest, Constructor) { + uint64_t capacity = 128; + BitMask bm(capacity); + // Ensure that initially no bits are set as null + for (uint64_t i = 0; i < capacity; ++i) { + EXPECT_FALSE(bm.IsBitSet(i)); + } + EXPECT_TRUE(bm.HasNoNullsGuarantee()); +} + +TEST(BitMaskTest, SetAndCheckNullBits) { + uint64_t capacity = 128; + BitMask bm(capacity); + bm.SetBit(10, true); + bm.SetBit(63, true); + bm.SetBit(127, true); + EXPECT_TRUE(bm.IsBitSet(10)); + EXPECT_TRUE(bm.IsBitSet(63)); + EXPECT_TRUE(bm.IsBitSet(127)); + EXPECT_FALSE(bm.IsBitSet(0)); + EXPECT_FALSE(bm.IsBitSet(64)); + EXPECT_FALSE(bm.HasNoNullsGuarantee()); +} + +TEST(BitMaskTest, SetAllNull) { + uint64_t capacity = 64; + BitMask bm(capacity); + bm.SetAllNull(); + + for (uint64_t i = 0; i < capacity; ++i) { + EXPECT_TRUE(bm.IsBitSet(i)); + } + EXPECT_FALSE(bm.HasNoNullsGuarantee()); +} + +TEST(BitMaskTest, SetAllNonNull) { + uint64_t capacity = 64; + BitMask bm(capacity); + bm.SetBit(20, true); + bm.SetBit(40, true); + EXPECT_TRUE(bm.IsBitSet(20)); + EXPECT_TRUE(bm.IsBitSet(40)); + bm.SetAllNonNull(); + + for (uint64_t i = 0; i < capacity; ++i) { + EXPECT_FALSE(bm.IsBitSet(i)); + } + + EXPECT_TRUE(bm.HasNoNullsGuarantee()); +} + +TEST(BitMaskTest, CopyConstructor) { + uint64_t capacity = 64; + BitMask bm1(capacity); + + bm1.SetBit(15, true); + bm1.SetBit(30, true); + + // Use copy constructor + BitMask bm2 = bm1; + EXPECT_TRUE(bm2.IsBitSet(15)); + EXPECT_TRUE(bm2.IsBitSet(30)); + EXPECT_FALSE(bm2.IsBitSet(0)); + EXPECT_FALSE(bm2.IsBitSet(63)); +} + +TEST(BitMaskTest, CopyAssignment) { + uint64_t capacity1 = 64; + uint64_t capacity2 = 128; + BitMask bm1(capacity1); + BitMask bm2(capacity2); + bm1.SetBit(10, true); + bm2.SetBit(100, true); + + // Copy assignment + bm2 = bm1; + EXPECT_TRUE(bm2.IsBitSet(10)); +} + +TEST(BitMaskTest, Resize) { + uint64_t initial_capacity = 64; + BitMask bm(initial_capacity); + bm.SetBit(10, true); + bm.SetBit(63, true); + EXPECT_TRUE(bm.IsBitSet(10)); + EXPECT_TRUE(bm.IsBitSet(63)); + + // Resize the bitmask + bm.resize(128); + EXPECT_TRUE(bm.IsBitSet(10)); + EXPECT_TRUE(bm.IsBitSet(63)); + EXPECT_FALSE(bm.IsBitSet(64)); + EXPECT_FALSE(bm.IsBitSet(127)); +} + +TEST(BitMaskTest, SetNullFromRange) { + uint64_t capacity = 128; + BitMask bm(capacity); + + // Set a range of bits to null + bm.SetNullFromRange(10, 20, true); + + for (uint64_t i = 10; i < 30; ++i) { + EXPECT_TRUE(bm.IsBitSet(i)); + } + EXPECT_FALSE(bm.IsBitSet(9)); + EXPECT_FALSE(bm.IsBitSet(30)); +} diff --git a/test/test_column_vector.cpp b/test/test_column_vector.cpp new file mode 100644 index 0000000000..b46751692c --- /dev/null +++ b/test/test_column_vector.cpp @@ -0,0 +1,163 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include +#include + +#include "./graph_factory.h" +/* Make sure include graph_factory.h BEFORE antlr4-runtime.h. Otherwise causing the following error: + * ‘EOF’ was not declared in this scope. + * For the former (include/butil) uses macro EOF, which is undefined in antlr4. */ +#include "./antlr4-runtime.h" +#include "geax-front-end/ast/AstNode.h" +#include "geax-front-end/ast/AstDumper.h" +#include "geax-front-end/isogql/GQLResolveCtx.h" +#include "geax-front-end/isogql/GQLAstVisitor.h" +#include "geax-front-end/isogql/parser/AntlrGqlParser.h" + +#include "cypher/parser/generated/LcypherLexer.h" +#include "cypher/parser/generated/LcypherParser.h" +#include "cypher/parser/cypher_base_visitor.h" +#include "cypher/parser/cypher_error_listener.h" +#include "cypher/rewriter/GenAnonymousAliasRewriter.h" +#include "fma-common/file_system.h" +#include "db/galaxy.h" +#include "cypher/execution_plan/runtime_context.h" +#include "cypher/execution_plan/execution_plan_v2.h" +#include "lgraph/lgraph_utils.h" +#include "./ut_utils.h" +#include "./ut_config.h" +#include "./ut_types.h" +#include "cypher/resultset/column_vector.h" +#include "cypher/resultset/bit_mask.h" +#include "cypher/resultset/cypher_string_t.h" + +using namespace geax::frontend; +using geax::frontend::GEAXErrorCode; + +using namespace cypher; + +TEST(ColumnVectorTest, Constructor) { + ColumnVector cv(sizeof(int32_t), 10); + EXPECT_EQ(cv.GetElementSize(), sizeof(int32_t)); + EXPECT_EQ(cv.GetCapacity(), 10); +} + +TEST(ColumnVectorTest, SetAndGetValue) { + ColumnVector cv(sizeof(int32_t), 10); + int32_t value = 42; + cv.SetValue(0, value); + EXPECT_EQ(cv.GetValue(0), value); +} + +TEST(ColumnVectorTest, SetNullAndCheck) { + ColumnVector cv(sizeof(int32_t), 10); + cv.SetNull(0, true); + EXPECT_TRUE(cv.IsNull(0)); + cv.SetNull(0, false); + EXPECT_FALSE(cv.IsNull(0)); +} + +TEST(StringColumnTest, AddShortString) { + ColumnVector cv(sizeof(cypher_string_t), 10); + std::string shortString = "abcd"; // 4 bytes short string + StringColumn::AddString(&cv, 0, shortString); + + auto& storedString = cv.GetValue(0); + EXPECT_EQ(storedString.GetAsString(), shortString); +} + +TEST(StringColumnTest, AddLongString) { + ColumnVector cv(sizeof(cypher_string_t), 10); + std::string longString = "This is a very long string to test overflow buffer."; // > 12 bytes + StringColumn::AddString(&cv, 0, longString); + + auto& storedString = cv.GetValue(0); + EXPECT_EQ(storedString.GetAsString(), longString); +} + +TEST(ColumnVectorTest, CopyConstructor) { + ColumnVector cv1(sizeof(int32_t), 10); + int32_t value = 99; + cv1.SetValue(0, value); + + ColumnVector cv2 = cv1; // Use copy constructor + EXPECT_EQ(cv2.GetValue(0), value); +} + +TEST(ColumnVectorTest, CopyAssignment) { + ColumnVector cv1(sizeof(int32_t), 10); + int32_t value = 100; + cv1.SetValue(0, value); + + ColumnVector cv2(sizeof(int32_t), 5); + cv2 = cv1; // Use copy assignment operator + EXPECT_EQ(cv2.GetValue(0), value); +} + +TEST(ColumnVectorTest, ResizeOverflowBuffer) { + ColumnVector cv(sizeof(cypher_string_t), 1); + std::string longString = "This string will cause the overflow buffer to resize."; + + // Add long string to trigger overflow buffer allocation + StringColumn::AddString(&cv, 0, longString); + EXPECT_EQ(cv.GetValue(0).GetAsString(), longString); + + // Access the overflow buffer to check for resize + void* initialPtr = cv.AllocateOverflow(1); + cv.AllocateOverflow(2000); // Force resize + void* newPtr = cv.AllocateOverflow(1); + EXPECT_NE(initialPtr, newPtr); // Pointer should change after resize +} + +TEST(ColumnVectorTest, AccessEmptyVector) { + ColumnVector cv(sizeof(int32_t), 0); // 容量为0 + EXPECT_THROW(cv.SetValue(0, 42), std::out_of_range); + EXPECT_THROW(cv.GetValue(0), std::out_of_range); +} + +TEST(ColumnVectorTest, MaximumCapacity) { + ColumnVector cv(sizeof(int32_t), DEFAULT_VECTOR_CAPACITY); + EXPECT_NO_THROW(cv.SetValue(DEFAULT_VECTOR_CAPACITY - 1, 42)); + EXPECT_EQ(cv.GetValue(DEFAULT_VECTOR_CAPACITY - 1), 42); +} + +TEST(ColumnVectorTest, NegativeIndexAccess) { + ColumnVector cv(sizeof(int32_t), 10); + EXPECT_THROW(cv.SetValue(-1, 42), std::out_of_range); +} + +TEST(ColumnVectorTest, OverflowBufferExpansion) { + ColumnVector cv(sizeof(cypher_string_t), 1); + + // Add long string to trigger overflow buffer allocation + std::string longString(2000, 'x'); + StringColumn::AddString(&cv, 0, longString); + EXPECT_EQ(cv.GetValue(0).GetAsString(), longString); +} + +TEST(ColumnVectorTest, LargeScaleData) { + size_t largeSize = 1000000; + ColumnVector cv(sizeof(int32_t), largeSize); + + for (size_t i = 0; i < largeSize; ++i) { + cv.SetValue(i, static_cast(i)); + } + + for (size_t i = 0; i < largeSize; ++i) { + EXPECT_EQ(cv.GetValue(i), static_cast(i)); + } +} diff --git a/test/test_plan_cache.cpp b/test/test_plan_cache.cpp new file mode 100644 index 0000000000..16db0a2c1e --- /dev/null +++ b/test/test_plan_cache.cpp @@ -0,0 +1,59 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./antlr4-runtime.h" +#include "cypher/execution_plan/plan_cache/plan_cache_param.h" +#include "cypher/execution_plan/plan_cache/plan_cache.h" + +#include "gtest/gtest.h" +#include "./ut_utils.h" +#include "core/data_type.h" +#include "./test_tools.h" + +class TestPlanCache : public TuGraphTest {}; + +TEST_F(TestPlanCache, basicCaching) { + cypher::LRUPlanCache cache(512); + + cache.add_plan("1", 1); + int value; + cache.get_plan("1", value); + ASSERT_EQ(value, 1); + + cache.add_plan("2", 2); + cache.get_plan("2", value); + ASSERT_EQ(value, 2); + + ASSERT_EQ(cache.current_size(), 2); +} + +TEST_F(TestPlanCache, eviction) { + cypher::LRUPlanCache cache(512); + + for (int i = 0; i < 522; i++) { + cache.add_plan(std::to_string(i), i); + } + + for (int i = 0; i < 10; i++) { + int val; + bool res = cache.get_plan(std::to_string(i), val); + ASSERT_EQ(res, false); + } + + for (int i = 10; i < 522; i++) { + int val; + bool res = cache.get_plan(std::to_string(i), val); + ASSERT_EQ(res, true); + } +} diff --git a/test/test_query_benchmark.cpp b/test/test_query_benchmark.cpp new file mode 100644 index 0000000000..f2a5215d50 --- /dev/null +++ b/test/test_query_benchmark.cpp @@ -0,0 +1,76 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "QueryTester.h" + +// Define the external variables +GraphFactory::GRAPH_DATASET_TYPE _ut_graph_dataset_type = GraphFactory::GRAPH_DATASET_TYPE::YAGO; +lgraph::ut::QUERY_TYPE _ut_query_type = lgraph::ut::QUERY_TYPE::CYPHER; + +// Function to run the core test code +void TestDemoFunction() { + QueryTester tester; + tester.RunTestDemo(); +} + +// Benchmark function +static void BM_TestDemoCol(benchmark::State& state) { + for (auto _ : state) { + // You can adjust the loop count as needed + for (int i = 0; i < 10; ++i) { + TestDemoFunction(); + } + } +} + +// Register the benchmark +BENCHMARK(BM_TestDemoCol); + +int main(int argc, char** argv) { + // Initialize logging if needed + // Parse command-line arguments + for (int i = 1; i < argc; ++i) { + if (std::string(argv[i]) == "--dataset") { + if (i + 1 < argc) { + std::string dataset_type = argv[++i]; // Read dataset argument + // Set _ut_graph_dataset_type based on dataset_type + _ut_graph_dataset_type = GraphFactory::GRAPH_DATASET_TYPE::YAGO; + // Add other dataset types as needed + } else { + std::cerr << "--dataset option requires one argument." << std::endl; + return 1; + } + } else if (std::string(argv[i]) == "--query_type") { + if (i + 1 < argc) { + std::string query_type = argv[++i]; // Read query_type argument + // Set _ut_query_type based on query_type + if (query_type == "cypher") { + _ut_query_type = lgraph::ut::QUERY_TYPE::CYPHER; + } else if (query_type == "gql") { + _ut_query_type = lgraph::ut::QUERY_TYPE::GQL; + } + // Add other query types as needed + } else { + std::cerr << "--query_type option requires one argument." << std::endl; + return 1; + } + } + } + + // Initialize Google Benchmark + ::benchmark::Initialize(&argc, argv); + ::benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/test/test_query_col.cpp b/test/test_query_col.cpp new file mode 100644 index 0000000000..af551dcd55 --- /dev/null +++ b/test/test_query_col.cpp @@ -0,0 +1,375 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include +#include + +#include "./graph_factory.h" +/* Make sure include graph_factory.h BEFORE antlr4-runtime.h. Otherwise causing the following error: + * ‘EOF’ was not declared in this scope. + * For the former (include/butil) uses macro EOF, which is undefined in antlr4. */ +#include "./antlr4-runtime.h" +#include "geax-front-end/ast/AstNode.h" +#include "geax-front-end/ast/AstDumper.h" +#include "geax-front-end/isogql/GQLResolveCtx.h" +#include "geax-front-end/isogql/GQLAstVisitor.h" +#include "geax-front-end/isogql/parser/AntlrGqlParser.h" + +#include "cypher/parser/generated/LcypherLexer.h" +#include "cypher/parser/generated/LcypherParser.h" +#include "cypher/parser/cypher_base_visitor_v2.h" +#include "cypher/parser/cypher_error_listener.h" +#include "cypher/rewriter/GenAnonymousAliasRewriter.h" +#include "cypher/rewriter/MultiPathPatternRewriter.h" +#include "fma-common/file_system.h" +#include "db/galaxy.h" +#include "cypher/rewriter/PushDownFilterAstRewriter.h" +#include "cypher/execution_plan/runtime_context.h" +#include "cypher/execution_plan/execution_plan_v2.h" +#include "lgraph/lgraph_utils.h" +#include "./ut_utils.h" +#include "./ut_config.h" +#include "./ut_types.h" + +using namespace geax::frontend; +using geax::frontend::GEAXErrorCode; + +extern GraphFactory::GRAPH_DATASET_TYPE _ut_graph_dataset_type; +extern lgraph::ut::QUERY_TYPE _ut_query_type; + +class TestQueryCol : public TuGraphTest { + private: + std::shared_ptr ctx_; + std::shared_ptr galaxy_; + lgraph::Galaxy::Config gconf_; + inline static const std::string TEST_SUFFIX = ".test"; + inline static const std::string REAL_SUFFIX = ".real"; + inline static const std::string RESULT_SUFFIX = ".result"; + inline static const std::string COMMENT_PREFIX = "#"; + inline static const std::string END_LINE_SUFFIX = ";"; + inline static const std::string LOAD_PROCEDURE_CMD_PREFIX = "-- loadProcedure"; + inline static const std::string ERROR_CMD_PREFIX = "-- error"; + inline static const std::string LOAD_PROCEDURE_READ_ONLY = "read_only=true"; + std::string db_dir_ = "./testdb"; + std::string graph_name_ = "default"; + GraphFactory::GRAPH_DATASET_TYPE graph_type_ = GraphFactory::GRAPH_DATASET_TYPE::YAGO; + lgraph::ut::QUERY_TYPE query_type_ = lgraph::ut::QUERY_TYPE::GQL; + + protected: + std::string test_suite_dir_ = lgraph::ut::TEST_RESOURCE_DIRECTORY + "/cases"; + + void set_graph_type(GraphFactory::GRAPH_DATASET_TYPE graph_type) { + graph_type_ = graph_type; + } + + void set_query_type(lgraph::ut::QUERY_TYPE query_type) { + query_type_ = query_type; + } + + bool diff_file(const std::string& lef, const std::string& rig) { + std::string cmd = fma_common::StringFormatter::Format("diff {} {}", lef, rig); + lgraph::SubProcess diff(cmd, false); + diff.Wait(); + if (diff.GetExitCode() != 0) { + UT_LOG() << "-----" << cmd << "-----"; + UT_LOG() << diff.Stdout(); + } + return diff.GetExitCode() == 0; + } + + void init_db() { + ctx_.reset(); + galaxy_.reset(); + GraphFactory::create_graph(graph_type_, db_dir_); + gconf_.dir = db_dir_; + galaxy_ = std::make_shared(gconf_, true, nullptr); + ctx_ = std::make_shared( + nullptr, galaxy_.get(), + lgraph::_detail::DEFAULT_ADMIN_NAME, graph_name_); + } + + bool test_gql_case(const std::string& gql, std::string& result) { + if (ctx_ == nullptr) { + UT_LOG() << "ctx_ is nullptr"; + return false; + } + geax::frontend::AntlrGqlParser parser(gql); + parser::GqlParser::GqlRequestContext* rule = parser.gqlRequest(); + if (!parser.error().empty()) { + UT_LOG() << "parser.gqlRequest() error: " << parser.error(); + result = parser.error(); + return false; + } + geax::common::ObjectArenaAllocator objAlloc_; + GQLResolveCtx gql_ctx{objAlloc_}; + GQLAstVisitor visitor{gql_ctx}; + rule->accept(&visitor); + auto ret = visitor.error(); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "rule->accept(&visitor) ret: " << ToString(ret); + result = ToString(ret); + return false; + } + AstNode* node = visitor.result(); + // rewrite ast + cypher::GenAnonymousAliasRewriter gen_anonymous_alias_rewriter; + node->accept(gen_anonymous_alias_rewriter); + // dump + AstDumper dumper; + ret = dumper.handle(node); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "dumper.handle(node) gql: " << gql; + UT_LOG() << "dumper.handle(node) ret: " << ToString(ret); + UT_LOG() << "dumper.handle(node) error_msg: " << dumper.error_msg(); + result = dumper.error_msg(); + return false; + } else { + UT_DBG() << "--- dumper.handle(node) dump ---"; + UT_DBG() << dumper.dump(); + } + cypher::ExecutionPlanV2 execution_plan_v2; + ret = execution_plan_v2.Build(node, ctx_.get()); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "build execution_plan_v2 failed: " << execution_plan_v2.ErrorMsg(); + result = execution_plan_v2.ErrorMsg(); + return false; + } else { + try { + execution_plan_v2.Execute(ctx_.get()); + } catch (std::exception &e) { + UT_LOG() << e.what(); + result = e.what(); + return false; + } + UT_LOG() << "-----result-----"; + result = ctx_->result_->Dump(false); + UT_LOG() << result; + } + return true; + } + + bool test_cypher_case(const std::string& cypher, std::string& result) { + try { + antlr4::ANTLRInputStream input(cypher); + parser::LcypherLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + parser::LcypherParser parser(&tokens); + parser.addErrorListener(&parser::CypherErrorListener::INSTANCE); + geax::common::ObjectArenaAllocator objAlloc_; + parser::CypherBaseVisitorV2 visitor(objAlloc_, parser.oC_Cypher(), ctx_.get()); + AstNode* node = visitor.result(); + // rewrite ast + cypher::GenAnonymousAliasRewriter gen_anonymous_alias_rewriter; + node->accept(gen_anonymous_alias_rewriter); + cypher::MultiPathPatternRewriter multi_path_pattern_rewriter(objAlloc_); + node->accept(multi_path_pattern_rewriter); + cypher::PushDownFilterAstRewriter push_down_filter_ast_writer(objAlloc_, ctx_.get()); + node->accept(push_down_filter_ast_writer); + // dump + AstDumper dumper; + auto ret = dumper.handle(node); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "dumper.handle(node) gql: " << cypher; + UT_LOG() << "dumper.handle(node) ret: " << ToString(ret); + UT_LOG() << "dumper.handle(node) error_msg: " << dumper.error_msg(); + result = dumper.error_msg(); + return false; + } else { + UT_DBG() << "--- dumper.handle(node) dump ---"; + UT_DBG() << dumper.dump(); + } + cypher::ExecutionPlanV2 execution_plan_v2; + ret = execution_plan_v2.Build(node, ctx_.get()); + if (ret != GEAXErrorCode::GEAX_SUCCEED) { + UT_LOG() << "build execution_plan_v2 failed: " << execution_plan_v2.ErrorMsg(); + result = execution_plan_v2.ErrorMsg(); + return false; + } else { +// if (visitor.CommandType() != parser::CmdType::QUERY) { +// ctx_->result_info_ = std::make_unique(); +// ctx_->result_ = std::make_unique(); +// std::string header, data; +// if (visitor.CommandType() == parser::CmdType::EXPLAIN) { +// header = "@plan"; +// data = execution_plan_v2.DumpPlan(0, false); +// } else { +// header = "@profile"; +// data = execution_plan_v2.DumpGraph(); +// } +// ctx_->result_->ResetHeader({{header, lgraph_api::LGraphType::STRING}}); +// auto r = ctx_->result_->MutableRecord(); +// r->Insert(header, lgraph::FieldData(data)); +// result = ctx_->result_->Dump(false); +// return true; +// } + try { + execution_plan_v2.Execute(ctx_.get()); + } catch (std::exception& e) { + UT_LOG() << e.what(); + result = e.what(); + return true; + } + // UT_LOG() << "-----result-----"; + // result = ctx_->result_->Dump(false); + // UT_LOG() << result; + // UT_LOG() << "-----MY result-----"; + // ctx_->data_chunk_->Print(); + UT_LOG() << "-----MY result-----"; + result = ctx_->data_chunk_->Dump(false); + // UT_LOG() << "-----ORIGINAL result-----"; + // result = ctx_->result_->Dump(false); + // UT_LOG() << "-----result-----"; + // result = ctx_->result_->Dump(false); + UT_LOG() << result; + } + } catch (std::exception& e) { + UT_LOG() << e.what(); + result = e.what(); + return true; + } + return true; + } + + void test_files(const std::string& dir) { + fma_common::LocalFileSystem fs; + for (auto& file : fs.ListFiles(dir)) { + if (fma_common::EndsWith(file, TEST_SUFFIX)) { + test_file(file.substr(0, file.size() - TEST_SUFFIX.size())); + } + } + } + + void test_file(const std::string& file_prefix, bool check_result = true) { + std::string test_file = file_prefix + TEST_SUFFIX; + std::string result_file = file_prefix + RESULT_SUFFIX; + std::string real_file = file_prefix + REAL_SUFFIX; + std::string line, query, result; + bool is_error = false; + fma_common::LocalFileSystem fs; + if (!fs.FileExists(test_file)) { + UT_ERR() << "test_file not exists: " << test_file; + UT_EXPECT_TRUE(false); + return; + } + std::ifstream test_file_in(test_file); + std::ofstream real_file_out(real_file); + init_db(); + UT_DBG() << "test_file: " << test_file; + auto test_query_handle_result = [&]() { + UT_LOG() << "-----" << lgraph::ut::ToString(query_type_) << "-----"; + UT_LOG() << query; + bool success; + if (query_type_ == lgraph::ut::QUERY_TYPE::CYPHER) { + success = test_cypher_case(query, result); + } else if (query_type_ == lgraph::ut::QUERY_TYPE::GQL) { + success = test_gql_case(query, result); + } else { + LOG_FATAL() << "unhandled query_type_: " << lgraph::ut::ToString(query_type_); + UT_EXPECT_TRUE(false); + return; + } + if (!success && !is_error) { + UT_EXPECT_TRUE(false); + } + real_file_out << query << std::endl; + real_file_out << result << std::endl; + query.clear(); + is_error = false; + }; + while (std::getline(test_file_in, line)) { + std::string line_t = fma_common::Strip(line, ' '); + bool start_with_comment_prefix = fma_common::StartsWith(line_t, COMMENT_PREFIX); + if (start_with_comment_prefix || line_t.empty()) { + real_file_out << line << std::endl; + continue; + } + if (fma_common::StartsWith(line_t, ERROR_CMD_PREFIX)) { + real_file_out << line << std::endl; + is_error = true; + continue; + } else if (fma_common::StartsWith(line_t, LOAD_PROCEDURE_CMD_PREFIX)) { + // Load stored procedure + // Input format: -- loadProcedure name procedure_source_path [read_only=true] + // The default value for read_only is true. + real_file_out << line << std::endl; + auto args = fma_common::Split(line_t, " "); + if (args.size() / 2 == 2 && !args[2].empty() && !args[3].empty()) { + load_procedure(args[2], args[3], + args.size() != 5 || args[4] == LOAD_PROCEDURE_READ_ONLY); + continue; + } + UT_EXPECT_TRUE(false); + } + if (!query.empty()) { + query += "\n"; + } + query += line; + if (fma_common::EndsWith(line_t, END_LINE_SUFFIX)) { + test_query_handle_result(); + } + } + if (!query.empty()) { + test_query_handle_result(); + } + test_file_in.close(); + real_file_out.close(); + if (!check_result) { + return; + } + if (diff_file(real_file, result_file)) { + fma_common::LocalFileSystem fs; + fs.Remove(real_file); + } else { + UT_EXPECT_TRUE(false); + } + } + + void load_procedure(const std::string& name, const std::string& procedure_source_path, + bool read_only = true) { + std::ifstream f; + f.open(procedure_source_path, std::ios::in); + std::string buf; + std::string text = ""; + while (getline(f, buf)) { + text += buf; + text += "\n"; + } + f.close(); + std::string encoded = lgraph_api::encode_base64(text); + std::string result; + std::string procedure_version = "v1"; + UT_EXPECT_TRUE(test_cypher_case( + FMA_FMT("CALL db.plugin.loadPlugin('CPP','{}','{}','CPP','{}', {}, '{}')", name, + encoded, name, read_only ? "true" : "false", procedure_version), + result)); + return; + } + // Public method to run the core test code + void RunTestDemo() { + set_graph_type(_ut_graph_dataset_type); + set_query_type(_ut_query_type); + std::string dir = lgraph::ut::TEST_RESOURCE_DIRECTORY + "/cases/demo"; + test_file(dir, false); + } +}; + +TEST_F(TestQueryCol, TestDemo) { + set_graph_type(_ut_graph_dataset_type); + set_query_type(_ut_query_type); + std::string dir = test_suite_dir_ + "/demo"; + test_file(dir, false); +} + diff --git a/test/test_query_compilation.cpp b/test/test_query_compilation.cpp new file mode 100644 index 0000000000..321a575fbb --- /dev/null +++ b/test/test_query_compilation.cpp @@ -0,0 +1,128 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/experimental/expressions/cexpr.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/execution_plan/runtime_context.h" +#include "geax-front-end/ast/Ast.h" + +#include "blocks/c_code_generator.h" +#include "builder/builder.h" +#include "builder/builder_context.h" +#include "builder/dyn_var.h" +#include "builder/static_var.h" +using builder::dyn_var; +using builder::static_var; +using cypher::compilation::CFieldData; +using cypher::compilation::CScalarData; +using cypher::compilation::CRecord; +using cypher::compilation::CEntry; + +#include "gtest/gtest.h" + +#include "core/value.h" +#include "./ut_utils.h" + +std::string execute(const std::string& command) { + std::string result; + FILE* pipe = popen(command.c_str(), "r"); + if (!pipe) { + std::cerr << "popen() failed!" << std::endl; + return ""; + } + char buf[128]; + while (fgets(buf, sizeof(buf), pipe) != nullptr) { + result += buf; + } + pclose(pipe); + return result; +} + +std::string execute_func(std::string &func_body) { + const std::string file_name = "test_add.cpp"; + const std::string output_name = "test_add"; + std::ofstream out_file(file_name); + if (!out_file) { + std::cerr << "Failed to open file for writing!" << std::endl; + return ""; + } + out_file << func_body; + out_file.close(); + // define and execute compiler commands + std::string compile_cmd = "g++ " + file_name + " -o " + output_name; + int compile_res = system(compile_cmd.c_str()); + if (compile_res != 0) { + std::cerr << "Compilation failed!" << std::endl; + return ""; + } + // define and execute command + std::string output = execute("./a"); + // delete files + if (std::remove(file_name.c_str()) && std::remove(output_name.c_str())) { + std::cerr << "Failed to delete files: " << file_name + << ", " << output_name << std::endl; + } + return output; +} +class TestQueryCompilation : public TuGraphTest {}; + +dyn_var add(void) { + cypher::SymbolTable sym_tab; + + CFieldData a(std::move(CScalarData(10))); + geax::frontend::Ref ref1; + ref1.setName(std::string("a")); + sym_tab.symbols.emplace("a", + cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + CFieldData b(static_var(10)); + geax::frontend::Ref ref2; + ref2.setName(std::string("b")); + sym_tab.symbols.emplace("b", + cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + geax::frontend::BAdd add; + add.setLeft((geax::frontend::Expr*)&ref1); + add.setRight((geax::frontend::Expr*)&ref2); + CRecord record; + record.values.push_back(CEntry(a)); + record.values.push_back(CEntry(b)); + + cypher::compilation::ExprEvaluator evaluator(&add, &sym_tab); + cypher::RTContext ctx; + return evaluator.Evaluate(&ctx, &record).constant_.scalar.Int64(); +} + +TEST_F(TestQueryCompilation, Add) { + builder::builder_context context; + auto ast = context.extract_function_ast(add, "add"); + std::ostringstream oss; + oss << "#include \n"; + block::c_code_generator::generate_code(ast, oss, 0); + oss << "int main() {\n std::cout << add();\n return 0;\n}"; + std::string body = oss.str(); + std::cout <<"Generated code: \n" << body << std::endl; + std::string res = execute_func(body); + ASSERT_EQ(res, "20"); +} diff --git a/toolkits/CMakeLists.txt b/toolkits/CMakeLists.txt index f21862c42c..11edbe91cd 100644 --- a/toolkits/CMakeLists.txt +++ b/toolkits/CMakeLists.txt @@ -1,6 +1,20 @@ cmake_minimum_required(VERSION 3.13) project(TuGraph C CXX) +############### lgraph_compilation ################ +# set(TARGET_LGRAPH_COMPILATION lgraph_compilation) + +# add_executable(${TARGET_LGRAPH_COMPILATION} +# lgraph_compilation.cpp) + +# target_include_directories(${TARGET_LGRAPH_COMPILATION} PUBLIC +# ${CMAKE_SOURCE_DIR}/deps/buildit/include) +# add_dependencies(${TARGET_LGRAPH_COMPILATION} buildit) +# target_link_libraries(${TARGET_LGRAPH_COMPILATION} +# lgraph_cypher_lib + # ${CMAKE_SOURCE_DIR}/deps/buildit/build/libbuildit.a + # librocksdb.a) + ############### lgraph_import ###################### set(TARGET_LGRAPH_IMPORT lgraph_import) diff --git a/toolkits/lgraph_compilation.cpp b/toolkits/lgraph_compilation.cpp new file mode 100644 index 0000000000..4d169caf55 --- /dev/null +++ b/toolkits/lgraph_compilation.cpp @@ -0,0 +1,72 @@ +/** + * Copyright 2022 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include "cypher/experimental/data_type/field_data.h" +#include "cypher/experimental/data_type/record.h" +#include "cypher/experimental/expressions/cexpr.h" +#include "cypher/parser/symbol_table.h" +#include "cypher/execution_plan/runtime_context.h" +#include "geax-front-end/ast/Ast.h" +#include "blocks/c_code_generator.h" +#include "builder/builder.h" +#include "builder/builder_context.h" +#include "builder/dyn_var.h" +using namespace cypher::compilation; +using builder::static_var; +using builder::dyn_var; + +dyn_var bar(void) { + std::variant, static_var> a; + std::variant, dyn_var> b; + a = (std::variant, static_var>)static_var(10); + b = dyn_var(10); + auto res = std::get>(a) + std::get>(b); + return res; +} + +dyn_var foo(void) { + cypher::SymbolTable sym_tab; + + CFieldData a(std::move(CScalarData(10))); + geax::frontend::Ref ref1; + ref1.setName(std::string("a")); + sym_tab.symbols.emplace("a", + cypher::SymbolNode(0, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + CFieldData b(static_var(10)); + geax::frontend::Ref ref2; + ref2.setName(std::string("b")); + sym_tab.symbols.emplace("b", + cypher::SymbolNode(1, cypher::SymbolNode::CONSTANT, cypher::SymbolNode::LOCAL)); + + geax::frontend::BAdd add; + add.setLeft((geax::frontend::Expr*)&ref1); + add.setRight((geax::frontend::Expr*)&ref2); + CRecord record; + record.values.push_back(CEntry(a)); + record.values.push_back(CEntry(b)); + + ExprEvaluator evaluator(&add, &sym_tab); + cypher::RTContext ctx; + return evaluator.Evaluate(&ctx, &record).constant_.scalar.Int64(); +} + +int main() { + builder::builder_context context; + std::cout << "#include " << std::endl; + block::c_code_generator::generate_code(context.extract_function_ast(foo, "foo"), std::cout, 0); + std::cout << "int main() {\n std::cout << foo() << std::endl;\n return 0;\n}"; + return 0; +}