Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix calcite plan overlapping #1476

Open
wants to merge 7 commits into
base: branch-21.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory;
import org.apache.calcite.sql.parser.SqlParseException;
Expand Down Expand Up @@ -271,22 +273,14 @@ public RelationalAlgebraGenerator(FrameworkConfig frameworkConfig, HepProgram he
String response = "";

try {
response = RelOptUtil.toString(getRelationalAlgebra(sql));
RelNode optimizedPlan = getRelationalAlgebra(sql);
response = RelOptUtil.dumpPlan("", optimizedPlan, SqlExplainFormat.TEXT, SqlExplainLevel.NON_COST_ATTRIBUTES);
}catch(SqlValidationException ex){
//System.out.println(ex.getMessage());
//System.out.println("Found validation err!");
throw ex;
//return "fail: \n " + ex.getMessage();
}catch(SqlSyntaxException ex){
//System.out.println(ex.getMessage());
//System.out.println("Found syntax err!");
throw ex;
//return "fail: \n " + ex.getMessage();
} catch(Exception ex) {
//System.out.println(ex.toString());
//System.out.println(ex.getMessage());
ex.printStackTrace();

LOGGER.error(ex.getMessage());
return "fail: \n " + ex.getMessage();
}
Expand Down
4 changes: 3 additions & 1 deletion engine/src/bmr/MemoryMonitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ namespace ral {
iter != starting_node->kernel_unit->output_.cache_machines_.end(); iter++) {
size_t amount_downgraded = 0;
do {
amount_downgraded = iter->second->downgradeCacheData();
for (auto cache : iter->second){
amount_downgraded = cache->downgradeCacheData();
}
} while (amount_downgraded > 0 && need_to_free_memory()); // if amount_downgraded is 0 then there is was nothing left to downgrade
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct tree_processor {
std::vector<ral::io::Schema> schemas;
std::vector<std::string> table_names;
std::vector<std::string> table_scans;
std::map<int, std::shared_ptr<node>> mapped_ids; //TODO to clean this state
const bool transform_operators_bigger_than_gpu = false;

tree_processor( node root,
Expand Down Expand Up @@ -133,11 +134,19 @@ struct tree_processor {
root_ptr->level = level;
root_ptr->kernel_unit = make_kernel(kernel_id, expr, query_graph);
kernel_id++;
for (auto &child : p_tree.get_child("children")) {
auto child_node_ptr = std::make_shared<node>();
root_ptr->children.push_back(child_node_ptr);
kernel_id = expr_tree_from_json(kernel_id, child.second, child_node_ptr.get(), level + 1, query_graph);
auto step_id = get_id_from_expression(expr);
if(mapped_ids.find(step_id) == mapped_ids.end()){
for (auto &child : p_tree.get_child("children")) {
auto child_node_ptr = std::make_shared<node>();
root_ptr->children.push_back(child_node_ptr);
kernel_id = expr_tree_from_json(kernel_id, child.second, child_node_ptr.get(), level + 1, query_graph);
}
std::shared_ptr<node> my_ptr(root_ptr);
mapped_ids[step_id] = my_ptr;
}else{ //This node was already processed
root_ptr->children.push_back(mapped_ids[step_id]);
}

return kernel_id;
}

Expand Down
79 changes: 56 additions & 23 deletions engine/src/execution_graph/logic_controllers/taskflow/port.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

namespace ral {
namespace cache {
void port::register_port(std::string port_name) { cache_machines_[port_name] = nullptr; }
void port::register_port(std::string port_name) {
cache_machines_[port_name].push_back(nullptr);
}

std::shared_ptr<CacheMachine> & port::get_cache(const std::string & port_name) {
if(port_name.length() == 0) {
// NOTE: id is the `default` cache_machine name
auto id = std::to_string(kernel_->get_id());
auto it = cache_machines_.find(id);
return it->second;
return it->second[0]; //default
}
auto it = cache_machines_.find(port_name);
if (it == cache_machines_.end()){
Expand All @@ -24,22 +26,31 @@ std::shared_ptr<CacheMachine> & port::get_cache(const std::string & port_name) {
logger->error("|||{info}|||||","info"_a=log_detail);
}
}
return it->second;
return it->second[0]; //default
}

void port::register_cache(const std::string & port_name, std::shared_ptr<CacheMachine> cache_machine) {
this->cache_machines_[port_name] = cache_machine;
if(this->cache_machines_[port_name].empty()){
this->cache_machines_[port_name].push_back(cache_machine); //todo
}
else{
this->cache_machines_[port_name][0] = cache_machine;
}
}
void port::finish() {
for(auto it : cache_machines_) {
it.second->finish();
for (auto cache_vector : cache_machines_){
for (auto cache : cache_vector.second){
cache->finish();
}
}
}

bool port::all_finished(){
for (auto cache : cache_machines_){
if (!cache.second->is_finished())
return false;
for (auto cache_vector : cache_machines_){
for (auto cache : cache_vector.second){
if (!cache->is_finished())
return false;
}
}
return true;
}
Expand All @@ -48,33 +59,47 @@ bool port::is_finished(const std::string & port_name){
if(port_name.length() == 0) {
// NOTE: id is the `default` cache_machine name
auto id = std::to_string(kernel_->get_id());
auto it = cache_machines_.find(id);
return it->second->is_finished();
auto cache_vector = cache_machines_.find(id);
bool is_finished = true;
for (auto cache : cache_vector->second){
is_finished = is_finished && cache->is_finished();
}
return is_finished;
}
auto it = cache_machines_.find(port_name);
return it->second->is_finished();
bool is_finished = true;
auto cache_vector = cache_machines_.find(port_name);
for (auto cache : cache_vector->second){
is_finished = is_finished && cache->is_finished();
}
return is_finished;
}

uint64_t port::total_bytes_added(){
uint64_t total = 0;
for (auto cache : cache_machines_){
total += cache.second->get_num_bytes_added();
for (auto cache_vector : cache_machines_){
for (auto cache : cache_vector.second){
total += cache->get_num_bytes_added();
}
}
return total;
}

uint64_t port::total_rows_added(){
uint64_t total = 0;
for (auto cache : cache_machines_){
total += cache.second->get_num_rows_added();
for (auto cache_vector : cache_machines_){
for (auto cache : cache_vector.second){
total += cache->get_num_rows_added();
}
}
return total;
}

uint64_t port::total_batches_added(){
uint64_t total = 0;
for (auto cache : cache_machines_){
total += cache.second->get_num_batches_added();
for (auto cache_vector : cache_machines_){
for (auto cache : cache_vector.second){
total += cache->get_num_batches_added();
}
}
return total;
}
Expand All @@ -83,11 +108,19 @@ uint64_t port::get_num_rows_added(const std::string & port_name){
if(port_name.length() == 0) {
// NOTE: id is the `default` cache_machine name
auto id = std::to_string(kernel_->get_id());
auto it = cache_machines_.find(id);
return it->second->get_num_rows_added();
uint64_t num_rows_added = 0;
auto cache_vector = cache_machines_.find(id);
for (auto cache : cache_vector->second){
num_rows_added += cache->get_num_rows_added();
}
return num_rows_added;
}
auto it = cache_machines_.find(port_name);
return it->second->get_num_rows_added();
uint64_t num_rows_added = 0;
auto cache_vector = cache_machines_.find(port_name);
for (auto cache : cache_vector->second){
num_rows_added += cache->get_num_rows_added();
}
return num_rows_added;
}


Expand Down
4 changes: 2 additions & 2 deletions engine/src/execution_graph/logic_controllers/taskflow/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class port {

void finish();

std::shared_ptr<CacheMachine> & operator[](const std::string & port_name) { return cache_machines_[port_name]; }
std::shared_ptr<CacheMachine> & operator[](const std::string & port_name) { return cache_machines_[port_name][0]; } //todo

bool all_finished();

Expand All @@ -86,7 +86,7 @@ class port {

public:
kernel * kernel_;
std::map<std::string, std::shared_ptr<CacheMachine>> cache_machines_;
std::map<std::string, std::vector<std::shared_ptr<CacheMachine>>> cache_machines_;
};

} // namespace cache
Expand Down
10 changes: 10 additions & 0 deletions engine/src/parser/expression_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,16 @@ std::string get_named_expression(const std::string & query_part, const std::stri
return query_part.substr(start_position, end_position - start_position);
}

int get_id_from_expression(const std::string & query_part){
const std::string prefix_id_pattern = "), id = ";
if(query_part.find(prefix_id_pattern) == query_part.npos) {
return -1; // pattern not found
}

int start_position = query_part.find(prefix_id_pattern) + prefix_id_pattern.size();
return std::stoi(query_part.substr(start_position));
}

std::vector<int> get_projections(const std::string & query_part) {

// On Calcite, the select count(*) case is represented with
Expand Down
1 change: 1 addition & 0 deletions engine/src/parser/expression_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ bool is_var_column(const std::string& token);
bool is_inequality(const std::string& token);

std::string get_named_expression(const std::string & query_part, const std::string & expression_name);
int get_id_from_expression(const std::string & query_part);

std::vector<int> get_projections(const std::string & query_part);

Expand Down